Skip to content
This repository has been archived by the owner on Jun 6, 2023. It is now read-only.

Run sorters in docker through hither #223

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions docker_test/docker_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import spikeextractors as se
import spikesorters as ss


rec, _ = se.example_datasets.toy_example(dumpable=True)

output_folder = "ms4_test_docker"

sorting = ss.run_klusta(rec, output_folder=output_folder, use_docker=True)

print(f"KL found #{len(sorting.get_unit_ids())} units")


# output_folder = "kl_test_docker"
#
# sorting_KL = ssd.run_klusta(rec, output_folder=output_folder)
#
# print(f"KL found #{len(sorting_KL.get_unit_ids())} units")
#
# rec, _ = se.example_datasets.toy_example(dumpable=True)
#
# output_folder = "sc_test_docker"
#
# sorting_SC = ssd.run_spykingcircus(rec, output_folder=output_folder)
#
# print(f"SC found #{len(sorting_SC.get_unit_ids())} units")
#
# rec, _ = se.example_datasets.toy_example(dumpable=True)
#
# output_folder = "hs_test_docker"
#
# sorting_HS = ssd.run_herdingspikes(rec, output_folder=output_folder)
#
# print(f"HS found #{len(sorting_HS.get_unit_ids())} units")
74 changes: 74 additions & 0 deletions spikesorters/docker_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import spikeextractors as se
import time
import numpy as np
from pathlib import Path

ss_folder = Path(__file__).parent

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would name this hither_tools instead of docker tools

try:
import hither2 as hither
import docker

HAVE_DOCKER = True

default_docker_images = {
"klusta": hither.DockerImageFromScript(name="klusta", dockerfile=str(ss_folder / "docker_images" / "v0.12" / "klusta" / "Dockerfile")),
"mountainsort4": hither.DockerImageFromScript(name="ms4", dockerfile=str(ss_folder / "docker_images" / "v0.12" / "mountainsort4" / "Dockerfile")),
"herdingspikes": hither.LocalDockerImage('spikeinterface/herdingspikes-si-0.12:0.3.7'),
"spykingcircus": hither.LocalDockerImage('spikeinterface/spyking-circus-si-0.12:1.0.7')
}

except ImportError:
HAVE_DOCKER = False


def modify_input_folder(dump_dict, input_folder="/input"):
if "kwargs" in dump_dict.keys():
dcopy_kwargs, folder_to_mount = modify_input_folder(dump_dict["kwargs"])
dump_dict["kwargs"] = dcopy_kwargs
return dump_dict, folder_to_mount
else:
if "file_path" in dump_dict:
file_path = Path(dump_dict["file_path"])
folder_to_mount = file_path.parent
file_relative = file_path.relative_to(folder_to_mount)
dump_dict["file_path"] = f"{input_folder}/{str(file_relative)}"
return dump_dict, folder_to_mount
elif "folder_path" in dump_dict:
folder_path = Path(dump_dict["folder_path"])
folder_to_mount = folder_path.parent
folder_relative = folder_path.relative_to(folder_to_mount)
dump_dict["folder_path"] = f"{input_folder}/{str(folder_relative)}"
return dump_dict, folder_to_mount
elif "file_or_folder_path" in dump_dict:
file_or_folder_path = Path(dump_dict["file_or_folder_path"])
folder_to_mount = file_or_folder_path.parent
file_or_folder_relative = file_or_folder_path.relative_to(folder_to_mount)
dump_dict["file_or_folder_path"] = f"{input_folder}/{str(file_or_folder_relative)}"
return dump_dict, folder_to_mount
else:
raise Exception


def return_local_data_folder(recording, input_folder='/input'):
"""
Modifies recording dictionary so that the file_path, folder_path, or file_or_folder path is relative to the
'input_folder'

Parameters
----------
recording: se.RecordingExtractor
input_folder: str

Returns
-------
dump_dict: dict

"""
assert recording.is_dumpable
from copy import deepcopy

d = recording.dump_to_dict()
dcopy = deepcopy(d)

return modify_input_folder(dcopy, input_folder)
1 change: 1 addition & 0 deletions spikesorters/run_funtions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .run_functions import _run_sorter_local, _run_sorter_hither
149 changes: 149 additions & 0 deletions spikesorters/run_funtions/run_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from ..docker_tools import HAVE_DOCKER
from ..sorterlist import sorter_dict, sorter_full_list


