Skip to content

Commit

Permalink
merge with main
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Oct 6, 2023
2 parents 5707879 + cdc1ccb commit c4a7609
Show file tree
Hide file tree
Showing 15 changed files with 1,221 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/core/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def is_set_global_dataset_folder():

########################################
global global_job_kwargs
global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True)
global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1)
global global_job_kwargs_set
global_job_kwargs_set = False

Expand Down
60 changes: 56 additions & 4 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,6 @@ def run(self):
self.gather_func(res)
else:
n_jobs = min(self.n_jobs, len(all_chunks))
######## Do you want to limit the number of threads per process?
######## It has to be done to speed up numpy a lot if multicores
######## Otherwise, np.dot will be slow. How to do that, up to you
######## This is just a suggestion, but here it adds a dependency

# parallel
with ProcessPoolExecutor(
Expand Down Expand Up @@ -436,3 +432,59 @@ def function_wrapper(args):
else:
with threadpool_limits(limits=max_threads_per_process):
return _func(segment_index, start_frame, end_frame, _worker_ctx)


# Here some utils copy/paste from DART (Charlie Windolf)


class MockFuture:
"""A non-concurrent class for mocking the concurrent.futures API."""

def __init__(self, f, *args):
self.f = f
self.args = args

def result(self):
return self.f(*self.args)


class MockPoolExecutor:
"""A non-concurrent class for mocking the concurrent.futures API."""

def __init__(
self,
max_workers=None,
mp_context=None,
initializer=None,
initargs=None,
context=None,
):
if initializer is not None:
initializer(*initargs)
self.map = map
self.imap = map

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
return

def submit(self, f, *args):
return MockFuture(f, *args)


class MockQueue:
"""Another helper class for turning off concurrency when debugging."""

def __init__(self):
self.q = []
self.put = self.q.append
self.get = lambda: self.q.pop(0)


def get_poolexecutor(n_jobs):
if n_jobs == 1:
return MockPoolExecutor
else:
return ProcessPoolExecutor
7 changes: 4 additions & 3 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def run_node_pipeline(
job_name="pipeline",
mp_context=None,
gather_mode="memory",
gather_kwargs={},
squeeze_output=True,
folder=None,
names=None,
Expand All @@ -448,7 +449,7 @@ def run_node_pipeline(
if gather_mode == "memory":
gather_func = GatherToMemory()
elif gather_mode == "npy":
gather_func = GatherToNpy(folder, names)
gather_func = GatherToNpy(folder, names, **gather_kwargs)
else:
raise ValueError(f"wrong gather_mode : {gather_mode}")

Expand Down Expand Up @@ -593,9 +594,9 @@ class GatherToNpy:
* create the npy v1.0 header at the end with the correct shape and dtype
"""

def __init__(self, folder, names, npy_header_size=1024):
def __init__(self, folder, names, npy_header_size=1024, exist_ok=False):
self.folder = Path(folder)
self.folder.mkdir(parents=True, exist_ok=False)
self.folder.mkdir(parents=True, exist_ok=exist_ok)
assert names is not None
self.names = names
self.npy_header_size = npy_header_size
Expand Down
10 changes: 7 additions & 3 deletions src/spikeinterface/core/tests/test_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,20 @@ def test_global_tmp_folder():


def test_global_job_kwargs():
job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True)
job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1)
global_job_kwargs = get_global_job_kwargs()
assert global_job_kwargs == dict(n_jobs=1, chunk_duration="1s", progress_bar=True)
assert global_job_kwargs == dict(
n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1
)
set_global_job_kwargs(**job_kwargs)
assert get_global_job_kwargs() == job_kwargs
# test updating only one field
partial_job_kwargs = dict(n_jobs=2)
set_global_job_kwargs(**partial_job_kwargs)
global_job_kwargs = get_global_job_kwargs()
assert global_job_kwargs == dict(n_jobs=2, chunk_duration="1s", progress_bar=True)
assert global_job_kwargs == dict(
n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1
)
# test that fix_job_kwargs grabs global kwargs
new_job_kwargs = dict(n_jobs=10)
job_kwargs_split = fix_job_kwargs(new_job_kwargs)
Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/core/tests/test_waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ def test_recordingless():

# delete original recording and rely on rec_attributes
if platform.system() != "Windows":
# this avoid reference on the folder
del we, recording
shutil.rmtree(cache_folder / "recording1")
we_loaded = WaveformExtractor.load(wf_folder, with_recording=False)
assert not we_loaded.has_recording()
Expand Down Expand Up @@ -554,7 +556,7 @@ def test_non_json_object():
# test_WaveformExtractor()
# test_extract_waveforms()
# test_portability()
# test_recordingless()
test_recordingless()
# test_compute_sparsity()
# test_non_json_object()
test_empty_sorting()
6 changes: 2 additions & 4 deletions src/spikeinterface/preprocessing/remove_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ def __init__(
time_jitter=0,
waveforms_kwargs={"allow_unfiltered": True, "mode": "memory"},
):
import scipy.interpolate

available_modes = ("zeros", "linear", "cubic", "average", "median")
num_seg = recording.get_num_segments()

Expand Down Expand Up @@ -236,8 +234,6 @@ def __init__(
time_pad,
sparsity,
):
import scipy.interpolate

BasePreprocessorSegment.__init__(self, parent_recording_segment)

self.triggers = np.asarray(triggers, dtype="int64")
Expand Down Expand Up @@ -285,6 +281,8 @@ def get_traces(self, start_frame, end_frame, channel_indices):
elif trig + pad[1] >= end_frame - start_frame:
traces[trig - pad[0] :, :] = 0
elif self.mode in ["linear", "cubic"]:
import scipy.interpolate

for trig in triggers:
if pad is None:
pre_data_end_idx = trig - 1
Expand Down
11 changes: 6 additions & 5 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from spikeinterface.core import NumpySorting, load_extractor, BaseRecording, get_noise_levels, extract_waveforms
from spikeinterface.core.job_tools import fix_job_kwargs
from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore
from spikeinterface.preprocessing import common_reference, zscore, whiten, highpass_filter

try:
import hdbscan
Expand All @@ -22,7 +22,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
_default_params = {
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 100},
"waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1},
"filtering": {"dtype": "float32"},
"filtering": {"freq_min": 150, "dtype": "float32"},
"detection": {"peak_sign": "neg", "detect_threshold": 5},
"selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000},
"localization": {},
Expand Down Expand Up @@ -60,11 +60,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
## First, we are filtering the data
filtering_params = params["filtering"].copy()
if params["apply_preprocessing"]:
recording_f = bandpass_filter(recording, **filtering_params)
recording_f = highpass_filter(recording, **filtering_params)
recording_f = common_reference(recording_f)
else:
recording_f = recording

# recording_f = whiten(recording_f, dtype="float32")
recording_f = zscore(recording_f, dtype="float32")

## Then, we are detecting peaks with a locally_exclusive method
Expand Down Expand Up @@ -98,10 +99,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
## We launch a clustering (using hdbscan) relying on positions and features extracted on
## the fly from the snippets
clustering_params = params["clustering"].copy()
clustering_params["waveforms_kwargs"] = params["waveforms"]
clustering_params["waveforms"] = params["waveforms"].copy()

for k in ["ms_before", "ms_after"]:
clustering_params["waveforms_kwargs"][k] = params["general"][k]
clustering_params["waveforms"][k] = params["general"][k]

clustering_params.update(dict(shared_memory=params["shared_memory"]))
clustering_params["job_kwargs"] = job_kwargs
Expand Down
45 changes: 45 additions & 0 deletions src/spikeinterface/sortingcomponents/clustering/clean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np

from .tools import FeaturesLoader, compute_template_from_sparse

# This is work in progress ...


def clean_clusters(
peaks,
peak_labels,
recording,
features_dict_or_folder,
peak_sign="neg",
):
total_channels = recording.get_num_channels()

if isinstance(features_dict_or_folder, dict):
features = features_dict_or_folder
else:
features = FeaturesLoader(features_dict_or_folder)

clean_labels = peak_labels.copy()

sparse_wfs = features["sparse_wfs"]
sparse_mask = features["sparse_mask"]

labels_set = np.setdiff1d(peak_labels, [-1]).tolist()
n = len(labels_set)

count = np.zeros(n, dtype="int64")
for i, label in enumerate(labels_set):
count[i] = np.sum(peak_labels == label)
print(count)

templates = compute_template_from_sparse(peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels)

if peak_sign == "both":
max_values = np.max(np.abs(templates), axis=(1, 2))
elif peak_sign == "neg":
max_values = -np.min(templates, axis=(1, 2))
elif peak_sign == "pos":
max_values = np.max(templates, axis=(1, 2))
print(max_values)

return clean_labels
Original file line number Diff line number Diff line change
Expand Up @@ -593,29 +593,34 @@ def remove_duplicates_via_matching(

chunk_size = duration + 3 * margin

method_kwargs.update(
local_params = method_kwargs.copy()

local_params.update(
{
"waveform_extractor": waveform_extractor,
"noise_levels": noise_levels,
"amplitudes": [0.95, 1.05],
"omp_min_sps": 0.1,
"omp_min_sps": 0.05,
}
)

spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()["unit_index"], return_counts=True)
indices = np.argsort(counts)

ignore_ids = []
similar_templates = [[], []]

for i in range(nb_templates):
for i in indices:
t_start = padding + i * duration
t_stop = padding + (i + 1) * duration

sub_recording = recording.frame_slice(t_start - half_marging, t_stop + half_marging)
method_kwargs.update({"ignored_ids": ignore_ids + [i]})
local_params.update({"ignored_ids": ignore_ids + [i]})
spikes, computed = find_spikes_from_templates(
sub_recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs
sub_recording, method=method, method_kwargs=local_params, extra_outputs=True, **job_kwargs
)
if method == "circus-omp-svd":
method_kwargs.update(
local_params.update(
{
"overlaps": computed["overlaps"],
"templates": computed["templates"],
Expand All @@ -629,7 +634,7 @@ def remove_duplicates_via_matching(
}
)
elif method == "circus-omp":
method_kwargs.update(
local_params.update(
{
"overlaps": computed["overlaps"],
"templates": computed["templates"],
Expand Down Expand Up @@ -661,7 +666,7 @@ def remove_duplicates_via_matching(
labels = np.unique(new_labels)
labels = labels[labels >= 0]

del recording, sub_recording, method_kwargs
del recording, sub_recording, local_params, waveform_extractor
os.remove(tmp_filename)

return labels, new_labels
Expand Down
Loading

0 comments on commit c4a7609

Please sign in to comment.