This repository has been archived by the owner on Jun 6, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 20
Run sorters in docker through hither #223
Draft
alejoe91
wants to merge
6
commits into
master
Choose a base branch
from
docker
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
ca5719e
WIP
alejoe91 d71c569
more wip
alejoe91 c31b3d6
Fix imports
alejoe91 7802ff3
Remove local module from hither decorator
alejoe91 5c7e067
Move run functions in separate sub-folder
alejoe91 5382d81
Correctly import run functions
alejoe91 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import spikeextractors as se | ||
import spikesorters as ss | ||
|
||
|
||
rec, _ = se.example_datasets.toy_example(dumpable=True) | ||
|
||
output_folder = "ms4_test_docker" | ||
|
||
sorting = ss.run_klusta(rec, output_folder=output_folder, use_docker=True) | ||
|
||
print(f"KL found #{len(sorting.get_unit_ids())} units") | ||
|
||
|
||
# output_folder = "kl_test_docker" | ||
# | ||
# sorting_KL = ssd.run_klusta(rec, output_folder=output_folder) | ||
# | ||
# print(f"KL found #{len(sorting_KL.get_unit_ids())} units") | ||
# | ||
# rec, _ = se.example_datasets.toy_example(dumpable=True) | ||
# | ||
# output_folder = "sc_test_docker" | ||
# | ||
# sorting_SC = ssd.run_spykingcircus(rec, output_folder=output_folder) | ||
# | ||
# print(f"SC found #{len(sorting_SC.get_unit_ids())} units") | ||
# | ||
# rec, _ = se.example_datasets.toy_example(dumpable=True) | ||
# | ||
# output_folder = "hs_test_docker" | ||
# | ||
# sorting_HS = ssd.run_herdingspikes(rec, output_folder=output_folder) | ||
# | ||
# print(f"HS found #{len(sorting_HS.get_unit_ids())} units") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import spikeextractors as se | ||
import time | ||
import numpy as np | ||
from pathlib import Path | ||
|
||
ss_folder = Path(__file__).parent | ||
|
||
try: | ||
import hither2 as hither | ||
import docker | ||
|
||
HAVE_DOCKER = True | ||
|
||
default_docker_images = { | ||
"klusta": hither.DockerImageFromScript(name="klusta", dockerfile=str(ss_folder / "docker_images" / "v0.12" / "klusta" / "Dockerfile")), | ||
"mountainsort4": hither.DockerImageFromScript(name="ms4", dockerfile=str(ss_folder / "docker_images" / "v0.12" / "mountainsort4" / "Dockerfile")), | ||
"herdingspikes": hither.LocalDockerImage('spikeinterface/herdingspikes-si-0.12:0.3.7'), | ||
"spykingcircus": hither.LocalDockerImage('spikeinterface/spyking-circus-si-0.12:1.0.7') | ||
} | ||
|
||
except ImportError: | ||
HAVE_DOCKER = False | ||
|
||
|
||
def modify_input_folder(dump_dict, input_folder="/input"): | ||
if "kwargs" in dump_dict.keys(): | ||
dcopy_kwargs, folder_to_mount = modify_input_folder(dump_dict["kwargs"]) | ||
dump_dict["kwargs"] = dcopy_kwargs | ||
return dump_dict, folder_to_mount | ||
else: | ||
if "file_path" in dump_dict: | ||
file_path = Path(dump_dict["file_path"]) | ||
folder_to_mount = file_path.parent | ||
file_relative = file_path.relative_to(folder_to_mount) | ||
dump_dict["file_path"] = f"{input_folder}/{str(file_relative)}" | ||
return dump_dict, folder_to_mount | ||
elif "folder_path" in dump_dict: | ||
folder_path = Path(dump_dict["folder_path"]) | ||
folder_to_mount = folder_path.parent | ||
folder_relative = folder_path.relative_to(folder_to_mount) | ||
dump_dict["folder_path"] = f"{input_folder}/{str(folder_relative)}" | ||
return dump_dict, folder_to_mount | ||
elif "file_or_folder_path" in dump_dict: | ||
file_or_folder_path = Path(dump_dict["file_or_folder_path"]) | ||
folder_to_mount = file_or_folder_path.parent | ||
file_or_folder_relative = file_or_folder_path.relative_to(folder_to_mount) | ||
dump_dict["file_or_folder_path"] = f"{input_folder}/{str(file_or_folder_relative)}" | ||
return dump_dict, folder_to_mount | ||
else: | ||
raise Exception | ||
|
||
|
||
def return_local_data_folder(recording, input_folder='/input'): | ||
""" | ||
Modifies recording dictionary so that the file_path, folder_path, or file_or_folder path is relative to the | ||
'input_folder' | ||
|
||
Parameters | ||
---------- | ||
recording: se.RecordingExtractor | ||
input_folder: str | ||
|
||
Returns | ||
------- | ||
dump_dict: dict | ||
|
||
""" | ||
assert recording.is_dumpable | ||
from copy import deepcopy | ||
|
||
d = recording.dump_to_dict() | ||
dcopy = deepcopy(d) | ||
|
||
return modify_input_folder(dcopy, input_folder) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .run_functions import _run_sorter_local, _run_sorter_hither |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
from ..docker_tools import HAVE_DOCKER | ||
from ..sorterlist import sorter_dict, sorter_full_list | ||
|
||
|
||
if HAVE_DOCKER: | ||
# conditional definition of hither tools | ||
import time | ||
from pathlib import Path | ||
import hither2 as hither | ||
import spikeextractors as se | ||
import numpy as np | ||
import shutil | ||
from ..docker_tools import modify_input_folder, default_docker_images | ||
|
||
class SpikeSortingDockerHook(hither.RuntimeHook): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def precontainer(self, context: hither.PreContainerContext): | ||
# this gets run outside the container before the run, and we have a chance to mutate the kwargs, | ||
# add bind mounts, and set the image | ||
input_directory = context.kwargs['input_directory'] | ||
output_directory = context.kwargs['output_directory'] | ||
|
||
print("Input:", input_directory) | ||
print("Output:", output_directory) | ||
context.add_bind_mount(hither.BindMount(source=input_directory, | ||
target='/input', read_only=True)) | ||
context.add_bind_mount(hither.BindMount(source=output_directory, | ||
target='/output', read_only=False)) | ||
context.image = default_docker_images[context.kwargs['sorter_name']] | ||
context.kwargs['output_directory'] = '/output' | ||
context.kwargs['input_directory'] = '/input' | ||
|
||
|
||
@hither.function('run_sorter_docker_with_container', | ||
'0.1.0', | ||
image=True, | ||
runtime_hooks=[SpikeSortingDockerHook()]) | ||
def run_sorter_docker_with_container( | ||
recording_dict, sorter_name, input_directory, output_directory, **kwargs | ||
): | ||
recording = se.load_extractor_from_dict(recording_dict) | ||
# run sorter | ||
kwargs["output_folder"] = f"{output_directory}/working" | ||
t_start = time.time() | ||
# set output folder within the container | ||
sorting = _run_sorter_local(sorter_name, recording, **kwargs) | ||
t_stop = time.time() | ||
print(f'{sorter_name} run time {np.round(t_stop - t_start)}s') | ||
# save sorting to npz | ||
se.NpzSortingExtractor.write_sorting(sorting, f"{output_directory}/sorting_docker.npz") | ||
|
||
def _run_sorter_hither(sorter_name, recording, output_folder=None, delete_output_folder=False, | ||
grouping_property=None, parallel=False, verbose=False, raise_error=True, | ||
n_jobs=-1, joblib_backend='loky', **params): | ||
assert recording.is_dumpable, "Cannot run not dumpable recordings in docker" | ||
if output_folder is None: | ||
output_folder = sorter_name + '_output' | ||
output_folder = Path(output_folder).absolute() | ||
output_folder.mkdir(exist_ok=True, parents=True) | ||
|
||
with hither.Config(use_container=True, show_console=True): | ||
dump_dict_container, input_directory = modify_input_folder(recording.dump_to_dict(), '/input') | ||
print(dump_dict_container) | ||
kwargs = dict(recording_dict=dump_dict_container, | ||
sorter_name=sorter_name, | ||
output_folder=str(output_folder), | ||
delete_output_folder=False, | ||
grouping_property=grouping_property, parallel=parallel, | ||
verbose=verbose, raise_error=raise_error, n_jobs=n_jobs, | ||
joblib_backend=joblib_backend) | ||
|
||
kwargs.update(params) | ||
kwargs.update({'input_directory': str(input_directory), 'output_directory': str(output_folder)}) | ||
sorting_job = hither.Job(run_sorter_docker_with_container, kwargs) | ||
sorting_job.wait() | ||
sorting = se.NpzSortingExtractor(output_folder / "sorting_docker.npz") | ||
if delete_output_folder: | ||
shutil.rmtree(output_folder) | ||
return sorting | ||
else: | ||
def _run_sorter_hither(sorter_name, recording, output_folder=None, delete_output_folder=False, | ||
grouping_property=None, parallel=False, verbose=False, raise_error=True, | ||
n_jobs=-1, joblib_backend='loky', **params): | ||
raise ImportError() | ||
|
||
|
||
# generic launcher via function approach | ||
def _run_sorter_local(sorter_name_or_class, recording, output_folder=None, delete_output_folder=False, | ||
grouping_property=None, parallel=False, verbose=False, raise_error=True, n_jobs=-1, | ||
joblib_backend='loky', **params): | ||
""" | ||
Generic function to run a sorter via function approach. | ||
|
||
Two usages with name or class: | ||
|
||
by name: | ||
>>> sorting = run_sorter('tridesclous', recording) | ||
|
||
by class: | ||
>>> sorting = run_sorter(TridesclousSorter, recording) | ||
|
||
Parameters | ||
---------- | ||
sorter_name_or_class: str or SorterClass | ||
The sorter to retrieve default parameters from | ||
recording: RecordingExtractor | ||
The recording extractor to be spike sorted | ||
output_folder: str or Path | ||
Path to output folder | ||
delete_output_folder: bool | ||
If True, output folder is deleted (default False) | ||
grouping_property: str | ||
Splits spike sorting by 'grouping_property' (e.g. 'groups') | ||
parallel: bool | ||
If True and spike sorting is by 'grouping_property', spike sorting jobs are launched in parallel | ||
verbose: bool | ||
If True, output is verbose | ||
raise_error: bool | ||
If True, an error is raised if spike sorting fails (default). If False, the process continues and the error is | ||
logged in the log file. | ||
n_jobs: int | ||
Number of jobs when parallel=True (default=-1) | ||
joblib_backend: str | ||
joblib backend when parallel=True (default='loky') | ||
**params: keyword args | ||
Spike sorter specific arguments (they can be retrieved with 'get_default_params(sorter_name_or_class)' | ||
|
||
Returns | ||
------- | ||
sortingextractor: SortingExtractor | ||
The spike sorted data | ||
|
||
""" | ||
if isinstance(sorter_name_or_class, str): | ||
SorterClass = sorter_dict[sorter_name_or_class] | ||
elif sorter_name_or_class in sorter_full_list: | ||
SorterClass = sorter_name_or_class | ||
else: | ||
raise ValueError('Unknown sorter') | ||
|
||
sorter = SorterClass(recording=recording, output_folder=output_folder, grouping_property=grouping_property, | ||
verbose=verbose, delete_output_folder=delete_output_folder) | ||
sorter.set_params(**params) | ||
sorter.run(raise_error=raise_error, parallel=parallel, n_jobs=n_jobs, joblib_backend=joblib_backend) | ||
sortingextractor = sorter.get_result(raise_error=raise_error) | ||
|
||
return sortingextractor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,11 @@ | |
from .yass import YassSorter | ||
from .combinato import CombinatoSorter | ||
|
||
|
||
from .run_funtions import _run_sorter_local, _run_sorter_hither | ||
from .docker_tools import HAVE_DOCKER | ||
|
||
|
||
sorter_full_list = [ | ||
HDSortSorter, | ||
KlustaSorter, | ||
|
@@ -33,10 +38,9 @@ | |
sorter_dict = {s.sorter_name: s for s in sorter_full_list} | ||
|
||
|
||
# generic launcher via function approach | ||
def run_sorter(sorter_name_or_class, recording, output_folder=None, delete_output_folder=False, | ||
grouping_property=None, parallel=False, verbose=False, raise_error=True, n_jobs=-1, joblib_backend='loky', | ||
**params): | ||
grouping_property=None, use_docker=False, parallel=False, verbose=False, raise_error=True, n_jobs=-1, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would do |
||
joblib_backend='loky', **params): | ||
""" | ||
Generic function to run a sorter via function approach. | ||
|
||
|
@@ -58,6 +62,8 @@ def run_sorter(sorter_name_or_class, recording, output_folder=None, delete_outpu | |
Path to output folder | ||
delete_output_folder: bool | ||
If True, output folder is deleted (default False) | ||
use_docker: bool | ||
If True and docker backend is installed, spike sorting is run in a docker image | ||
grouping_property: str | ||
Splits spike sorting by 'grouping_property' (e.g. 'groups') | ||
parallel: bool | ||
|
@@ -80,20 +86,26 @@ def run_sorter(sorter_name_or_class, recording, output_folder=None, delete_outpu | |
The spike sorted data | ||
|
||
""" | ||
if isinstance(sorter_name_or_class, str): | ||
SorterClass = sorter_dict[sorter_name_or_class] | ||
elif sorter_name_or_class in sorter_full_list: | ||
SorterClass = sorter_name_or_class | ||
if use_docker: | ||
assert HAVE_DOCKER, "To run in docker, install docker and hitheron your system and >>> pip install hither docker" | ||
|
||
# we need sorter name here | ||
if isinstance(sorter_name_or_class, str): | ||
sorter_name = sorter_name_or_class | ||
elif sorter_name_or_class in sorter_full_list: | ||
sorter_name = sorter_name_or_class.sorter_name | ||
else: | ||
raise ValueError('Unknown sorter') | ||
sorting = _run_sorter_hither(sorter_name, recording, output_folder=output_folder, | ||
delete_output_folder=delete_output_folder, grouping_property=grouping_property, | ||
parallel=parallel, verbose=verbose, raise_error=raise_error, n_jobs=n_jobs, | ||
joblib_backend=joblib_backend, **params) | ||
else: | ||
raise (ValueError('Unknown sorter')) | ||
|
||
sorter = SorterClass(recording=recording, output_folder=output_folder, grouping_property=grouping_property, | ||
verbose=verbose, delete_output_folder=delete_output_folder) | ||
sorter.set_params(**params) | ||
sorter.run(raise_error=raise_error, parallel=parallel, n_jobs=n_jobs, joblib_backend=joblib_backend) | ||
sortingextractor = sorter.get_result(raise_error=raise_error) | ||
|
||
return sortingextractor | ||
sorting = _run_sorter_local(sorter_name_or_class, recording, output_folder=output_folder, | ||
delete_output_folder=delete_output_folder, grouping_property=grouping_property, | ||
parallel=parallel, verbose=verbose, raise_error=raise_error, n_jobs=n_jobs, | ||
joblib_backend=joblib_backend, **params) | ||
return sorting | ||
|
||
|
||
def available_sorters(): | ||
|
@@ -110,6 +122,7 @@ def installed_sorters(): | |
l = sorted([s.sorter_name for s in sorter_full_list if s.is_installed()]) | ||
return l | ||
|
||
|
||
def print_sorter_versions(): | ||
""" | ||
Prints versions of all installed sorters. | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would name this hither_tools instead of docker tools