if HAVE_DOCKER:
# conditional definition of hither tools
import time
from pathlib import Path
import hither2 as hither
import spikeextractors as se
import numpy as np
import shutil
from ..docker_tools import modify_input_folder, default_docker_images

class SpikeSortingDockerHook(hither.RuntimeHook):
def __init__(self):
super().__init__()

def precontainer(self, context: hither.PreContainerContext):
# this gets run outside the container before the run, and we have a chance to mutate the kwargs,
# add bind mounts, and set the image
input_directory = context.kwargs['input_directory']
output_directory = context.kwargs['output_directory']

print("Input:", input_directory)
print("Output:", output_directory)
context.add_bind_mount(hither.BindMount(source=input_directory,
target='/input', read_only=True))
context.add_bind_mount(hither.BindMount(source=output_directory,
target='/output', read_only=False))
context.image = default_docker_images[context.kwargs['sorter_name']]
context.kwargs['output_directory'] = '/output'
context.kwargs['input_directory'] = '/input'


@hither.function('run_sorter_docker_with_container',
'0.1.0',
image=True,
runtime_hooks=[SpikeSortingDockerHook()])
def run_sorter_docker_with_container(
recording_dict, sorter_name, input_directory, output_directory, **kwargs
):
recording = se.load_extractor_from_dict(recording_dict)
# run sorter
kwargs["output_folder"] = f"{output_directory}/working"
t_start = time.time()
# set output folder within the container
sorting = _run_sorter_local(sorter_name, recording, **kwargs)
t_stop = time.time()
print(f'{sorter_name} run time {np.round(t_stop - t_start)}s')
# save sorting to npz
se.NpzSortingExtractor.write_sorting(sorting, f"{output_directory}/sorting_docker.npz")

def _run_sorter_hither(sorter_name, recording, output_folder=None, delete_output_folder=False,
grouping_property=None, parallel=False, verbose=False, raise_error=True,
n_jobs=-1, joblib_backend='loky', **params):
assert recording.is_dumpable, "Cannot run not dumpable recordings in docker"
if output_folder is None:
output_folder = sorter_name + '_output'
output_folder = Path(output_folder).absolute()
output_folder.mkdir(exist_ok=True, parents=True)

with hither.Config(use_container=True, show_console=True):
dump_dict_container, input_directory = modify_input_folder(recording.dump_to_dict(), '/input')
print(dump_dict_container)
kwargs = dict(recording_dict=dump_dict_container,
sorter_name=sorter_name,
output_folder=str(output_folder),
delete_output_folder=False,
grouping_property=grouping_property, parallel=parallel,
verbose=verbose, raise_error=raise_error, n_jobs=n_jobs,
joblib_backend=joblib_backend)

kwargs.update(params)
kwargs.update({'input_directory': str(input_directory), 'output_directory': str(output_folder)})
sorting_job = hither.Job(run_sorter_docker_with_container, kwargs)
sorting_job.wait()
sorting = se.NpzSortingExtractor(output_folder / "sorting_docker.npz")
if delete_output_folder:
shutil.rmtree(output_folder)
return sorting
else:
def _run_sorter_hither(sorter_name, recording, output_folder=None, delete_output_folder=False,
grouping_property=None, parallel=False, verbose=False, raise_error=True,
n_jobs=-1, joblib_backend='loky', **params):
raise ImportError()


# generic launcher via function approach
def _run_sorter_local(sorter_name_or_class, recording, output_folder=None, delete_output_folder=False,
grouping_property=None, parallel=False, verbose=False, raise_error=True, n_jobs=-1,
joblib_backend='loky', **params):
"""
Generic function to run a sorter via function approach.

Two usages with name or class:

by name:
>>> sorting = run_sorter('tridesclous', recording)

by class:
>>> sorting = run_sorter(TridesclousSorter, recording)

Parameters
----------
sorter_name_or_class: str or SorterClass
The sorter to retrieve default parameters from
recording: RecordingExtractor
The recording extractor to be spike sorted
output_folder: str or Path
Path to output folder
delete_output_folder: bool
If True, output folder is deleted (default False)
grouping_property: str
Splits spike sorting by 'grouping_property' (e.g. 'groups')
parallel: bool
If True and spike sorting is by 'grouping_property', spike sorting jobs are launched in parallel
verbose: bool
If True, output is verbose
raise_error: bool
If True, an error is raised if spike sorting fails (default). If False, the process continues and the error is
logged in the log file.
n_jobs: int
Number of jobs when parallel=True (default=-1)
joblib_backend: str
joblib backend when parallel=True (default='loky')
**params: keyword args
Spike sorter specific arguments (they can be retrieved with 'get_default_params(sorter_name_or_class)'

Returns
-------
sortingextractor: SortingExtractor
The spike sorted data

"""
if isinstance(sorter_name_or_class, str):
SorterClass = sorter_dict[sorter_name_or_class]
elif sorter_name_or_class in sorter_full_list:
SorterClass = sorter_name_or_class
else:
raise ValueError('Unknown sorter')

