Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement proof of concept merge_clusters/split_clusters from tridesclous and columbia codes. #1996

Merged
merged 11 commits into from
Oct 5, 2023
2 changes: 1 addition & 1 deletion src/spikeinterface/core/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def is_set_global_dataset_folder():

########################################
global global_job_kwargs
global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True)
global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1)
global global_job_kwargs_set
global_job_kwargs_set = False

Expand Down
60 changes: 56 additions & 4 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,6 @@ def run(self):
self.gather_func(res)
else:
n_jobs = min(self.n_jobs, len(all_chunks))
######## Do you want to limit the number of threads per process?
######## It has to be done to speed up numpy a lot if multicores
######## Otherwise, np.dot will be slow. How to do that, up to you
######## This is just a suggestion, but here it adds a dependency

# parallel
with ProcessPoolExecutor(
Expand Down Expand Up @@ -436,3 +432,59 @@ def function_wrapper(args):
else:
with threadpool_limits(limits=max_threads_per_process):
return _func(segment_index, start_frame, end_frame, _worker_ctx)


# Here some utils copy/paste from DART (Charlie Windolf)


class MockFuture:
"""A non-concurrent class for mocking the concurrent.futures API."""

def __init__(self, f, *args):
self.f = f
self.args = args

def result(self):
return self.f(*self.args)


class MockPoolExecutor:
"""A non-concurrent class for mocking the concurrent.futures API."""

def __init__(
self,
max_workers=None,
mp_context=None,
initializer=None,
initargs=None,
context=None,
):
if initializer is not None:
initializer(*initargs)
self.map = map
self.imap = map

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
return

def submit(self, f, *args):
return MockFuture(f, *args)


class MockQueue:
"""Another helper class for turning off concurrency when debugging."""

def __init__(self):
self.q = []
self.put = self.q.append
self.get = lambda: self.q.pop(0)


def get_poolexecutor(n_jobs):
if n_jobs == 1:
return MockPoolExecutor
else:
return ProcessPoolExecutor
7 changes: 4 additions & 3 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def run_node_pipeline(
job_name="pipeline",
mp_context=None,
gather_mode="memory",
gather_kwargs={},
squeeze_output=True,
folder=None,
names=None,
Expand All @@ -448,7 +449,7 @@ def run_node_pipeline(
if gather_mode == "memory":
gather_func = GatherToMemory()
elif gather_mode == "npy":
gather_func = GatherToNpy(folder, names)
gather_func = GatherToNpy(folder, names, **gather_kwargs)
else:
raise ValueError(f"wrong gather_mode : {gather_mode}")

Expand Down Expand Up @@ -593,9 +594,9 @@ class GatherToNpy:
* create the npy v1.0 header at the end with the correct shape and dtype
"""

def __init__(self, folder, names, npy_header_size=1024):
def __init__(self, folder, names, npy_header_size=1024, exist_ok=False):
self.folder = Path(folder)
self.folder.mkdir(parents=True, exist_ok=False)
self.folder.mkdir(parents=True, exist_ok=exist_ok)
assert names is not None
self.names = names
self.npy_header_size = npy_header_size
Expand Down
10 changes: 7 additions & 3 deletions src/spikeinterface/core/tests/test_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,20 @@ def test_global_tmp_folder():


def test_global_job_kwargs():
job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True)
job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1)
global_job_kwargs = get_global_job_kwargs()
assert global_job_kwargs == dict(n_jobs=1, chunk_duration="1s", progress_bar=True)
assert global_job_kwargs == dict(
n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1
)
set_global_job_kwargs(**job_kwargs)
assert get_global_job_kwargs() == job_kwargs
# test updating only one field
partial_job_kwargs = dict(n_jobs=2)
set_global_job_kwargs(**partial_job_kwargs)
global_job_kwargs = get_global_job_kwargs()
assert global_job_kwargs == dict(n_jobs=2, chunk_duration="1s", progress_bar=True)
assert global_job_kwargs == dict(
n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1
)
# test that fix_job_kwargs grabs global kwargs
new_job_kwargs = dict(n_jobs=10)
job_kwargs_split = fix_job_kwargs(new_job_kwargs)
Expand Down
6 changes: 4 additions & 2 deletions src/spikeinterface/core/tests/test_waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ def test_recordingless():

# delete original recording and rely on rec_attributes
if platform.system() != "Windows":
# this avoid reference on the folder
del we, recording
shutil.rmtree(cache_folder / "recording1")
we_loaded = WaveformExtractor.load(wf_folder, with_recording=False)
assert not we_loaded.has_recording()
Expand Down Expand Up @@ -554,6 +556,6 @@ def test_non_json_object():
# test_WaveformExtractor()
# test_extract_waveforms()
# test_portability()
# test_recordingless()
test_recordingless()
# test_compute_sparsity()
test_non_json_object()
# test_non_json_object()
45 changes: 45 additions & 0 deletions src/spikeinterface/sortingcomponents/clustering/clean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np

from .tools import FeaturesLoader, compute_template_from_sparse

# This is work in progress ...


def clean_clusters(
peaks,
peak_labels,
recording,
features_dict_or_folder,
peak_sign="neg",
):
total_channels = recording.get_num_channels()

if isinstance(features_dict_or_folder, dict):
features = features_dict_or_folder
else:
features = FeaturesLoader(features_dict_or_folder)

clean_labels = peak_labels.copy()

sparse_wfs = features["sparse_wfs"]
sparse_mask = features["sparse_mask"]

labels_set = np.setdiff1d(peak_labels, [-1]).tolist()
n = len(labels_set)

count = np.zeros(n, dtype="int64")
for i, label in enumerate(labels_set):
count[i] = np.sum(peak_labels == label)
print(count)

templates = compute_template_from_sparse(peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels)

if peak_sign == "both":
max_values = np.max(np.abs(templates), axis=(1, 2))
elif peak_sign == "neg":
max_values = -np.min(templates, axis=(1, 2))
elif peak_sign == "pos":
max_values = np.max(templates, axis=(1, 2))
print(max_values)

return clean_labels
Loading