Skip to content

Commit

Permalink
Merge pull request #1996 from samuelgarcia/split_merge_clean
Browse files Browse the repository at this point in the history
implement proof of concept merge_clusters/split_clusters from tridesclous and columbia codes.
  • Loading branch information
samuelgarcia authored Oct 5, 2023
2 parents 20f20e4 + de2d642 commit cdc1ccb
Show file tree
Hide file tree
Showing 11 changed files with 1,200 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/core/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def is_set_global_dataset_folder():

########################################
global global_job_kwargs
global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True)
global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1)
global global_job_kwargs_set
global_job_kwargs_set = False

Expand Down
60 changes: 56 additions & 4 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,6 @@ def run(self):
self.gather_func(res)
else:
n_jobs = min(self.n_jobs, len(all_chunks))
######## Do you want to limit the number of threads per process?
######## It has to be done to speed up numpy a lot if multicores
######## Otherwise, np.dot will be slow. How to do that, up to you
######## This is just a suggestion, but here it adds a dependency

# parallel
with ProcessPoolExecutor(
Expand Down Expand Up @@ -436,3 +432,59 @@ def function_wrapper(args):
else:
with threadpool_limits(limits=max_threads_per_process):
return _func(segment_index, start_frame, end_frame, _worker_ctx)


# Here some utils copy/paste from DART (Charlie Windolf)


class MockFuture:
"""A non-concurrent class for mocking the concurrent.futures API."""

def __init__(self, f, *args):
self.f = f
self.args = args

def result(self):
return self.f(*self.args)


class MockPoolExecutor:
"""A non-concurrent class for mocking the concurrent.futures API."""

def __init__(
self,
max_workers=None,
mp_context=None,
initializer=None,
initargs=None,
context=None,
):
if initializer is not None:
initializer(*initargs)
self.map = map
self.imap = map

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
return

def submit(self, f, *args):
return MockFuture(f, *args)


class MockQueue:
"""Another helper class for turning off concurrency when debugging."""

def __init__(self):
self.q = []
self.put = self.q.append
self.get = lambda: self.q.pop(0)


def get_poolexecutor(n_jobs):
if n_jobs == 1:
return MockPoolExecutor
else:
return ProcessPoolExecutor
7 changes: 4 additions & 3 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def run_node_pipeline(
job_name="pipeline",
mp_context=None,
gather_mode="memory",
gather_kwargs={},
squeeze_output=True,
folder=None,
names=None,
Expand All @@ -448,7 +449,7 @@ def run_node_pipeline(
if gather_mode == "memory":
gather_func = GatherToMemory()
elif gather_mode == "npy":
gather_func = GatherToNpy(folder, names)
gather_func = GatherToNpy(folder, names, **gather_kwargs)
else:
raise ValueError(f"wrong gather_mode : {gather_mode}")

Expand Down Expand Up @@ -593,9 +594,9 @@ class GatherToNpy:
* create the npy v1.0 header at the end with the correct shape and dtype
"""

def __init__(self, folder, names, npy_header_size=1024):
def __init__(self, folder, names, npy_header_size=1024, exist_ok=False):
self.folder = Path(folder)
self.folder.mkdir(parents=True, exist_ok=False)
self.folder.mkdir(parents=True, exist_ok=exist_ok)
assert names is not None
self.names = names
self.npy_header_size = npy_header_size
Expand Down
10 changes: 7 additions & 3 deletions src/spikeinterface/core/tests/test_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,20 @@ def test_global_tmp_folder():


def test_global_job_kwargs():
job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True)
job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1)
global_job_kwargs = get_global_job_kwargs()
assert global_job_kwargs == dict(n_jobs=1, chunk_duration="1s", progress_bar=True)
assert global_job_kwargs == dict(
n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1
)
set_global_job_kwargs(**job_kwargs)
assert get_global_job_kwargs() == job_kwargs
# test updating only one field
partial_job_kwargs = dict(n_jobs=2)
set_global_job_kwargs(**partial_job_kwargs)
global_job_kwargs = get_global_job_kwargs()
assert global_job_kwargs == dict(n_jobs=2, chunk_duration="1s", progress_bar=True)
assert global_job_kwargs == dict(
n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1
)
# test that fix_job_kwargs grabs global kwargs
new_job_kwargs = dict(n_jobs=10)
job_kwargs_split = fix_job_kwargs(new_job_kwargs)
Expand Down
6 changes: 4 additions & 2 deletions src/spikeinterface/core/tests/test_waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ def test_recordingless():

# delete original recording and rely on rec_attributes
if platform.system() != "Windows":
# this avoid reference on the folder
del we, recording
shutil.rmtree(cache_folder / "recording1")
we_loaded = WaveformExtractor.load(wf_folder, with_recording=False)
assert not we_loaded.has_recording()
Expand Down Expand Up @@ -554,6 +556,6 @@ def test_non_json_object():
# test_WaveformExtractor()
# test_extract_waveforms()
# test_portability()
# test_recordingless()
test_recordingless()
# test_compute_sparsity()
test_non_json_object()
# test_non_json_object()
45 changes: 45 additions & 0 deletions src/spikeinterface/sortingcomponents/clustering/clean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np

from .tools import FeaturesLoader, compute_template_from_sparse

# This is work in progress ...


def clean_clusters(
peaks,
peak_labels,
recording,
features_dict_or_folder,
peak_sign="neg",
):
total_channels = recording.get_num_channels()

if isinstance(features_dict_or_folder, dict):
features = features_dict_or_folder
else:
features = FeaturesLoader(features_dict_or_folder)

clean_labels = peak_labels.copy()

sparse_wfs = features["sparse_wfs"]
sparse_mask = features["sparse_mask"]

labels_set = np.setdiff1d(peak_labels, [-1]).tolist()
n = len(labels_set)

count = np.zeros(n, dtype="int64")
for i, label in enumerate(labels_set):
count[i] = np.sum(peak_labels == label)
print(count)

templates = compute_template_from_sparse(peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels)

if peak_sign == "both":
max_values = np.max(np.abs(templates), axis=(1, 2))
elif peak_sign == "neg":
max_values = -np.min(templates, axis=(1, 2))
elif peak_sign == "pos":
max_values = np.max(templates, axis=(1, 2))
print(max_values)

return clean_labels
Loading

0 comments on commit cdc1ccb

Please sign in to comment.