sorter = SorterClass(recording=recording, output_folder=output_folder, grouping_property=grouping_property,
verbose=verbose, delete_output_folder=delete_output_folder)
sorter.set_params(**params)
sorter.run(raise_error=raise_error, parallel=parallel, n_jobs=n_jobs, joblib_backend=joblib_backend)
sortingextractor = sorter.get_result(raise_error=raise_error)

return sortingextractor
45 changes: 29 additions & 16 deletions spikesorters/sorterlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
from .yass import YassSorter
from .combinato import CombinatoSorter


from .run_funtions import _run_sorter_local, _run_sorter_hither
from .docker_tools import HAVE_DOCKER


sorter_full_list = [
HDSortSorter,
KlustaSorter,
Expand All @@ -33,10 +38,9 @@
sorter_dict = {s.sorter_name: s for s in sorter_full_list}


# generic launcher via function approach
def run_sorter(sorter_name_or_class, recording, output_folder=None, delete_output_folder=False,
grouping_property=None, parallel=False, verbose=False, raise_error=True, n_jobs=-1, joblib_backend='loky',
**params):
grouping_property=None, use_docker=False, parallel=False, verbose=False, raise_error=True, n_jobs=-1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would do use_container instead of use_docker.

joblib_backend='loky', **params):
"""
Generic function to run a sorter via function approach.

Expand All @@ -58,6 +62,8 @@ def run_sorter(sorter_name_or_class, recording, output_folder=None, delete_outpu
Path to output folder
delete_output_folder: bool
If True, output folder is deleted (default False)
use_docker: bool
If True and docker backend is installed, spike sorting is run in a docker image
grouping_property: str
Splits spike sorting by 'grouping_property' (e.g. 'groups')
parallel: bool
Expand All @@ -80,20 +86,26 @@ def run_sorter(sorter_name_or_class, recording, output_folder=None, delete_outpu
The spike sorted data

"""
if isinstance(sorter_name_or_class, str):
SorterClass = sorter_dict[sorter_name_or_class]
elif sorter_name_or_class in sorter_full_list:
SorterClass = sorter_name_or_class
if use_docker:
assert HAVE_DOCKER, "To run in docker, install docker and hitheron your system and >>> pip install hither docker"

# we need sorter name here
if isinstance(sorter_name_or_class, str):
sorter_name = sorter_name_or_class
elif sorter_name_or_class in sorter_full_list:
sorter_name = sorter_name_or_class.sorter_name
else:
raise ValueError('Unknown sorter')
sorting = _run_sorter_hither(sorter_name, recording, output_folder=output_folder,
delete_output_folder=delete_output_folder, grouping_property=grouping_property,
parallel=parallel, verbose=verbose, raise_error=raise_error, n_jobs=n_jobs,
joblib_backend=joblib_backend, **params)
else:
raise (ValueError('Unknown sorter'))

sorter = SorterClass(recording=recording, output_folder=output_folder, grouping_property=grouping_property,
verbose=verbose, delete_output_folder=delete_output_folder)
sorter.set_params(**params)
sorter.run(raise_error=raise_error, parallel=parallel, n_jobs=n_jobs, joblib_backend=joblib_backend)
sortingextractor = sorter.get_result(raise_error=raise_error)

return sortingextractor
sorting = _run_sorter_local(sorter_name_or_class, recording, output_folder=output_folder,
delete_output_folder=delete_output_folder, grouping_property=grouping_property,
parallel=parallel, verbose=verbose, raise_error=raise_error, n_jobs=n_jobs,
joblib_backend=joblib_backend, **params)
return sorting


def available_sorters():
Expand All @@ -110,6 +122,7 @@ def installed_sorters():
l = sorted([s.sorter_name for s in sorter_full_list if s.is_installed()])
return l


def print_sorter_versions():
"""
Prints versions of all installed sorters.
Expand Down