From 5fd2627fbcf523dd8ac9c16706120e0e82930942 Mon Sep 17 00:00:00 2001 From: Matthias H Hennig Date: Thu, 29 Jun 2023 12:44:35 +0100 Subject: [PATCH 01/74] Allow any integer type. --- src/spikeinterface/core/numpyextractors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 398ef18130..17c2849b6d 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -234,7 +234,7 @@ class NumpySortingSegment(BaseSortingSegment): def __init__(self, units_dict): BaseSortingSegment.__init__(self) for unit_id, times in units_dict.items(): - assert times.dtype.kind == 'i', 'numpy array of spike times must be integer' + assert (times.dtype.kind == 'i') or (times.dtype.kind == 'u'), 'numpy array of spike times must be integer' assert np.all(np.diff(times) >= 0), 'unsorted times' self._units_dict = units_dict From 077a7fe28932be5d8dbf81bb946529c4ca6e90f9 Mon Sep 17 00:00:00 2001 From: Matthias H Hennig Date: Thu, 29 Jun 2023 12:45:48 +0100 Subject: [PATCH 02/74] Fix problem with non-numeric unit IDs. --- src/spikeinterface/extractors/mdaextractors.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index 5b97f5de07..2d4b98635e 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -197,10 +197,14 @@ def write_sorting(sorting, save_path, write_primary_channels=False): times_list = [] labels_list = [] primary_channels_list = [] - for unit_id in unit_ids: + for unit_id_i, unit_id in enumerate(unit_ids): times = sorting.get_unit_spike_train(unit_id=unit_id) times_list.append(times) - labels_list.append(np.ones(times.shape) * unit_id) + # unit id may not be numeric + if unit_id.dtype.kind in 'biufc': + labels_list.append(np.ones(times.shape) * unit_id) + else: + labels_list.append(np.ones(times.shape) * unit_id_i) if write_primary_channels: if 'max_channel' in sorting.get_unit_property_names(unit_id): primary_channels_list.append([sorting.get_unit_property(unit_id, 'max_channel')] * times.shape[0]) From 9b86b485cd6861469e1bd6ed7fd26bd18c59391d Mon Sep 17 00:00:00 2001 From: Matthias H Hennig Date: Mon, 31 Jul 2023 21:21:09 +0100 Subject: [PATCH 03/74] Fixed docstring --- src/spikeinterface/core/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 9b300e4787..817cb95d66 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -555,12 +555,12 @@ def dump_to_pickle( ): """ Dump recording extractor to a pickle file. - The extractor can be re-loaded with load_extractor_from_json(json_file) + The extractor can be re-loaded with load_extractor_from_pickle(pickle_file) Parameters ---------- file_path: str - Path of the json file + Path of the pickle file include_properties: bool If True, all properties are dumped relative_to: str, Path, or None From 8349b90593622af022fa6b80ede0bc021296e5d6 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 14 Sep 2023 21:07:04 +0200 Subject: [PATCH 04/74] implement proof of concept merge_clusters/split_clusters from tridesclous and columbia codes. --- src/spikeinterface/core/globals.py | 2 +- src/spikeinterface/core/job_tools.py | 60 ++- src/spikeinterface/core/node_pipeline.py | 7 +- .../sortingcomponents/clustering/clean.py | 45 ++ .../sortingcomponents/clustering/merge.py | 500 ++++++++++++++++++ .../sortingcomponents/clustering/split.py | 260 +++++++++ .../sortingcomponents/clustering/tools.py | 196 +++++++ .../sortingcomponents/tests/test_split.py | 12 + 8 files changed, 1074 insertions(+), 8 deletions(-) create mode 100644 src/spikeinterface/sortingcomponents/clustering/clean.py create mode 100644 src/spikeinterface/sortingcomponents/clustering/merge.py create mode 100644 src/spikeinterface/sortingcomponents/clustering/split.py create mode 100644 src/spikeinterface/sortingcomponents/clustering/tools.py create mode 100644 src/spikeinterface/sortingcomponents/tests/test_split.py diff --git a/src/spikeinterface/core/globals.py b/src/spikeinterface/core/globals.py index e5581c7a67..d039206296 100644 --- a/src/spikeinterface/core/globals.py +++ b/src/spikeinterface/core/globals.py @@ -96,7 +96,7 @@ def is_set_global_dataset_folder(): ######################################## global global_job_kwargs -global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True) +global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) global global_job_kwargs_set global_job_kwargs_set = False diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index c0ee77d2fd..3e25d64983 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -380,10 +380,6 @@ def run(self): self.gather_func(res) else: n_jobs = min(self.n_jobs, len(all_chunks)) - ######## Do you want to limit the number of threads per process? - ######## It has to be done to speed up numpy a lot if multicores - ######## Otherwise, np.dot will be slow. How to do that, up to you - ######## This is just a suggestion, but here it adds a dependency # parallel with ProcessPoolExecutor( @@ -436,3 +432,59 @@ def function_wrapper(args): else: with threadpool_limits(limits=max_threads_per_process): return _func(segment_index, start_frame, end_frame, _worker_ctx) + + +# Here some utils + + +class MockFuture: + """A non-concurrent class for mocking the concurrent.futures API.""" + + def __init__(self, f, *args): + self.f = f + self.args = args + + def result(self): + return self.f(*self.args) + + +class MockPoolExecutor: + """A non-concurrent class for mocking the concurrent.futures API.""" + + def __init__( + self, + max_workers=None, + mp_context=None, + initializer=None, + initargs=None, + context=None, + ): + if initializer is not None: + initializer(*initargs) + self.map = map + self.imap = map + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return + + def submit(self, f, *args): + return MockFuture(f, *args) + + +class MockQueue: + """Another helper class for turning off concurrency when debugging.""" + + def __init__(self): + self.q = [] + self.put = self.q.append + self.get = lambda: self.q.pop(0) + + +def get_poolexecutor(n_jobs): + if n_jobs == 1: + return MockPoolExecutor + else: + return ProcessPoolExecutor diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index b11f40a441..cd858da1e1 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -436,6 +436,7 @@ def run_node_pipeline( job_name="pipeline", mp_context=None, gather_mode="memory", + gather_kwargs={}, squeeze_output=True, folder=None, names=None, @@ -452,7 +453,7 @@ def run_node_pipeline( if gather_mode == "memory": gather_func = GatherToMemory() elif gather_mode == "npy": - gather_func = GatherToNpy(folder, names) + gather_func = GatherToNpy(folder, names, **gather_kwargs) else: raise ValueError(f"wrong gather_mode : {gather_mode}") @@ -597,9 +598,9 @@ class GatherToNpy: * create the npy v1.0 header at the end with the correct shape and dtype """ - def __init__(self, folder, names, npy_header_size=1024): + def __init__(self, folder, names, npy_header_size=1024, exist_ok=False): self.folder = Path(folder) - self.folder.mkdir(parents=True, exist_ok=False) + self.folder.mkdir(parents=True, exist_ok=exist_ok) assert names is not None self.names = names self.npy_header_size = npy_header_size diff --git a/src/spikeinterface/sortingcomponents/clustering/clean.py b/src/spikeinterface/sortingcomponents/clustering/clean.py new file mode 100644 index 0000000000..cbded0c49f --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/clean.py @@ -0,0 +1,45 @@ +import numpy as np + +from .tools import FeaturesLoader, compute_template_from_sparse + +# This is work in progress ... + + +def clean_clusters( + peaks, + peak_labels, + recording, + features_dict_or_folder, + peak_sign="neg", +): + total_channels = recording.get_num_channels() + + if isinstance(features_dict_or_folder, dict): + features = features_dict_or_folder + else: + features = FeaturesLoader(features_dict_or_folder) + + clean_labels = peak_labels.copy() + + sparse_wfs = features["sparse_wfs"] + sparse_mask = features["sparse_mask"] + + labels_set = np.setdiff1d(peak_labels, [-1]).tolist() + n = len(labels_set) + + count = np.zeros(n, dtype="int64") + for i, label in enumerate(labels_set): + count[i] = np.sum(peak_labels == label) + print(count) + + templates = compute_template_from_sparse(peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels) + + if peak_sign == "both": + max_values = np.max(np.abs(templates), axis=(1, 2)) + elif peak_sign == "neg": + max_values = -np.min(templates, axis=(1, 2)) + elif peak_sign == "pos": + max_values = np.max(templates, axis=(1, 2)) + print(max_values) + + return clean_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py new file mode 100644 index 0000000000..2e839ef0fc --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -0,0 +1,500 @@ +from pathlib import Path +from multiprocessing import get_context +from concurrent.futures import ProcessPoolExecutor +from threadpoolctl import threadpool_limits +from tqdm.auto import tqdm + +import scipy.spatial +from sklearn.decomposition import PCA +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from hdbscan import HDBSCAN + +import numpy as np +import networkx as nx + +from spikeinterface.core.job_tools import get_poolexecutor, fix_job_kwargs + + +from .isocut5 import isocut5 + +from .tools import aggregate_sparse_features, FeaturesLoader, compute_template_from_sparse + + +def merge_clusters( + peaks, + peak_labels, + recording, + features_dict_or_folder, + radius_um=70, + method="waveforms_lda", + method_kwargs={}, + **job_kwargs, +): + """ + Merge cluster using differents methods. + + Parameters + ---------- + peaks: numpy.ndarray 1d + detected peaks (or a subset) + peak_labels: numpy.ndarray 1d + original label before merge + peak_labels.size == peaks.size + recording: Recording object + A recording object + features_dict_or_folder: dict or folder + A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features. + method: str + The method used + method_kwargs: dict + Option for the method. + Returns + ------- + merge_peak_labels: numpy.ndarray 1d + New vectors label after merges. + peak_shifts: numpy.ndarray 1d + A vector of sample shift to be reverse applied on original sample_index on peak detection + Negative shift means too early. + Posituve shift means too late. + So the correction must be applied like this externaly: + final_peaks = peaks.copy() + final_peaks['sample_index'] -= peak_shifts + + """ + + job_kwargs = fix_job_kwargs(job_kwargs) + + features = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) + sparse_wfs = features["sparse_wfs"] + sparse_mask = features["sparse_mask"] + + labels_set, pair_mask, pair_shift, pair_values = find_merge_pairs( + peaks, + peak_labels, + recording, + features_dict_or_folder, + sparse_wfs, + sparse_mask, + radius_um=radius_um, + method=method, + method_kwargs=method_kwargs, + **job_kwargs, + ) + + # merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="partial") + merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full") + + group_shifts = resolve_final_shifts(labels_set, merges, pair_mask, pair_shift) + + # apply final label and shift + merge_peak_labels = peak_labels.copy() + peak_shifts = np.zeros(peak_labels.size, dtype="int64") + for merge, shifts in zip(merges, group_shifts): + label0 = merge[0] + mask = np.in1d(peak_labels, merge) + merge_peak_labels[mask] = label0 + for l, label1 in enumerate(merge): + if l == 0: + # the first label is the reference (shift=0) + continue + peak_shifts[peak_labels == label1] = shifts[l] + + return merge_peak_labels, peak_shifts + + +def resolve_final_shifts(labels_set, merges, pair_mask, pair_shift): + labels_set = list(labels_set) + + group_shifts = [] + for merge in merges: + shifts = np.zeros(len(merge), dtype="int64") + + label_inds = [labels_set.index(label) for label in merge] + + label0 = merge[0] + ind0 = label_inds[0] + + # First find relative shift to label0 (l=0) in the subgraph + local_pair_mask = pair_mask[label_inds, :][:, label_inds] + local_pair_shift = None + G = None + for l, label1 in enumerate(merge): + if l == 0: + # the first label is the reference (shift=0) + continue + ind1 = label_inds[l] + if local_pair_mask[0, l]: + # easy case the pair label0<>label1 was existing + shift = pair_shift[ind0, ind1] + else: + # more complicated case need to find intermediate label and propagate the shift!! + if G is None: + # the the graph only once and only if needed + G = nx.from_numpy_array(local_pair_mask | local_pair_mask.T) + local_pair_shift = pair_shift[label_inds, :][:, label_inds] + local_pair_shift += local_pair_shift.T + + shift_chain = nx.shortest_path(G, source=l, target=0) + shift = 0 + for i in range(len(shift_chain) - 1): + shift += local_pair_shift[shift_chain[i + 1], shift_chain[i]] + shifts[l] = shift + + group_shifts.append(shifts) + + return group_shifts + + +def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full"): + """ + Agglomerate merge pairs into final merge groups. + + The merges are ordered by label. + + """ + + labels_set = np.array(labels_set) + + merges = [] + + graph = nx.from_numpy_matrix(pair_mask | pair_mask.T) + # put real nodes names for debugging + maps = dict(zip(np.arange(labels_set.size), labels_set)) + graph = nx.relabel_nodes(graph, maps) + + groups = list(nx.connected_components(graph)) + for group in groups: + if len(group) == 1: + continue + sub_graph = graph.subgraph(group) + # print(group, sub_graph) + cliques = list(nx.find_cliques(sub_graph)) + if len(cliques) == 1 and len(cliques[0]) == len(group): + # the sub graph is full connected: no ambiguity + # merges.append(labels_set[cliques[0]]) + merges.append(cliques[0]) + elif len(cliques) > 1: + # the subgraph is not fully connected + if connection_mode == "full": + # node merge + pass + elif connection_mode == "partial": + group = list(group) + # merges.append(labels_set[group]) + merges.append(group) + elif connection_mode == "clique": + raise NotImplementedError + else: + raise ValueError + + # DEBUG = True + DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + + fig = plt.figure() + nx.draw_networkx(sub_graph) + plt.show() + + DEBUG = True + # DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + + fig = plt.figure() + nx.draw_networkx(graph) + plt.show() + + # ensure ordered label + merges = [np.sort(merge) for merge in merges] + + return merges + + +def find_merge_pairs( + peaks, + peak_labels, + recording, + features_dict_or_folder, + sparse_wfs, + sparse_mask, + radius_um=70, + method="waveforms_lda", + method_kwargs={}, + **job_kwargs + # n_jobs=1, + # mp_context="fork", + # max_threads_per_process=1, + # progress_bar=True, +): + """ + Searh some possible merge 2 by 2. + """ + job_kwargs = fix_job_kwargs(job_kwargs) + + # features_dict_or_folder = Path(features_dict_or_folder) + + # peaks = features_dict_or_folder['peaks'] + total_channels = recording.get_num_channels() + + # sparse_wfs = features['sparse_wfs'] + + labels_set = np.setdiff1d(peak_labels, [-1]).tolist() + n = len(labels_set) + pair_mask = np.triu(np.ones((n, n), dtype="bool")) & ~np.eye(n, dtype="bool") + pair_shift = np.zeros((n, n), dtype="int64") + pair_values = np.zeros((n, n), dtype="float64") + + # compute template (no shift at this step) + + templates = compute_template_from_sparse( + peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels, peak_shifts=None + ) + + max_chans = np.argmax(np.max(np.abs(templates), axis=1), axis=1) + + channel_locs = recording.get_channel_locations() + template_locs = channel_locs[max_chans, :] + template_dist = scipy.spatial.distance.cdist(template_locs, template_locs, metric="euclidean") + + pair_mask = pair_mask & (template_dist < radius_um) + indices0, indices1 = np.nonzero(pair_mask) + + n_jobs = job_kwargs["n_jobs"] + mp_context = job_kwargs["mp_context"] + max_threads_per_process = job_kwargs["max_threads_per_process"] + progress_bar = job_kwargs["progress_bar"] + + Executor = get_poolexecutor(n_jobs) + + with Executor( + max_workers=n_jobs, + initializer=find_pair_worker_init, + mp_context=get_context(mp_context), + initargs=(recording, features_dict_or_folder, peak_labels, method, method_kwargs, max_threads_per_process), + ) as pool: + jobs = [] + for ind0, ind1 in zip(indices0, indices1): + label0 = labels_set[ind0] + label1 = labels_set[ind1] + jobs.append(pool.submit(find_pair_function_wrapper, label0, label1)) + + if progress_bar: + iterator = tqdm(jobs, desc=f"find_merge_pairs with {method}", total=len(jobs)) + else: + iterator = jobs + + for res in iterator: + is_merge, label0, label1, shift, merge_value = res.result() + ind0 = labels_set.index(label0) + ind1 = labels_set.index(label1) + + pair_mask[ind0, ind1] = is_merge + if is_merge: + pair_shift[ind0, ind1] = shift + pair_values[ind0, ind1] = merge_value + + pair_mask = pair_mask & (template_dist < radius_um) + indices0, indices1 = np.nonzero(pair_mask) + + return labels_set, pair_mask, pair_shift, pair_values + + +def find_pair_worker_init( + recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process +): + global _ctx + _ctx = {} + + _ctx["recording"] = recording + _ctx["original_labels"] = original_labels + _ctx["method"] = method + _ctx["method_kwargs"] = method_kwargs + _ctx["method_class"] = find_pair_method_dict[method] + _ctx["max_threads_per_process"] = max_threads_per_process + + # if isinstance(features_dict_or_folder, dict): + # _ctx["features"] = features_dict_or_folder + # else: + # _ctx["features"] = FeaturesLoader(features_dict_or_folder) + + _ctx["features"] = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) + + _ctx["peaks"] = _ctx["features"]["peaks"] + + +def find_pair_function_wrapper(label0, label1): + global _ctx + with threadpool_limits(limits=_ctx["max_threads_per_process"]): + is_merge, label0, label1, shift, merge_value = _ctx["method_class"].merge( + label0, label1, _ctx["original_labels"], _ctx["peaks"], _ctx["features"], **_ctx["method_kwargs"] + ) + return is_merge, label0, label1, shift, merge_value + + +class WaveformsLda: + name = "waveforms_lda" + + @staticmethod + def merge( + label0, + label1, + original_labels, + peaks, + features, + waveforms_sparse_mask=None, + feature_name="sparse_tsvd", + projection="centroid", + criteria="diptest", + threshold_diptest=0.5, + threshold_percentile=80.0, + num_shift=2, + ): + if num_shift > 0: + assert feature_name == "sparse_wfs" + sparse_wfs = features[feature_name] + + assert waveforms_sparse_mask is not None + + (inds0,) = np.nonzero(original_labels == label0) + chans0 = np.unique(peaks["channel_index"][inds0]) + target_chans0 = np.flatnonzero(np.all(waveforms_sparse_mask[chans0, :], axis=0)) + + (inds1,) = np.nonzero(original_labels == label1) + chans1 = np.unique(peaks["channel_index"][inds1]) + target_chans1 = np.flatnonzero(np.all(waveforms_sparse_mask[chans1, :], axis=0)) + + if inds0.size <40 or inds1.size <40: + is_merge = False + merge_value = 0 + final_shift = 0 + return is_merge, label0, label1, final_shift, merge_value + + + target_chans = np.intersect1d(target_chans0, target_chans1) + + inds = np.concatenate([inds0, inds1]) + labels = np.zeros(inds.size, dtype="int") + labels[inds0.size :] = 1 + wfs, out = aggregate_sparse_features(peaks, inds, sparse_wfs, waveforms_sparse_mask, target_chans) + wfs = wfs[~out] + labels = labels[~out] + + cut = np.searchsorted(labels, 1) + wfs0_ = wfs[:cut, :, :] + wfs1_ = wfs[cut:, :, :] + + template0_ = np.mean(wfs0_, axis=0) + template1_ = np.mean(wfs1_, axis=0) + num_samples = template0_.shape[0] + + template0 = template0_[num_shift : num_samples - num_shift, :] + + wfs0 = wfs0_[:, num_shift : num_samples - num_shift, :] + + # best shift strategy 1 = max cosine + # values = [] + # for shift in range(num_shift * 2 + 1): + # template1 = template1_[shift : shift + template0.shape[0], :] + # norm = np.linalg.norm(template0.flatten()) * np.linalg.norm(template1.flatten()) + # value = np.sum(template0.flatten() * template1.flatten()) / norm + # values.append(value) + # best_shift = np.argmax(values) + + # best shift strategy 2 = min dist**2 + # values = [] + # for shift in range(num_shift * 2 + 1): + # template1 = template1_[shift : shift + template0.shape[0], :] + # value = np.sum((template1 - template0)**2) + # values.append(value) + # best_shift = np.argmin(values) + + # best shift strategy 3 : average delta argmin between channels + channel_shift = np.argmax(np.abs(template1_), axis=0) - np.argmax(np.abs(template0_), axis=0) + mask = np.abs(channel_shift) <= num_shift + channel_shift = channel_shift[mask] + if channel_shift.size > 0: + best_shift = int(np.round(np.mean(channel_shift))) + num_shift + else: + best_shift = num_shift + + wfs1 = wfs1_[:, best_shift : best_shift + template0.shape[0], :] + template1 = template1_[best_shift : best_shift + template0.shape[0], :] + + if projection == "lda": + wfs_0_1 = np.concatenate([wfs0, wfs1], axis=0) + flat_wfs = wfs_0_1.reshape(wfs_0_1.shape[0], -1) + feat = LinearDiscriminantAnalysis(n_components=1).fit_transform(flat_wfs, labels) + feat = feat[:, 0] + feat0 = feat[:cut] + feat1 = feat[cut:] + + elif projection == "centroid": + vector_0_1 = template1 - template0 + vector_0_1 /= np.sum(vector_0_1**2) + feat0 = np.sum((wfs0 - template0[np.newaxis, :, :]) * vector_0_1[np.newaxis, :, :], axis=(1, 2)) + feat1 = np.sum((wfs1 - template0[np.newaxis, :, :]) * vector_0_1[np.newaxis, :, :], axis=(1, 2)) + # feat = np.sum((wfs_0_1 - template0[np.newaxis, :, :]) * vector_0_1[np.newaxis, :, :], axis=(1, 2)) + feat = np.concatenate([feat0, feat1], axis=0) + + else: + raise ValueError(f"bad projection {projection}") + + if criteria == "diptest": + dipscore, cutpoint = isocut5(feat) + is_merge = dipscore < threshold_diptest + merge_value = dipscore + elif criteria == "percentile": + l0 = np.percentile(feat0, threshold_percentile) + l1 = np.percentile(feat1, 100.0 - threshold_percentile) + is_merge = l0 >= l1 + merge_value = l0 - l1 + else: + raise ValueError(f"bad criteria {criteria}") + + if is_merge: + final_shift = best_shift - num_shift + else: + final_shift = 0 + + DEBUG = True + # DEBUG = False + + if DEBUG and is_merge: + # if DEBUG: + import matplotlib.pyplot as plt + + flatten_wfs0 = wfs0.swapaxes(1, 2).reshape(wfs0.shape[0], -1) + flatten_wfs1 = wfs1.swapaxes(1, 2).reshape(wfs1.shape[0], -1) + + fig, axs = plt.subplots(ncols=2) + ax = axs[0] + ax.plot(flatten_wfs0.T, color="C0", alpha=0.01) + ax.plot(flatten_wfs1.T, color="C1", alpha=0.01) + m0 = np.mean(flatten_wfs0, axis=0) + m1 = np.mean(flatten_wfs1, axis=0) + ax.plot(m0, color="C0", alpha=1, lw=4, label=f"{label0} {inds0.size}") + ax.plot(m1, color="C1", alpha=1, lw=4, label=f"{label1} {inds1.size}") + + ax.legend() + + bins = np.linspace(np.percentile(feat, 1), np.percentile(feat, 99), 100) + + count0, _ = np.histogram(feat0, bins=bins) + count1, _ = np.histogram(feat1, bins=bins) + + ax = axs[1] + ax.plot(bins[:-1], count0, color="C0") + ax.plot(bins[:-1], count1, color="C1") + + ax.set_title(f"{dipscore:.4f} {is_merge}") + plt.show() + + + return is_merge, label0, label1, final_shift, merge_value + + +find_pair_method_list = [ + WaveformsLda, +] +find_pair_method_dict = {e.name: e for e in find_pair_method_list} diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py new file mode 100644 index 0000000000..411d8c2116 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -0,0 +1,260 @@ +from multiprocessing import get_context +from threadpoolctl import threadpool_limits +from tqdm.auto import tqdm + +from sklearn.decomposition import TruncatedSVD +from hdbscan import HDBSCAN + +import numpy as np + +from spikeinterface.core.job_tools import get_poolexecutor, fix_job_kwargs + +from .tools import aggregate_sparse_features, FeaturesLoader +from .isocut5 import isocut5 + + +def split_clusters( + peak_labels, + recording, + features_dict_or_folder, + method="hdbscan_on_local_pca", + method_kwargs={}, + recursive=False, + recursive_depth=None, + returns_split_count=False, + **job_kwargs, +): + """ + Run recusrsively or not in a multi process pool a local split method. + + Parameters + ---------- + peak_labels: numpy.array + Peak label before split + recording: Recording + Recording object + features_dict_or_folder: dict or folder + A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features. + method: str + The method name + method_kwargs: dict + The method option + recursive: bool Default False + Reccursive or not. + recursive_depth: None or int + If recursive=True, then this is the max split per spikes. + returns_split_count: bool + Optionally return the split count vector. Same size as labels. + + Returns + ------- + new_labels: numpy.ndarray + The labels of peaks after split. + split_count: numpy.ndarray + Optionally returned + """ + + job_kwargs = fix_job_kwargs(job_kwargs) + n_jobs = job_kwargs["n_jobs"] + mp_context = job_kwargs["mp_context"] + progress_bar = job_kwargs["progress_bar"] + max_threads_per_process = job_kwargs["max_threads_per_process"] + + original_labels = peak_labels + peak_labels = peak_labels.copy() + split_count = np.zeros(peak_labels.size, dtype=int) + + Executor = get_poolexecutor(n_jobs) + + with Executor( + max_workers=n_jobs, + initializer=split_worker_init, + mp_context=get_context(mp_context), + initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process), + ) as pool: + labels_set = np.setdiff1d(peak_labels, [-1]) + current_max_label = np.max(labels_set) + 1 + + jobs = [] + for label in labels_set: + peak_indices = np.flatnonzero(peak_labels == label) + if peak_indices.size > 0: + jobs.append(pool.submit(split_function_wrapper, peak_indices)) + + if progress_bar: + iterator = tqdm(jobs, desc=f"split_clusters with {method}", total=len(labels_set)) + else: + iterator = jobs + + for res in iterator: + is_split, local_labels, peak_indices = res.result() + if not is_split: + continue + + mask = local_labels >= 0 + peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label + peak_labels[peak_indices[~mask]] = local_labels[~mask] + + split_count[peak_indices] += 1 + + current_max_label += np.max(local_labels[mask]) + 1 + + if recursive: + if recursive_depth is not None: + # stop reccursivity when recursive_depth is reach + extra_ball = np.max(split_count[peak_indices]) < recursive_depth + else: + # reccurssive always + extra_ball = True + + if extra_ball: + new_labels_set = np.setdiff1d(peak_labels[peak_indices], [-1]) + for label in new_labels_set: + peak_indices = np.flatnonzero(peak_labels == label) + if peak_indices.size > 0: + jobs.append(pool.submit(split_function_wrapper, peak_indices)) + if progress_bar: + iterator.total += 1 + + if returns_split_count: + return peak_labels, split_count + else: + return peak_labels + + +global _ctx + + +def split_worker_init( + recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process +): + global _ctx + _ctx = {} + + _ctx["recording"] = recording + features_dict_or_folder + _ctx["original_labels"] = original_labels + _ctx["method"] = method + _ctx["method_kwargs"] = method_kwargs + _ctx["method_class"] = split_methods_dict[method] + _ctx["max_threads_per_process"] = max_threads_per_process + _ctx["features"] = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) + _ctx["peaks"] = _ctx["features"]["peaks"] + + +def split_function_wrapper(peak_indices): + global _ctx + with threadpool_limits(limits=_ctx["max_threads_per_process"]): + is_split, local_labels = _ctx["method_class"].split( + peak_indices, _ctx["peaks"], _ctx["features"], **_ctx["method_kwargs"] + ) + return is_split, local_labels, peak_indices + + +class HdbscanOnLocalPca: + # @charlie : this is the equivalent of "herding_split()" in DART + # but simplified, flexible and renamed + + name = "hdbscan_on_local_pca" + + @staticmethod + def split( + peak_indices, + peaks, + features, + clusterer="hdbscan", + feature_name="sparse_tsvd", + neighbours_mask=None, + waveforms_sparse_mask=None, + min_size_split=25, + min_cluster_size=25, + min_samples=25, + n_pca_features=2, + ): + local_labels = np.zeros(peak_indices.size, dtype=np.int64) + + # can be sparse_tsvd or sparse_wfs + sparse_features = features[feature_name] + + assert waveforms_sparse_mask is not None + + # target channel subset is done intersect local channels + neighbours + local_chans = np.unique(peaks["channel_index"][peak_indices]) + target_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0)) + + # TODO fix this a better way, this when cluster have too few overlapping channels + minimum_channels = 2 + if target_channels.size < minimum_channels: + return False, None + + aligned_wfs, dont_have_channels = aggregate_sparse_features( + peaks, peak_indices, sparse_features, waveforms_sparse_mask, target_channels + ) + + local_labels[dont_have_channels] = -2 + kept = np.flatnonzero(~dont_have_channels) + if kept.size < min_size_split: + return False, None + + aligned_wfs = aligned_wfs[kept, :, :] + + flatten_features = aligned_wfs.reshape(aligned_wfs.shape[0], -1) + + # final_features = PCA(n_pca_features, whiten=True).fit_transform(flatten_features) + # final_features = PCA(n_pca_features, whiten=False).fit_transform(flatten_features) + final_features = TruncatedSVD(n_pca_features).fit_transform(flatten_features) + + if clusterer == "hdbscan": + clust = HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples, allow_single_cluster=True) + clust.fit(final_features) + possible_labels = clust.labels_ + elif clusterer == "isocut5": + dipscore, cutpoint = isocut5(final_features[:, 0]) + possible_labels = np.zeros(final_features.shape[0]) + if dipscore > 1.5: + mask = final_features[:, 0] > cutpoint + if np.sum(mask) > min_cluster_size and np.sum(~mask): + possible_labels[mask] = 1 + else: + return False, None + else: + raise ValueError(f"wrong clusterer {clusterer}") + + is_split = np.setdiff1d(possible_labels, [-1]).size > 1 + + # DEBUG = True + DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + + labels_set = np.setdiff1d(possible_labels, [-1]) + colors = plt.get_cmap("tab10", len(labels_set)) + colors = {k: colors(i) for i, k in enumerate(labels_set)} + colors[-1] = "k" + fix, axs = plt.subplots(nrows=2) + + flatten_wfs = aligned_wfs.swapaxes(1, 2).reshape(aligned_wfs.shape[0], -1) + + sl = slice(None, None, 10) + for k in np.unique(possible_labels): + mask = possible_labels == k + ax = axs[0] + ax.scatter(final_features[:, 0][mask][sl], final_features[:, 1][mask][sl], s=5, color=colors[k]) + + ax = axs[1] + ax.plot(flatten_wfs[mask][sl].T, color=colors[k], alpha=0.5) + + plt.show() + + if not is_split: + return is_split, None + + local_labels[kept] = possible_labels + + return is_split, local_labels + + +split_methods_list = [ + HdbscanOnLocalPca, +] +split_methods_dict = {e.name: e for e in split_methods_list} diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py new file mode 100644 index 0000000000..9a537ab8a8 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -0,0 +1,196 @@ +from pathlib import Path +from typing import Any +import numpy as np + + +# TODO find a way to attach a a sparse_mask to a given features (waveforms, pca, tsvd ....) + + +class FeaturesLoader: + """ + Feature can be computed in memory or in a folder contaning npy files. + + This class read the folder and behave like a dict of array lazily. + + Parameters + ---------- + feature_folder + + preload + + """ + + def __init__(self, feature_folder, preload=["peaks"]): + self.feature_folder = Path(feature_folder) + + self.file_feature = {} + self.loaded_features = {} + for file in self.feature_folder.glob("*.npy"): + name = file.stem + if name in preload: + self.loaded_features[name] = np.load(file) + else: + self.file_feature[name] = file + + def __getitem__(self, name): + if name in self.loaded_features: + return self.loaded_features[name] + else: + return np.load(self.file_feature[name], mmap_mode="r") + + @staticmethod + def from_dict_or_folder(features_dict_or_folder): + if isinstance(features_dict_or_folder, dict): + return features_dict_or_folder + else: + return FeaturesLoader(features_dict_or_folder) + + +def aggregate_sparse_features(peaks, peak_indices, sparse_feature, sparse_mask, target_channels): + """ + Aggregate sparse features that have unaligned channels and realigned then on target_channels. + + This is usefull to aligned back peaks waveform or pca or tsvd when detected a differents channels. + + + Parameters + ---------- + peaks + + peak_indices + + sparse_feature + + sparse_mask + + target_channels + + Returns + ------- + aligned_features: numpy.array + Aligned features. shape is (local_peaks.size, sparse_feature.shape[1], target_channels.size) + dont_have_channels: numpy.array + Boolean vector to indicate spikes that do not have all target channels to be taken in account + shape is peak_indices.size + """ + local_peaks = peaks[peak_indices] + + aligned_features = np.zeros( + (local_peaks.size, sparse_feature.shape[1], target_channels.size), dtype=sparse_feature.dtype + ) + dont_have_channels = np.zeros(peak_indices.size, dtype=bool) + + for chan in np.unique(local_peaks["channel_index"]): + sparse_chans = np.flatnonzero(sparse_mask[chan, :]) + peak_inds = np.flatnonzero(local_peaks["channel_index"] == chan) + if np.all(np.in1d(target_channels, sparse_chans)): + # peaks feature channel have all target_channels + source_chans = np.flatnonzero(np.in1d(sparse_chans, target_channels)) + aligned_features[peak_inds, :, :] = sparse_feature[peak_indices[peak_inds], :, :][:, :, source_chans] + else: + # some channel are missing, peak are not removde + dont_have_channels[peak_inds] = True + + return aligned_features, dont_have_channels + + +def compute_template_from_sparse( + peaks, labels, labels_set, sparse_waveforms, sparse_mask, total_channels, peak_shifts=None +): + """ + Compute template average from single sparse waveforms buffer. + + Parameters + ---------- + peaks + + labels + + labels_set + + sparse_waveforms + + sparse_mask + + total_channels + + peak_shifts + + Returns + ------- + templates: numpy.array + Templates shape : (len(labels_set), num_samples, total_channels) + """ + n = len(labels_set) + + templates = np.zeros((n, sparse_waveforms.shape[1], total_channels), dtype=sparse_waveforms.dtype) + + for i, label in enumerate(labels_set): + peak_indices = np.flatnonzero(labels == label) + + local_chans = np.unique(peaks["channel_index"][peak_indices]) + target_channels = np.flatnonzero(np.all(sparse_mask[local_chans, :], axis=0)) + + aligned_wfs, dont_have_channels = aggregate_sparse_features( + peaks, peak_indices, sparse_waveforms, sparse_mask, target_channels + ) + + if peak_shifts is not None: + apply_waveforms_shift(aligned_wfs, peak_shifts[peak_indices], inplace=True) + + templates[i, :, :][:, target_channels] = np.mean(aligned_wfs[~dont_have_channels], axis=0) + + return templates + + +def apply_waveforms_shift(waveforms, peak_shifts, inplace=False): + """ + Apply a shift a spike level to realign waveforms buffers. + + This is usefull to compute template after merge when to cluster are shifted. + + A negative shift need the waveforms to be moved toward the right because the trough was too early. + A positive shift need the waveforms to be moved toward the left because the trough was too late. + + Note the border sample are left as before without move. + + Parameters + ---------- + + waveforms + + peak_shifts + + inplace + + Returns + ------- + aligned_waveforms + + + """ + + print("apply_waveforms_shift") + + if inplace: + aligned_waveforms = waveforms + else: + aligned_waveforms = waveforms.copy() + + shift_set = np.unique(peak_shifts) + assert max(np.abs(shift_set)) < aligned_waveforms.shape[1] + + for shift in shift_set: + if shift == 0: + continue + mask = peak_shifts == shift + wfs = waveforms[mask] + + if shift > 0: + aligned_waveforms[mask, :-shift, :] = wfs[:, shift:, :] + else: + aligned_waveforms[mask, -shift:, :] = wfs[:, :-shift, :] + + print("apply_waveforms_shift DONE") + + return aligned_waveforms diff --git a/src/spikeinterface/sortingcomponents/tests/test_split.py b/src/spikeinterface/sortingcomponents/tests/test_split.py new file mode 100644 index 0000000000..ed5e756469 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/tests/test_split.py @@ -0,0 +1,12 @@ +import pytest +import numpy as np + +from spikeinterface.sortingcomponents.clustering.split import split_clusters + + +def test_split(): + pass + + +if __name__ == "__main__": + test_split() From 9c4bba37b89012c4d016394916a04832df3109c7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 19 Sep 2023 21:00:13 +0200 Subject: [PATCH 05/74] wip --- .../sortingcomponents/clustering/merge.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 2e839ef0fc..e2049d70bf 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -157,7 +157,7 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" merges = [] - graph = nx.from_numpy_matrix(pair_mask | pair_mask.T) + graph = nx.from_numpy_array(pair_mask | pair_mask.T) # put real nodes names for debugging maps = dict(zip(np.arange(labels_set.size), labels_set)) graph = nx.relabel_nodes(graph, maps) @@ -196,8 +196,8 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" nx.draw_networkx(sub_graph) plt.show() - DEBUG = True - # DEBUG = False + # DEBUG = True + DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -457,8 +457,8 @@ def merge( else: final_shift = 0 - DEBUG = True - # DEBUG = False + # DEBUG = True + DEBUG = False if DEBUG and is_merge: # if DEBUG: @@ -487,7 +487,15 @@ def merge( ax.plot(bins[:-1], count0, color="C0") ax.plot(bins[:-1], count1, color="C1") - ax.set_title(f"{dipscore:.4f} {is_merge}") + if criteria == "diptest": + ax.set_title(f"{dipscore:.4f} {is_merge}") + elif criteria == "percentile": + ax.set_title(f"{l0:.4f} {l1:.4f} {is_merge}") + ax.axvline(l0, color="C0") + ax.axvline(l1, color="C1") + + + plt.show() From 2ba8928b785ed06f8a2f01b48ea632a4171ab926 Mon Sep 17 00:00:00 2001 From: Windows Home Date: Sun, 24 Sep 2023 13:51:48 -0500 Subject: [PATCH 06/74] Fix unit ID matching in sortingview curation Refine the logic for matching unit IDs in the sortingview curation process. Instead of using a potentially ambiguous containment check, unit IDs are now split at the '-' character, ensuring accurate mapping between unit labels and merged unit IDs. Additionally, introduced a unit test to validate the improved behavior and guard against potential false positives in future changes. --- .../curation/sortingview_curation.py | 3 +- .../tests/test_sortingview_curation.py | 45 +++++++++++++++++-- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index 6adf9effd4..f595a67a3f 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -83,8 +83,9 @@ def apply_sortingview_curation( properties[label] = np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) for u_i, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): labels_unit = [] + unit_id_parts = str(unit_id).split('-') for unit_label, labels in labels_dict.items(): - if unit_label in str(unit_id): + if unit_label in unit_id_parts: labels_unit.extend(labels) for label in labels_unit: properties[label][u_i] = True diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 9177cb5536..1b9e6f2800 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -1,8 +1,10 @@ import pytest from pathlib import Path import os - +import json +import numpy as np import spikeinterface as si +import spikeinterface.extractors as se from spikeinterface.extractors import read_mearec from spikeinterface import set_global_tmp_folder from spikeinterface.postprocessing import ( @@ -17,9 +19,7 @@ cache_folder = pytest.global_test_folder / "curation" else: cache_folder = Path("cache_folder") / "curation" - parent_folder = Path(__file__).parent - ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) KACHERY_CLOUD_SET = bool(os.getenv("KACHERY_CLOUD_CLIENT_ID")) and bool(os.getenv("KACHERY_CLOUD_PRIVATE_KEY")) @@ -111,6 +111,7 @@ def test_json_curation(): # from curation.json json_file = parent_folder / "sv-sorting-curation.json" sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) + print(f"Sorting: {sorting.get_unit_ids()}") print(f"From JSON: {sorting_curated_json}") assert len(sorting_curated_json.unit_ids) == 9 @@ -130,9 +131,47 @@ def test_json_curation(): assert len(sorting_curated_json_mua.unit_ids) == 6 assert len(sorting_curated_json_mua1.unit_ids) == 5 +def test_false_positive_curation(): + # https://spikeinterface.readthedocs.io/en/latest/modules_gallery/core/plot_2_sorting_extractor.html + sampling_frequency = 30000. + duration = 20. + num_timepoints = int(sampling_frequency * duration) + num_units = 20 + num_spikes = 1000 + times0 = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels0 = np.random.randint(1, num_units + 1, size=num_spikes) + times1 = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels1 = np.random.randint(1, num_units + 1, size=num_spikes) + + sorting = se.NumpySorting.from_times_labels([times0, times1], [labels0, labels1], sampling_frequency) + print('Sorting: {}'.format(sorting.get_unit_ids())) + + # Test curation JSON: + test_json = { + "labelsByUnit": { + "1": ["accept"], + }, + "mergeGroups": [] + } + + json_path = "test_data.json" + with open(json_path, 'w') as f: + json.dump(test_json, f, indent=4) + + sorting_curated = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True) + accept_idx = np.where(sorting_curated.get_property("accept"))[0] + sorting_curated_ids = sorting_curated.get_unit_ids() + print(f'Accepted unit IDs: {sorting_curated_ids[accept_idx]}') + + # Check if unit_id 1 has received the "accept" label. + assert sorting_curated.get_unit_property(unit_id=1, key="accept") + # Check if unit_id "#10" has received the "accept" label. + # If so, test fails since only unit_id 1 received the "accept" label in test_json. + assert not sorting_curated.get_unit_property(unit_id=10, key="accept") if __name__ == "__main__": # generate_sortingview_curation_dataset() test_sha1_curation() test_gh_curation() test_json_curation() + test_false_positive_curation() From 45c69f52147edd406f293f731b7c7c687c700d29 Mon Sep 17 00:00:00 2001 From: Windows Home Date: Sun, 24 Sep 2023 14:46:01 -0500 Subject: [PATCH 07/74] Add merge check --- .gitignore | 1 + .../tests/test_sortingview_curation.py | 20 ++++++++++++------- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 3ee3cb8867..7838213bed 100644 --- a/.gitignore +++ b/.gitignore @@ -188,3 +188,4 @@ test_folder/ # Mac OS .DS_Store +test_data.json diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 1b9e6f2800..c8a0788223 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -115,6 +115,7 @@ def test_json_curation(): print(f"From JSON: {sorting_curated_json}") assert len(sorting_curated_json.unit_ids) == 9 + print(sorting_curated_json.unit_ids) assert "#8-#9" in sorting_curated_json.unit_ids assert "accept" in sorting_curated_json.get_property_keys() assert "mua" in sorting_curated_json.get_property_keys() @@ -150,24 +151,29 @@ def test_false_positive_curation(): test_json = { "labelsByUnit": { "1": ["accept"], + "2": ["artifact"], + "12": ["artifact"] }, - "mergeGroups": [] + "mergeGroups": [[2,12]] } json_path = "test_data.json" with open(json_path, 'w') as f: json.dump(test_json, f, indent=4) - sorting_curated = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True) - accept_idx = np.where(sorting_curated.get_property("accept"))[0] - sorting_curated_ids = sorting_curated.get_unit_ids() + sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True) + accept_idx = np.where(sorting_curated_json.get_property("accept"))[0] + sorting_curated_ids = sorting_curated_json.get_unit_ids() print(f'Accepted unit IDs: {sorting_curated_ids[accept_idx]}') # Check if unit_id 1 has received the "accept" label. - assert sorting_curated.get_unit_property(unit_id=1, key="accept") - # Check if unit_id "#10" has received the "accept" label. + assert sorting_curated_json.get_unit_property(unit_id=1, key="accept") + # Check if unit_id 10 has received the "accept" label. # If so, test fails since only unit_id 1 received the "accept" label in test_json. - assert not sorting_curated.get_unit_property(unit_id=10, key="accept") + assert not sorting_curated_json.get_unit_property(unit_id=10, key="accept") + print(sorting_curated_json.unit_ids) + # Merging unit_ids of dtype int creates a new unit id + assert 21 in sorting_curated_json.unit_ids if __name__ == "__main__": # generate_sortingview_curation_dataset() From ffaf06756b3884646785fd81bce2d123abaaff0d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 24 Sep 2023 20:09:34 +0000 Subject: [PATCH 08/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../curation/sortingview_curation.py | 2 +- .../tests/test_sortingview_curation.py | 33 ++++++++----------- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index f595a67a3f..a5633fe165 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -83,7 +83,7 @@ def apply_sortingview_curation( properties[label] = np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) for u_i, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): labels_unit = [] - unit_id_parts = str(unit_id).split('-') + unit_id_parts = str(unit_id).split("-") for unit_label, labels in labels_dict.items(): if unit_label in unit_id_parts: labels_unit.extend(labels) diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index c8a0788223..a8944f0688 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -132,10 +132,11 @@ def test_json_curation(): assert len(sorting_curated_json_mua.unit_ids) == 6 assert len(sorting_curated_json_mua1.unit_ids) == 5 + def test_false_positive_curation(): # https://spikeinterface.readthedocs.io/en/latest/modules_gallery/core/plot_2_sorting_extractor.html - sampling_frequency = 30000. - duration = 20. + sampling_frequency = 30000.0 + duration = 20.0 num_timepoints = int(sampling_frequency * duration) num_units = 20 num_spikes = 1000 @@ -145,36 +146,30 @@ def test_false_positive_curation(): labels1 = np.random.randint(1, num_units + 1, size=num_spikes) sorting = se.NumpySorting.from_times_labels([times0, times1], [labels0, labels1], sampling_frequency) - print('Sorting: {}'.format(sorting.get_unit_ids())) + print("Sorting: {}".format(sorting.get_unit_ids())) # Test curation JSON: - test_json = { - "labelsByUnit": { - "1": ["accept"], - "2": ["artifact"], - "12": ["artifact"] - }, - "mergeGroups": [[2,12]] - } + test_json = {"labelsByUnit": {"1": ["accept"], "2": ["artifact"], "12": ["artifact"]}, "mergeGroups": [[2, 12]]} json_path = "test_data.json" - with open(json_path, 'w') as f: + with open(json_path, "w") as f: json.dump(test_json, f, indent=4) sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True) accept_idx = np.where(sorting_curated_json.get_property("accept"))[0] sorting_curated_ids = sorting_curated_json.get_unit_ids() - print(f'Accepted unit IDs: {sorting_curated_ids[accept_idx]}') + print(f"Accepted unit IDs: {sorting_curated_ids[accept_idx]}") - # Check if unit_id 1 has received the "accept" label. - assert sorting_curated_json.get_unit_property(unit_id=1, key="accept") - # Check if unit_id 10 has received the "accept" label. - # If so, test fails since only unit_id 1 received the "accept" label in test_json. - assert not sorting_curated_json.get_unit_property(unit_id=10, key="accept") + # Check if unit_id 1 has received the "accept" label. + assert sorting_curated_json.get_unit_property(unit_id=1, key="accept") + # Check if unit_id 10 has received the "accept" label. + # If so, test fails since only unit_id 1 received the "accept" label in test_json. + assert not sorting_curated_json.get_unit_property(unit_id=10, key="accept") print(sorting_curated_json.unit_ids) - # Merging unit_ids of dtype int creates a new unit id + # Merging unit_ids of dtype int creates a new unit id assert 21 in sorting_curated_json.unit_ids + if __name__ == "__main__": # generate_sortingview_curation_dataset() test_sha1_curation() From 57bb3a734978d207f12733eb4c4807cb8e22c06f Mon Sep 17 00:00:00 2001 From: Windows Home Date: Tue, 26 Sep 2023 22:54:41 -0500 Subject: [PATCH 09/74] Implement more tests to ensure int and string unit IDs merging, inheriting labels, etc. --- .../curation/sortingview_curation.py | 49 +++-- .../tests/test_sortingview_curation.py | 195 +++++++++++++++--- 2 files changed, 202 insertions(+), 42 deletions(-) diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index f595a67a3f..b7f0cab6a0 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -57,38 +57,52 @@ def apply_sortingview_curation( unit_ids_dtype = sorting.unit_ids.dtype # STEP 1: merge groups + labels_dict = sortingview_curation_dict["labelsByUnit"] if "mergeGroups" in sortingview_curation_dict and not skip_merge: merge_groups = sortingview_curation_dict["mergeGroups"] - for mg in merge_groups: + for merge_group in merge_groups: + # Store labels of units that are about to be merged + labels_to_inherit = [] + for unit in merge_group: + labels_to_inherit.extend(labels_dict.get(str(unit), [])) + labels_to_inherit = list(set(labels_to_inherit)) # Remove duplicates + if verbose: - print(f"Merging {mg}") + print(f"Merging {merge_group}") if unit_ids_dtype.kind in ("U", "S"): # if unit dtype is str, set new id as "{unit1}-{unit2}" - new_unit_id = "-".join(mg) + new_unit_id = "-".join(merge_group) + curation_sorting.merge(merge_group, new_unit_id=new_unit_id) else: # in this case, the CurationSorting takes care of finding a new unused int - new_unit_id = None - curation_sorting.merge(mg, new_unit_id=new_unit_id) + curation_sorting.merge(merge_group, new_unit_id=None) + new_unit_id = curation_sorting.max_used_id # merged unit id + labels_dict[str(new_unit_id)] = labels_to_inherit # STEP 2: gather and apply sortingview curation labels - # In sortingview, a unit is not required to have all labels. # For example, the first 3 units could be labeled as "accept". # In this case, the first 3 values of the property "accept" will be True, the rest False - labels_dict = sortingview_curation_dict["labelsByUnit"] - properties = {} - for _, labels in labels_dict.items(): - for label in labels: - if label not in properties: - properties[label] = np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) + + # Initialize the properties dictionary + properties = {label: np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) + for labels in labels_dict.values() for label in labels} + + # Populate the properties dictionary for u_i, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): - labels_unit = [] - unit_id_parts = str(unit_id).split('-') - for unit_label, labels in labels_dict.items(): - if unit_label in unit_id_parts: - labels_unit.extend(labels) + labels_unit = set() + + # Check for exact match first + if str(unit_id) in labels_dict: + labels_unit.update(labels_dict[str(unit_id)]) + # If no exact match, check if unit_label is a substring of unit_id (for string unit ID merged unit) + else: + for unit_label, labels in labels_dict.items(): + if isinstance(unit_id, str) and unit_label in unit_id: + labels_unit.update(labels) for label in labels_unit: properties[label][u_i] = True + for prop_name, prop_values in properties.items(): curation_sorting.current_sorting.set_property(prop_name, prop_values) @@ -104,5 +118,4 @@ def apply_sortingview_curation( units_to_remove.extend(unit_ids[curation_sorting.current_sorting.get_property(exclude_label) == True]) units_to_remove = np.unique(units_to_remove) curation_sorting.remove_units(units_to_remove) - return curation_sorting.current_sorting diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index c8a0788223..48923aa15d 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -3,6 +3,7 @@ import os import json import numpy as np + import spikeinterface as si import spikeinterface.extractors as se from spikeinterface.extractors import read_mearec @@ -14,11 +15,11 @@ compute_spike_amplitudes, ) from spikeinterface.curation import apply_sortingview_curation - if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "curation" else: cache_folder = Path("cache_folder") / "curation" + parent_folder = Path(__file__).parent ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) KACHERY_CLOUD_SET = bool(os.getenv("KACHERY_CLOUD_CLIENT_ID")) and bool(os.getenv("KACHERY_CLOUD_PRIVATE_KEY")) @@ -27,6 +28,7 @@ set_global_tmp_folder(cache_folder) + # this needs to be run only once def generate_sortingview_curation_dataset(): import spikeinterface.widgets as sw @@ -50,15 +52,15 @@ def generate_sortingview_curation_dataset(): @pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available") def test_gh_curation(): + """ + Test curation using GitHub URI. + """ local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") _, sorting = read_mearec(local_path) - - # from GH # curated link: # https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary&s={%22sortingCuration%22:%22gh://alejoe91/spikeinterface/fix-codecov/spikeinterface/curation/tests/sv-sorting-curation.json%22} gh_uri = "gh://SpikeInterface/spikeinterface/main/src/spikeinterface/curation/tests/sv-sorting-curation.json" sorting_curated_gh = apply_sortingview_curation(sorting, uri_or_json=gh_uri, verbose=True) - print(f"From GH: {sorting_curated_gh}") assert len(sorting_curated_gh.unit_ids) == 9 assert "#8-#9" in sorting_curated_gh.unit_ids @@ -75,9 +77,13 @@ def test_gh_curation(): assert len(sorting_curated_gh_mua.unit_ids) == 6 assert len(sorting_curated_gh_art_mua.unit_ids) == 5 + print("Test for GH passed!\n") @pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available") def test_sha1_curation(): + """ + Test curation using SHA1 URI. + """ local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") _, sorting = read_mearec(local_path) @@ -93,7 +99,7 @@ def test_sha1_curation(): assert "accept" in sorting_curated_sha1.get_property_keys() assert "mua" in sorting_curated_sha1.get_property_keys() assert "artifact" in sorting_curated_sha1.get_property_keys() - + unit_ids = sorting_curated_sha1.unit_ids sorting_curated_sha1_accepted = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, include_labels=["accept"]) sorting_curated_sha1_mua = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, exclude_labels=["mua"]) sorting_curated_sha1_art_mua = apply_sortingview_curation( @@ -103,19 +109,21 @@ def test_sha1_curation(): assert len(sorting_curated_sha1_mua.unit_ids) == 6 assert len(sorting_curated_sha1_art_mua.unit_ids) == 5 + print("Test for sha1 curation passed!\n") def test_json_curation(): + """ + Test curation using a JSON file. + """ local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") _, sorting = read_mearec(local_path) # from curation.json json_file = parent_folder / "sv-sorting-curation.json" - sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) print(f"Sorting: {sorting.get_unit_ids()}") - print(f"From JSON: {sorting_curated_json}") + sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) assert len(sorting_curated_json.unit_ids) == 9 - print(sorting_curated_json.unit_ids) assert "#8-#9" in sorting_curated_json.unit_ids assert "accept" in sorting_curated_json.get_property_keys() assert "mua" in sorting_curated_json.get_property_keys() @@ -131,20 +139,23 @@ def test_json_curation(): assert len(sorting_curated_json_accepted.unit_ids) == 3 assert len(sorting_curated_json_mua.unit_ids) == 6 assert len(sorting_curated_json_mua1.unit_ids) == 5 + + print("Test for json curation passed!\n") def test_false_positive_curation(): + """ + Test curation for false positives. + """ # https://spikeinterface.readthedocs.io/en/latest/modules_gallery/core/plot_2_sorting_extractor.html sampling_frequency = 30000. duration = 20. num_timepoints = int(sampling_frequency * duration) num_units = 20 num_spikes = 1000 - times0 = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) - labels0 = np.random.randint(1, num_units + 1, size=num_spikes) - times1 = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) - labels1 = np.random.randint(1, num_units + 1, size=num_spikes) + times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels = np.random.randint(1, num_units + 1, size=num_spikes) - sorting = se.NumpySorting.from_times_labels([times0, times1], [labels0, labels1], sampling_frequency) + sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) print('Sorting: {}'.format(sorting.get_unit_ids())) # Test curation JSON: @@ -161,23 +172,159 @@ def test_false_positive_curation(): with open(json_path, 'w') as f: json.dump(test_json, f, indent=4) + # Apply curation sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True) - accept_idx = np.where(sorting_curated_json.get_property("accept"))[0] - sorting_curated_ids = sorting_curated_json.get_unit_ids() - print(f'Accepted unit IDs: {sorting_curated_ids[accept_idx]}') - - # Check if unit_id 1 has received the "accept" label. - assert sorting_curated_json.get_unit_property(unit_id=1, key="accept") - # Check if unit_id 10 has received the "accept" label. - # If so, test fails since only unit_id 1 received the "accept" label in test_json. - assert not sorting_curated_json.get_unit_property(unit_id=10, key="accept") - print(sorting_curated_json.unit_ids) - # Merging unit_ids of dtype int creates a new unit id + print('Curated:', sorting_curated_json.get_unit_ids()) + + # Assertions + assert sorting_curated_json.get_unit_property(unit_id=1, key="accept") + assert not sorting_curated_json.get_unit_property(unit_id=10, key="accept") assert 21 in sorting_curated_json.unit_ids + print("False positive test for integer unit IDs passed!\n") + +def test_label_inheritance_int(): + """ + Test curation for label inheritance for integer unit IDs. + """ + # Setup + sampling_frequency = 30000. + duration = 20. + num_timepoints = int(sampling_frequency * duration) + num_spikes = 1000 + times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels = np.random.randint(1, 8, size=num_spikes) # 7 units: 1 to 7 + + sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) + + # Create a curation JSON with labels and merge groups + curation_dict = { + "labelsByUnit": { + "1": ["mua"], + "2": ["mua"], + "3": ["reject"], + "4": ["noise"], + "5": ["accept"], + "6": ["accept"], + "7": ["accept"] + }, + "mergeGroups": [[1, 2], [3, 4], [5, 6]] + } + + json_path = "test_curation_int.json" + with open(json_path, 'w') as f: + json.dump(curation_dict, f, indent=4) + + # Apply curation + sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_path) + + # Assertions for merged units + print(f"Merge only: {sorting_merge.get_unit_ids()}") + assert sorting_merge.get_unit_property(unit_id=8, key="mua") # 8 = merged unit of 1 and 2 + assert not sorting_merge.get_unit_property(unit_id=8, key="reject") + assert not sorting_merge.get_unit_property(unit_id=8, key="noise") + assert not sorting_merge.get_unit_property(unit_id=8, key="accept") + + assert not sorting_merge.get_unit_property(unit_id=9, key="mua") # 9 = merged unit of 3 and 4 + assert sorting_merge.get_unit_property(unit_id=9, key="reject") + assert sorting_merge.get_unit_property(unit_id=9, key="noise") + assert not sorting_merge.get_unit_property(unit_id=9, key="accept") + + assert not sorting_merge.get_unit_property(unit_id=10, key="mua") # 10 = merged unit of 5 and 6 + assert not sorting_merge.get_unit_property(unit_id=10, key="reject") + assert not sorting_merge.get_unit_property(unit_id=10, key="noise") + assert sorting_merge.get_unit_property(unit_id=10, key="accept") + + # Assertions for exclude_labels + sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_path, exclude_labels=["noise"]) + print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") + assert 9 not in sorting_exclude_noise.get_unit_ids() + + # Assertions for include_labels + sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_path, include_labels=["accept"]) + print(f"Include accept: {sorting_include_accept.get_unit_ids()}") + assert 8 not in sorting_include_accept.get_unit_ids() + assert 9 not in sorting_include_accept.get_unit_ids() + assert 10 in sorting_include_accept.get_unit_ids() + + print("Test for integer unit IDs passed!\n") + + +def test_label_inheritance_str(): + """ + Test curation for label inheritance for string unit IDs. + """ + sampling_frequency = 30000. + duration = 20. + num_timepoints = int(sampling_frequency * duration) + num_spikes = 1000 + times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels = np.random.choice(['a', 'b', 'c', 'd', 'e', 'f', 'g'], size=num_spikes) + + sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) + print(f"Sorting: {sorting.get_unit_ids()}") + # Create a curation JSON with labels and merge groups + curation_dict = { + "labelsByUnit": { + "a": ["mua"], + "b": ["mua"], + "c": ["reject"], + "d": ["noise"], + "e": ["accept"], + "f": ["accept"], + "g": ["accept"] + }, + "mergeGroups": [["a", "b"], ["c", "d"], ["e", "f"]] + } + + json_path = "test_curation_str.json" + with open(json_path, 'w') as f: + json.dump(curation_dict, f, indent=4) + + # Check label inheritance for merged units + merged_id_1 = "a-b" + merged_id_2 = "c-d" + merged_id_3 = "e-f" + # Apply curation + sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True) + + # Assertions for merged units + print(f"Merge only: {sorting_merge.get_unit_ids()}") + assert sorting_merge.get_unit_property(unit_id="a-b", key="mua") + assert not sorting_merge.get_unit_property(unit_id="a-b", key="reject") + assert not sorting_merge.get_unit_property(unit_id="a-b", key="noise") + assert not sorting_merge.get_unit_property(unit_id="a-b", key="accept") + + assert not sorting_merge.get_unit_property(unit_id="c-d", key="mua") + assert sorting_merge.get_unit_property(unit_id="c-d", key="reject") + assert sorting_merge.get_unit_property(unit_id="c-d", key="noise") + assert not sorting_merge.get_unit_property(unit_id="c-d", key="accept") + + assert not sorting_merge.get_unit_property(unit_id="e-f", key="mua") + assert not sorting_merge.get_unit_property(unit_id="e-f", key="reject") + assert not sorting_merge.get_unit_property(unit_id="e-f", key="noise") + assert sorting_merge.get_unit_property(unit_id="e-f", key="accept") + + # Assertions for exclude_labels + sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_path, exclude_labels=["noise"]) + print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") + assert "c-d" not in sorting_exclude_noise.get_unit_ids() + + # Assertions for include_labels + sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_path, include_labels=["accept"]) + print(f"Include accept: {sorting_include_accept.get_unit_ids()}") + assert "a-b" not in sorting_include_accept.get_unit_ids() + assert "c-d" not in sorting_include_accept.get_unit_ids() + assert "e-f" in sorting_include_accept.get_unit_ids() + + print("Test for string unit IDs passed!\n") + + if __name__ == "__main__": # generate_sortingview_curation_dataset() test_sha1_curation() test_gh_curation() test_json_curation() test_false_positive_curation() + test_label_inheritance_int() + test_label_inheritance_str() \ No newline at end of file From a8e07a71d8306550a20a6a611222fb76190d3178 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 04:01:49 +0000 Subject: [PATCH 10/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../curation/sortingview_curation.py | 9 ++++-- .../tests/test_sortingview_curation.py | 31 ++++++++++--------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index 7ae8e41030..f83ff3352b 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -76,7 +76,7 @@ def apply_sortingview_curation( else: # in this case, the CurationSorting takes care of finding a new unused int curation_sorting.merge(merge_group, new_unit_id=None) - new_unit_id = curation_sorting.max_used_id # merged unit id + new_unit_id = curation_sorting.max_used_id # merged unit id labels_dict[str(new_unit_id)] = labels_to_inherit # STEP 2: gather and apply sortingview curation labels @@ -85,8 +85,11 @@ def apply_sortingview_curation( # In this case, the first 3 values of the property "accept" will be True, the rest False # Initialize the properties dictionary - properties = {label: np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) - for labels in labels_dict.values() for label in labels} + properties = { + label: np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) + for labels in labels_dict.values() + for label in labels + } # Populate the properties dictionary for u_i, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 958df6acb5..cfc15013a3 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -15,6 +15,7 @@ compute_spike_amplitudes, ) from spikeinterface.curation import apply_sortingview_curation + if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "curation" else: @@ -28,7 +29,6 @@ set_global_tmp_folder(cache_folder) - # this needs to be run only once def generate_sortingview_curation_dataset(): import spikeinterface.widgets as sw @@ -79,6 +79,7 @@ def test_gh_curation(): print("Test for GH passed!\n") + @pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available") def test_sha1_curation(): """ @@ -111,6 +112,7 @@ def test_sha1_curation(): print("Test for sha1 curation passed!\n") + def test_json_curation(): """ Test curation using a JSON file. @@ -157,7 +159,7 @@ def test_false_positive_curation(): labels = np.random.randint(1, num_units + 1, size=num_spikes) sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) - print('Sorting: {}'.format(sorting.get_unit_ids())) + print("Sorting: {}".format(sorting.get_unit_ids())) # Test curation JSON: test_json = {"labelsByUnit": {"1": ["accept"], "2": ["artifact"], "12": ["artifact"]}, "mergeGroups": [[2, 12]]} @@ -168,7 +170,7 @@ def test_false_positive_curation(): # Apply curation sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True) - print('Curated:', sorting_curated_json.get_unit_ids()) + print("Curated:", sorting_curated_json.get_unit_ids()) # Assertions assert sorting_curated_json.get_unit_property(unit_id=1, key="accept") @@ -177,13 +179,14 @@ def test_false_positive_curation(): print("False positive test for integer unit IDs passed!\n") + def test_label_inheritance_int(): """ Test curation for label inheritance for integer unit IDs. """ # Setup - sampling_frequency = 30000. - duration = 20. + sampling_frequency = 30000.0 + duration = 20.0 num_timepoints = int(sampling_frequency * duration) num_spikes = 1000 times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) @@ -200,13 +203,13 @@ def test_label_inheritance_int(): "4": ["noise"], "5": ["accept"], "6": ["accept"], - "7": ["accept"] + "7": ["accept"], }, - "mergeGroups": [[1, 2], [3, 4], [5, 6]] + "mergeGroups": [[1, 2], [3, 4], [5, 6]], } json_path = "test_curation_int.json" - with open(json_path, 'w') as f: + with open(json_path, "w") as f: json.dump(curation_dict, f, indent=4) # Apply curation @@ -248,12 +251,12 @@ def test_label_inheritance_str(): """ Test curation for label inheritance for string unit IDs. """ - sampling_frequency = 30000. - duration = 20. + sampling_frequency = 30000.0 + duration = 20.0 num_timepoints = int(sampling_frequency * duration) num_spikes = 1000 times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) - labels = np.random.choice(['a', 'b', 'c', 'd', 'e', 'f', 'g'], size=num_spikes) + labels = np.random.choice(["a", "b", "c", "d", "e", "f", "g"], size=num_spikes) sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) print(f"Sorting: {sorting.get_unit_ids()}") @@ -266,13 +269,13 @@ def test_label_inheritance_str(): "d": ["noise"], "e": ["accept"], "f": ["accept"], - "g": ["accept"] + "g": ["accept"], }, - "mergeGroups": [["a", "b"], ["c", "d"], ["e", "f"]] + "mergeGroups": [["a", "b"], ["c", "d"], ["e", "f"]], } json_path = "test_curation_str.json" - with open(json_path, 'w') as f: + with open(json_path, "w") as f: json.dump(curation_dict, f, indent=4) # Check label inheritance for merged units From fb82e029be652fa33b69367d9d97f9c7a465914e Mon Sep 17 00:00:00 2001 From: Robin Kim <31869753+rkim48@users.noreply.github.com> Date: Wed, 27 Sep 2023 10:16:37 -0500 Subject: [PATCH 11/74] Apply suggestions from code review Remove print('success') statements Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- .../curation/tests/test_sortingview_curation.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index cfc15013a3..79cea3d010 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -77,7 +77,6 @@ def test_gh_curation(): assert len(sorting_curated_gh_mua.unit_ids) == 6 assert len(sorting_curated_gh_art_mua.unit_ids) == 5 - print("Test for GH passed!\n") @pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available") @@ -110,7 +109,6 @@ def test_sha1_curation(): assert len(sorting_curated_sha1_mua.unit_ids) == 6 assert len(sorting_curated_sha1_art_mua.unit_ids) == 5 - print("Test for sha1 curation passed!\n") def test_json_curation(): @@ -244,7 +242,6 @@ def test_label_inheritance_int(): assert 9 not in sorting_include_accept.get_unit_ids() assert 10 in sorting_include_accept.get_unit_ids() - print("Test for integer unit IDs passed!\n") def test_label_inheritance_str(): @@ -314,7 +311,6 @@ def test_label_inheritance_str(): assert "c-d" not in sorting_include_accept.get_unit_ids() assert "e-f" in sorting_include_accept.get_unit_ids() - print("Test for string unit IDs passed!\n") if __name__ == "__main__": From 776520bb100986bd90653d9b8eeba77eb0cc16aa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 15:16:55 +0000 Subject: [PATCH 12/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../curation/tests/test_sortingview_curation.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 79cea3d010..71912d7793 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -78,7 +78,6 @@ def test_gh_curation(): assert len(sorting_curated_gh_art_mua.unit_ids) == 5 - @pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available") def test_sha1_curation(): """ @@ -110,7 +109,6 @@ def test_sha1_curation(): assert len(sorting_curated_sha1_art_mua.unit_ids) == 5 - def test_json_curation(): """ Test curation using a JSON file. @@ -243,7 +241,6 @@ def test_label_inheritance_int(): assert 10 in sorting_include_accept.get_unit_ids() - def test_label_inheritance_str(): """ Test curation for label inheritance for string unit IDs. @@ -312,7 +309,6 @@ def test_label_inheritance_str(): assert "e-f" in sorting_include_accept.get_unit_ids() - if __name__ == "__main__": # generate_sortingview_curation_dataset() test_sha1_curation() From a85b4a8d666311325e74feaf05e47656048355ea Mon Sep 17 00:00:00 2001 From: Windows Home Date: Thu, 28 Sep 2023 09:39:22 -0500 Subject: [PATCH 13/74] Simplify label assignment logic and add test.json files to tests directory --- .../curation/sortingview_curation.py | 19 ++--- .../sv-sorting-curation-false-positive.json | 19 +++++ .../tests/sv-sorting-curation-int.json | 39 ++++++++++ .../tests/sv-sorting-curation-str.json | 39 ++++++++++ .../tests/test_sortingview_curation.py | 71 +++---------------- 5 files changed, 114 insertions(+), 73 deletions(-) create mode 100644 src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json create mode 100644 src/spikeinterface/curation/tests/sv-sorting-curation-int.json create mode 100644 src/spikeinterface/curation/tests/sv-sorting-curation-str.json diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index f83ff3352b..7a573c38c4 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -77,7 +77,7 @@ def apply_sortingview_curation( # in this case, the CurationSorting takes care of finding a new unused int curation_sorting.merge(merge_group, new_unit_id=None) new_unit_id = curation_sorting.max_used_id # merged unit id - labels_dict[str(new_unit_id)] = labels_to_inherit + labels_dict[str(new_unit_id)] = labels_to_inherit # STEP 2: gather and apply sortingview curation labels # In sortingview, a unit is not required to have all labels. @@ -92,19 +92,12 @@ def apply_sortingview_curation( } # Populate the properties dictionary - for u_i, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): - labels_unit = set() - + for unit_index, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): + unit_id_str = str(unit_id) # Check for exact match first - if str(unit_id) in labels_dict: - labels_unit.update(labels_dict[str(unit_id)]) - # If no exact match, check if unit_label is a substring of unit_id (for string unit ID merged unit) - else: - for unit_label, labels in labels_dict.items(): - if isinstance(unit_id, str) and unit_label in unit_id: - labels_unit.update(labels) - for label in labels_unit: - properties[label][u_i] = True + if unit_id_str in labels_dict: + for label in labels_dict[unit_id_str]: + properties[label][unit_index] = True for prop_name, prop_values in properties.items(): curation_sorting.current_sorting.set_property(prop_name, prop_values) diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json b/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json new file mode 100644 index 0000000000..5c29328363 --- /dev/null +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json @@ -0,0 +1,19 @@ +{ + "labelsByUnit": { + "1": [ + "accept" + ], + "2": [ + "artifact" + ], + "12": [ + "artifact" + ] + }, + "mergeGroups": [ + [ + 2, + 12 + ] + ] +} \ No newline at end of file diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-int.json b/src/spikeinterface/curation/tests/sv-sorting-curation-int.json new file mode 100644 index 0000000000..486a51a583 --- /dev/null +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-int.json @@ -0,0 +1,39 @@ +{ + "labelsByUnit": { + "1": [ + "mua" + ], + "2": [ + "mua" + ], + "3": [ + "reject" + ], + "4": [ + "noise" + ], + "5": [ + "accept" + ], + "6": [ + "accept" + ], + "7": [ + "accept" + ] + }, + "mergeGroups": [ + [ + 1, + 2 + ], + [ + 3, + 4 + ], + [ + 5, + 6 + ] + ] +} \ No newline at end of file diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-str.json b/src/spikeinterface/curation/tests/sv-sorting-curation-str.json new file mode 100644 index 0000000000..b2ab1d5849 --- /dev/null +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-str.json @@ -0,0 +1,39 @@ +{ + "labelsByUnit": { + "a": [ + "mua" + ], + "b": [ + "mua" + ], + "c": [ + "reject" + ], + "d": [ + "noise" + ], + "e": [ + "accept" + ], + "f": [ + "accept" + ], + "g": [ + "accept" + ] + }, + "mergeGroups": [ + [ + "a", + "b" + ], + [ + "c", + "d" + ], + [ + "e", + "f" + ] + ] +} \ No newline at end of file diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 71912d7793..1579c9f03b 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -138,8 +138,6 @@ def test_json_curation(): assert len(sorting_curated_json_mua.unit_ids) == 6 assert len(sorting_curated_json_mua1.unit_ids) == 5 - print("Test for json curation passed!\n") - def test_false_positive_curation(): """ @@ -157,15 +155,8 @@ def test_false_positive_curation(): sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) print("Sorting: {}".format(sorting.get_unit_ids())) - # Test curation JSON: - test_json = {"labelsByUnit": {"1": ["accept"], "2": ["artifact"], "12": ["artifact"]}, "mergeGroups": [[2, 12]]} - - json_path = "test_data.json" - with open(json_path, "w") as f: - json.dump(test_json, f, indent=4) - - # Apply curation - sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True) + json_file = parent_folder / "sv-sorting-curation-false-positive.json" + sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) print("Curated:", sorting_curated_json.get_unit_ids()) # Assertions @@ -173,8 +164,6 @@ def test_false_positive_curation(): assert not sorting_curated_json.get_unit_property(unit_id=10, key="accept") assert 21 in sorting_curated_json.unit_ids - print("False positive test for integer unit IDs passed!\n") - def test_label_inheritance_int(): """ @@ -190,26 +179,8 @@ def test_label_inheritance_int(): sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) - # Create a curation JSON with labels and merge groups - curation_dict = { - "labelsByUnit": { - "1": ["mua"], - "2": ["mua"], - "3": ["reject"], - "4": ["noise"], - "5": ["accept"], - "6": ["accept"], - "7": ["accept"], - }, - "mergeGroups": [[1, 2], [3, 4], [5, 6]], - } - - json_path = "test_curation_int.json" - with open(json_path, "w") as f: - json.dump(curation_dict, f, indent=4) - - # Apply curation - sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_path) + json_file = parent_folder / "sv-sorting-curation-int.json" + sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file) # Assertions for merged units print(f"Merge only: {sorting_merge.get_unit_ids()}") @@ -229,12 +200,12 @@ def test_label_inheritance_int(): assert sorting_merge.get_unit_property(unit_id=10, key="accept") # Assertions for exclude_labels - sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_path, exclude_labels=["noise"]) + sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"]) print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") assert 9 not in sorting_exclude_noise.get_unit_ids() # Assertions for include_labels - sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_path, include_labels=["accept"]) + sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"]) print(f"Include accept: {sorting_include_accept.get_unit_ids()}") assert 8 not in sorting_include_accept.get_unit_ids() assert 9 not in sorting_include_accept.get_unit_ids() @@ -254,30 +225,10 @@ def test_label_inheritance_str(): sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) print(f"Sorting: {sorting.get_unit_ids()}") - # Create a curation JSON with labels and merge groups - curation_dict = { - "labelsByUnit": { - "a": ["mua"], - "b": ["mua"], - "c": ["reject"], - "d": ["noise"], - "e": ["accept"], - "f": ["accept"], - "g": ["accept"], - }, - "mergeGroups": [["a", "b"], ["c", "d"], ["e", "f"]], - } - - json_path = "test_curation_str.json" - with open(json_path, "w") as f: - json.dump(curation_dict, f, indent=4) - - # Check label inheritance for merged units - merged_id_1 = "a-b" - merged_id_2 = "c-d" - merged_id_3 = "e-f" + # Apply curation - sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_path, verbose=True) + json_file = parent_folder / "sv-sorting-curation-str.json" + sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) # Assertions for merged units print(f"Merge only: {sorting_merge.get_unit_ids()}") @@ -297,12 +248,12 @@ def test_label_inheritance_str(): assert sorting_merge.get_unit_property(unit_id="e-f", key="accept") # Assertions for exclude_labels - sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_path, exclude_labels=["noise"]) + sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"]) print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") assert "c-d" not in sorting_exclude_noise.get_unit_ids() # Assertions for include_labels - sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_path, include_labels=["accept"]) + sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"]) print(f"Include accept: {sorting_include_accept.get_unit_ids()}") assert "a-b" not in sorting_include_accept.get_unit_ids() assert "c-d" not in sorting_include_accept.get_unit_ids() From 54d40eb2a0cc4468100fd8a058cb8a6b8354fd67 Mon Sep 17 00:00:00 2001 From: Windows Home Date: Thu, 28 Sep 2023 09:52:29 -0500 Subject: [PATCH 14/74] Comment out print statements --- .../tests/test_sortingview_curation.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 1579c9f03b..a620cb8db1 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -91,7 +91,7 @@ def test_sha1_curation(): # https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary&s={%22sortingCuration%22:%22sha1://1182ba19671fcc7d3f8e0501b0f8c07fb9736c22%22} sha1_uri = "sha1://1182ba19671fcc7d3f8e0501b0f8c07fb9736c22" sorting_curated_sha1 = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, verbose=True) - print(f"From SHA: {sorting_curated_sha1}") + # print(f"From SHA: {sorting_curated_sha1}") assert len(sorting_curated_sha1.unit_ids) == 9 assert "#8-#9" in sorting_curated_sha1.unit_ids @@ -118,7 +118,7 @@ def test_json_curation(): # from curation.json json_file = parent_folder / "sv-sorting-curation.json" - print(f"Sorting: {sorting.get_unit_ids()}") + # print(f"Sorting: {sorting.get_unit_ids()}") sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) assert len(sorting_curated_json.unit_ids) == 9 @@ -153,11 +153,11 @@ def test_false_positive_curation(): labels = np.random.randint(1, num_units + 1, size=num_spikes) sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) - print("Sorting: {}".format(sorting.get_unit_ids())) + # print("Sorting: {}".format(sorting.get_unit_ids())) json_file = parent_folder / "sv-sorting-curation-false-positive.json" sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) - print("Curated:", sorting_curated_json.get_unit_ids()) + # print("Curated:", sorting_curated_json.get_unit_ids()) # Assertions assert sorting_curated_json.get_unit_property(unit_id=1, key="accept") @@ -183,7 +183,7 @@ def test_label_inheritance_int(): sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file) # Assertions for merged units - print(f"Merge only: {sorting_merge.get_unit_ids()}") + # print(f"Merge only: {sorting_merge.get_unit_ids()}") assert sorting_merge.get_unit_property(unit_id=8, key="mua") # 8 = merged unit of 1 and 2 assert not sorting_merge.get_unit_property(unit_id=8, key="reject") assert not sorting_merge.get_unit_property(unit_id=8, key="noise") @@ -201,12 +201,12 @@ def test_label_inheritance_int(): # Assertions for exclude_labels sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"]) - print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") + # print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") assert 9 not in sorting_exclude_noise.get_unit_ids() # Assertions for include_labels sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"]) - print(f"Include accept: {sorting_include_accept.get_unit_ids()}") + # print(f"Include accept: {sorting_include_accept.get_unit_ids()}") assert 8 not in sorting_include_accept.get_unit_ids() assert 9 not in sorting_include_accept.get_unit_ids() assert 10 in sorting_include_accept.get_unit_ids() @@ -224,14 +224,14 @@ def test_label_inheritance_str(): labels = np.random.choice(["a", "b", "c", "d", "e", "f", "g"], size=num_spikes) sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) - print(f"Sorting: {sorting.get_unit_ids()}") + # print(f"Sorting: {sorting.get_unit_ids()}") # Apply curation json_file = parent_folder / "sv-sorting-curation-str.json" sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) # Assertions for merged units - print(f"Merge only: {sorting_merge.get_unit_ids()}") + # print(f"Merge only: {sorting_merge.get_unit_ids()}") assert sorting_merge.get_unit_property(unit_id="a-b", key="mua") assert not sorting_merge.get_unit_property(unit_id="a-b", key="reject") assert not sorting_merge.get_unit_property(unit_id="a-b", key="noise") @@ -249,17 +249,16 @@ def test_label_inheritance_str(): # Assertions for exclude_labels sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"]) - print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") + # print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") assert "c-d" not in sorting_exclude_noise.get_unit_ids() # Assertions for include_labels sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"]) - print(f"Include accept: {sorting_include_accept.get_unit_ids()}") + # print(f"Include accept: {sorting_include_accept.get_unit_ids()}") assert "a-b" not in sorting_include_accept.get_unit_ids() assert "c-d" not in sorting_include_accept.get_unit_ids() assert "e-f" in sorting_include_accept.get_unit_ids() - if __name__ == "__main__": # generate_sortingview_curation_dataset() test_sha1_curation() From f1b7bfe668ac8ff0581f252241edfb004577551d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Sep 2023 14:53:07 +0000 Subject: [PATCH 15/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../curation/tests/sv-sorting-curation-false-positive.json | 2 +- src/spikeinterface/curation/tests/sv-sorting-curation-int.json | 2 +- src/spikeinterface/curation/tests/sv-sorting-curation-str.json | 2 +- src/spikeinterface/curation/tests/test_sortingview_curation.py | 1 + 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json b/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json index 5c29328363..48881388bb 100644 --- a/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json @@ -16,4 +16,4 @@ 12 ] ] -} \ No newline at end of file +} diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-int.json b/src/spikeinterface/curation/tests/sv-sorting-curation-int.json index 486a51a583..2047c514ce 100644 --- a/src/spikeinterface/curation/tests/sv-sorting-curation-int.json +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-int.json @@ -36,4 +36,4 @@ 6 ] ] -} \ No newline at end of file +} diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-str.json b/src/spikeinterface/curation/tests/sv-sorting-curation-str.json index b2ab1d5849..2585b5cc50 100644 --- a/src/spikeinterface/curation/tests/sv-sorting-curation-str.json +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-str.json @@ -36,4 +36,4 @@ "f" ] ] -} \ No newline at end of file +} diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index a620cb8db1..22085f2f77 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -259,6 +259,7 @@ def test_label_inheritance_str(): assert "c-d" not in sorting_include_accept.get_unit_ids() assert "e-f" in sorting_include_accept.get_unit_ids() + if __name__ == "__main__": # generate_sortingview_curation_dataset() test_sha1_curation() From 8c633aceb84ff8e19e98949e9a9e366da3277053 Mon Sep 17 00:00:00 2001 From: Matthias H Hennig Date: Fri, 29 Sep 2023 11:54:02 +0100 Subject: [PATCH 16/74] Pip install into working directory for containers Apptainer fails to pip install into the system directory (not writable by default, no space when writable), and the --user flag ensures packages are installed in a writable location. Note not tested with docker. --- src/spikeinterface/sorters/runsorter.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 6e6ccc0358..f6501ef40f 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -514,19 +514,19 @@ def run_sorter_container( res_output = container_client.run_command(cmd) cmd = f"cp -r {si_dev_path_unix} {si_source_folder}" res_output = container_client.run_command(cmd) - cmd = f"pip install {si_source_folder}/spikeinterface[full]" + cmd = f"pip install --user {si_source_folder}/spikeinterface[full]" else: si_source = "remote repository" - cmd = "pip install --upgrade --no-input git+https://github.com/SpikeInterface/spikeinterface.git#egg=spikeinterface[full]" + cmd = "pip install --user --upgrade --no-input git+https://github.com/SpikeInterface/spikeinterface.git#egg=spikeinterface[full]" if verbose: print(f"Installing dev spikeinterface from {si_source}") res_output = container_client.run_command(cmd) - cmd = "pip install --upgrade --no-input https://github.com/NeuralEnsemble/python-neo/archive/master.zip" + cmd = "pip install --user --upgrade --no-input https://github.com/NeuralEnsemble/python-neo/archive/master.zip" res_output = container_client.run_command(cmd) else: if verbose: print(f"Installing spikeinterface=={si_version} in {container_image}") - cmd = f"pip install --upgrade --no-input spikeinterface[full]=={si_version}" + cmd = f"pip install --user --upgrade --no-input spikeinterface[full]=={si_version}" res_output = container_client.run_command(cmd) else: # TODO version checking @@ -540,7 +540,7 @@ def run_sorter_container( if extra_requirements: if verbose: print(f"Installing extra requirements: {extra_requirements}") - cmd = f"pip install --upgrade --no-input {' '.join(extra_requirements)}" + cmd = f"pip install --user --upgrade --no-input {' '.join(extra_requirements)}" res_output = container_client.run_command(cmd) # run sorter on folder From 4f2a50d7d1e0414bdf3bf2bdc3b9d35b12a900e3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Sep 2023 11:18:26 +0000 Subject: [PATCH 17/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/extractors/mdaextractors.py | 2 +- .../benchmark/benchmark_motion_estimation.py | 6 ++---- .../benchmark/benchmark_motion_interpolation.py | 8 ++++++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index 98378551f5..b863e338fa 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -220,7 +220,7 @@ def write_sorting(sorting, save_path, write_primary_channels=False): times = sorting.get_unit_spike_train(unit_id=unit_id) times_list.append(times) # unit id may not be numeric - if unit_id.dtype.kind in 'biufc': + if unit_id.dtype.kind in "biufc": labels_list.append(np.ones(times.shape) * unit_id) else: labels_list.append(np.ones(times.shape) * unit_id_i) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index c505676c05..abf40b2da6 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -584,13 +584,13 @@ def plot_motions_several_benchmarks(benchmarks): _simpleaxis(ax) -def plot_speed_several_benchmarks(benchmarks, detailed=True, ax=None, colors=None): +def plot_speed_several_benchmarks(benchmarks, detailed=True, ax=None, colors=None): if ax is None: fig, ax = plt.subplots(figsize=(5, 5)) for count, benchmark in enumerate(benchmarks): color = colors[count] if colors is not None else None - + if detailed: bottom = 0 i = 0 @@ -606,8 +606,6 @@ def plot_speed_several_benchmarks(benchmarks, detailed=True, ax=None, colors=No else: total_run_time = np.sum([value for key, value in benchmark.run_times.items()]) ax.bar([count], [total_run_time], color=color, edgecolor="black") - - # ax.legend() ax.set_ylabel("speed (s)") diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index 8e5afb2e8e..b28b29f17c 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -193,11 +193,15 @@ def run_sorters(self, skip_already_done=True): recording = self.recordings[case["recording"]] output_folder = self.folder / f"tmp_sortings_{label}" if output_folder.exists() and skip_already_done: - print('already done') + print("already done") sorting = read_sorter_folder(output_folder) else: sorting = run_sorter( - sorter_name, recording, output_folder, **sorter_params, delete_output_folder=self.delete_output_folder + sorter_name, + recording, + output_folder, + **sorter_params, + delete_output_folder=self.delete_output_folder, ) self.sortings[label] = sorting From 06089b8c0ed74c37c89d0d6ed2684e4c57668322 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 29 Sep 2023 15:07:28 +0200 Subject: [PATCH 18/74] Patch for sharedmem --- .../sorters/internal/spyking_circus2.py | 3 +-- .../clustering/clustering_tools.py | 5 +++-- .../clustering/random_projections.py | 15 ++++++++++----- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 710c4f76f4..a0a4d0823c 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -118,8 +118,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if clustering_folder.exists(): shutil.rmtree(clustering_folder) - sorting = sorting.save(folder=clustering_folder) - ## We get the templates our of such a clustering waveforms_params = params["waveforms"].copy() waveforms_params.update(job_kwargs) @@ -131,6 +129,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): mode = "memory" waveforms_folder = None else: + sorting = sorting.save(folder=clustering_folder) mode = "folder" waveforms_folder = sorter_output_folder / "waveforms" diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 28a1a63065..1a8332ad6d 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -574,6 +574,8 @@ def remove_duplicates_via_matching( if tmp_folder is None: tmp_folder = get_global_tmp_folder() + tmp_folder.mkdir(parents=True, exist_ok=True) + tmp_filename = tmp_folder / "tmp.raw" f = open(tmp_filename, "wb") @@ -583,8 +585,8 @@ def remove_duplicates_via_matching( f.close() recording = BinaryRecordingExtractor(tmp_filename, num_channels=num_chans, sampling_frequency=fs, dtype="float32") - recording.annotate(is_filtered=True) recording = recording.set_probe(waveform_extractor.recording.get_probe()) + recording.annotate(is_filtered=True) margin = 2 * max(waveform_extractor.nbefore, waveform_extractor.nafter) half_marging = margin // 2 @@ -608,7 +610,6 @@ def remove_duplicates_via_matching( t_stop = padding + (i + 1) * duration sub_recording = recording.frame_slice(t_start - half_marging, t_stop + half_marging) - method_kwargs.update({"ignored_ids": ignore_ids + [i]}) spikes, computed = find_spikes_from_templates( sub_recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 864548e7d4..1f97bf5201 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -181,17 +181,20 @@ def sigmoid(x, L, x0, k, b): else: tmp_folder = Path(params["tmp_folder"]) + tmp_folder.mkdir(parents=True, exist_ok=True) + + sorting_folder = tmp_folder / "sorting" + unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) + sorting = NumpySorting(spikes, fs, unit_ids=unit_ids) + if params["shared_memory"]: waveform_folder = None mode = "memory" else: waveform_folder = tmp_folder / "waveforms" mode = "folder" + sorting = sorting.save(folder=sorting_folder) - sorting_folder = tmp_folder / "sorting" - unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) - sorting = NumpySorting(spikes, fs, unit_ids=unit_ids) - sorting = sorting.save(folder=sorting_folder) we = extract_waveforms( recording, sorting, @@ -219,12 +222,14 @@ def sigmoid(x, L, x0, k, b): we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params ) + del we, sorting + if params["tmp_folder"] is None: shutil.rmtree(tmp_folder) else: if not params["shared_memory"]: shutil.rmtree(tmp_folder / "waveforms") - shutil.rmtree(tmp_folder / "sorting") + shutil.rmtree(tmp_folder / "sorting") if verbose: print("We kept %d non-duplicated clusters..." % len(labels)) From fa725fcd24c26ca3a55605a051c3527fb23cc35b Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 29 Sep 2023 16:27:53 -0400 Subject: [PATCH 19/74] add keyword arguments --- doc/how_to/load_matlab_data.rst | 4 +- doc/modules/curation.rst | 36 +++++----- doc/modules/exporters.rst | 19 +++-- doc/modules/extractors.rst | 35 +++++++--- doc/modules/motion_correction.rst | 44 ++++++------ doc/modules/postprocessing.rst | 10 +-- doc/modules/preprocessing.rst | 70 ++++++++++--------- doc/modules/qualitymetrics.rst | 8 +-- doc/modules/qualitymetrics/amplitude_cv.rst | 2 +- .../qualitymetrics/amplitude_median.rst | 2 +- doc/modules/qualitymetrics/d_prime.rst | 2 +- doc/modules/qualitymetrics/drift.rst | 6 +- doc/modules/qualitymetrics/firing_range.rst | 2 +- doc/modules/qualitymetrics/firing_rate.rst | 2 +- .../qualitymetrics/isolation_distance.rst | 10 +++ doc/modules/qualitymetrics/l_ratio.rst | 11 +++ doc/modules/qualitymetrics/presence_ratio.rst | 2 +- .../qualitymetrics/silhouette_score.rst | 10 +++ .../qualitymetrics/sliding_rp_violations.rst | 2 +- doc/modules/qualitymetrics/snr.rst | 3 +- doc/modules/qualitymetrics/synchrony.rst | 2 +- doc/modules/sorters.rst | 42 +++++------ doc/modules/sortingcomponents.rst | 23 +++--- doc/modules/widgets.rst | 10 +-- 24 files changed, 203 insertions(+), 154 deletions(-) diff --git a/doc/how_to/load_matlab_data.rst b/doc/how_to/load_matlab_data.rst index e12d83810a..54a66c0890 100644 --- a/doc/how_to/load_matlab_data.rst +++ b/doc/how_to/load_matlab_data.rst @@ -54,7 +54,7 @@ Use the following Python script to load the binary data into SpikeInterface: dtype = "float64" # MATLAB's double corresponds to Python's float64 # Load data using SpikeInterface - recording = si.read_binary(file_path, sampling_frequency=sampling_frequency, + recording = si.read_binary(file_paths=file_path, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype) # Confirm that the data was loaded correctly by comparing the data shapes and see they match the MATLAB data @@ -86,7 +86,7 @@ If your data in MATLAB is stored as :code:`int16`, and you know the gain and off gain_to_uV = 0.195 # Adjust according to your MATLAB dataset offset_to_uV = 0 # Adjust according to your MATLAB dataset - recording = si.read_binary(file_path, sampling_frequency=sampling_frequency, + recording = si.read_binary(file_paths=file_path, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype_int, gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV) diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index 6101b81552..23e9e20d96 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -24,21 +24,21 @@ The merging and splitting operations are handled by the :py:class:`~spikeinterfa from spikeinterface.curation import CurationSorting - sorting = run_sorter('kilosort2', recording) + sorting = run_sorter(sorter_name='kilosort2', recording=recording) - cs = CurationSorting(sorting) + cs = CurationSorting(parent_sorting=sorting) # make a first merge - cs.merge(['#1', '#5', '#15']) + cs.merge(units_to_merge=['#1', '#5', '#15']) # make a second merge - cs.merge(['#11', '#21']) + cs.merge(units_to_merge=['#11', '#21']) # make a split split_index = ... # some criteria on spikes - cs.split('#20', split_index) + cs.split(split_unit_id='#20', indices_list=split_index) - # here the final clean sorting + # here is the final clean sorting clean_sorting = cs.sorting @@ -60,12 +60,12 @@ merges. Therefore, it has many parameters and options. from spikeinterface.curation import MergeUnitsSorting, get_potential_auto_merge - sorting = run_sorter('kilosort', recording) + sorting = run_sorter(sorter_name='kilosort', recording=recording) - we = extract_waveforms(recording, sorting, folder='wf_folder') + we = extract_waveforms(recording=recording, sorting=sorting, folder='wf_folder') # merges is a list of lists, with unit_ids to be merged. - merges = get_potential_auto_merge(we, minimum_spikes=1000, maximum_distance_um=150., + merges = get_potential_auto_merge(waveform_extractor=we, minimum_spikes=1000, maximum_distance_um=150., peak_sign="neg", bin_ms=0.25, window_ms=100., corr_diff_thresh=0.16, template_diff_thresh=0.25, censored_period_ms=0., refractory_period_ms=1.0, @@ -73,7 +73,7 @@ merges. Therefore, it has many parameters and options. firing_contamination_balance=1.5) # here we apply the merges - clean_sorting = MergeUnitsSorting(sorting, merges) + clean_sorting = MergeUnitsSorting(parent_sorting=sorting, units_to_merge=merges) Manual curation with sorting view @@ -98,24 +98,24 @@ The manual curation (including merges and labels) can be applied to a SpikeInter from spikeinterface.widgets import plot_sorting_summary # run a sorter and export waveforms - sorting = run_sorter('kilosort2', recording) - we = extract_waveforms(recording, sorting, folder='wf_folder') + sorting = run_sorter(sorter_name'kilosort2', recording=recording) + we = extract_waveforms(recording=recording, sorting=sorting, folder='wf_folder') # some postprocessing is required - _ = compute_spike_amplitudes(we) - _ = compute_unit_locations(we) - _ = compute_template_similarity(we) - _ = compute_correlograms(we) + _ = compute_spike_amplitudes(waveform_extractor=we) + _ = compute_unit_locations(waveform_extractor=we) + _ = compute_template_similarity(waveform_extractor=we) + _ = compute_correlograms(waveform_extractor=we) # This loads the data to the cloud for web-based plotting and sharing - plot_sorting_summary(we, curation=True, backend='sortingview') + plot_sorting_summary(waveform_extractor=we, curation=True, backend='sortingview') # we open the printed link URL in a browswe # - make manual merges and labeling # - from the curation box, click on "Save as snapshot (sha1://)" # copy the uri sha_uri = "sha1://59feb326204cf61356f1a2eb31f04d8e0177c4f1" - clean_sorting = apply_sortingview_curation(sorting, uri_or_json=sha_uri) + clean_sorting = apply_sortingview_curation(sorting=sorting, uri_or_json=sha_uri) Note that you can also "Export as JSON" and pass the json file as :code:`uri_or_json` parameter. diff --git a/doc/modules/exporters.rst b/doc/modules/exporters.rst index fa637f898b..1d23f9ad6f 100644 --- a/doc/modules/exporters.rst +++ b/doc/modules/exporters.rst @@ -28,15 +28,14 @@ The input of the :py:func:`~spikeinterface.exporters.export_to_phy` is a :code:` from spikeinterface.exporters import export_to_phy # the waveforms are sparse so it is faster to export to phy - folder = 'waveforms' - we = extract_waveforms(recording, sorting, folder, sparse=True) + we = extract_waveforms(recording=recording, sorting=sorting, folder='waveforms', sparse=True) # some computations are done before to control all options - compute_spike_amplitudes(we) - compute_principal_components(we, n_components=3, mode='by_channel_global') + compute_spike_amplitudes(waveform_extractor = we) + compute_principal_components(waveform_extractor=we, n_components=3, mode='by_channel_global') # the export process is fast because everything is pre-computed - export_to_phy(we, output_folder='path/to/phy_folder') + export_to_phy(wavefor_extractor=we, output_folder='path/to/phy_folder') @@ -72,12 +71,12 @@ with many units! # the waveforms are sparse for more interpretable figures - we = extract_waveforms(recording, sorting, folder='path/to/wf', sparse=True) + we = extract_waveforms(recording=recording, sorting=sorting, folder='path/to/wf', sparse=True) # some computations are done before to control all options - compute_spike_amplitudes(we) - compute_correlograms(we) - compute_quality_metrics(we, metric_names=['snr', 'isi_violation', 'presence_ratio']) + compute_spike_amplitudes(waveform_extractor=we) + compute_correlograms(waveform_extractor=we) + compute_quality_metrics(waveform_extractor=we, metric_names=['snr', 'isi_violation', 'presence_ratio']) # the export process - export_report(we, output_folder='path/to/spikeinterface-report-folder') + export_report(waveform_extractor=we, output_folder='path/to/spikeinterface-report-folder') diff --git a/doc/modules/extractors.rst b/doc/modules/extractors.rst index 5aed24ca41..1eeca9a325 100644 --- a/doc/modules/extractors.rst +++ b/doc/modules/extractors.rst @@ -6,18 +6,19 @@ Overview The :py:mod:`~spikeinterface.extractors` module allows you to load :py:class:`~spikeinterface.core.BaseRecording`, :py:class:`~spikeinterface.core.BaseSorting`, and :py:class:`~spikeinterface.core.BaseEvent` objects from -a large variety of acquisition systems and spike sorting outputs. +a large variety of acquisition systems and spike sorting outputs. Most of the :code:`Recording` classes are implemented by wrapping the `NEO rawio implementation `_. Most of the :code:`Sorting` classes are instead directly implemented in SpikeInterface. - Although SpikeInterface is object-oriented (class-based), each object can also be loaded with a convenient :code:`read_XXXXX()` function. +.. code-block:: python + import spikeinterface.extractors as se Read one Recording @@ -27,32 +28,44 @@ Every format can be read with a simple function: .. code-block:: python - recording_oe = read_openephys("open-ephys-folder") + recording_oe = read_openephys(folder_path="open-ephys-folder") - recording_spikeglx = read_spikeglx("spikeglx-folder") + recording_spikeglx = read_spikeglx(folder_path="spikeglx-folder") - recording_blackrock = read_blackrock("blackrock-folder") + recording_blackrock = read_blackrock(folder_path="blackrock-folder") - recording_mearec = read_mearec("mearec_file.h5") + recording_mearec = read_mearec(file_path="mearec_file.h5") Importantly, some formats directly handle the probe information: .. code-block:: python - recording_spikeglx = read_spikeglx("spikeglx-folder") + recording_spikeglx = read_spikeglx(folder_path="spikeglx-folder") print(recording_spikeglx.get_probe()) - recording_mearec = read_mearec("mearec_file.h5") + recording_mearec = read_mearec(file_path="mearec_file.h5") print(recording_mearec.get_probe()) +Although most recordings are loaded with the :py:mod:`~spikeinterface.extractors` +a few file formats are loaded from the :py:mod:`~spikeinterface.core` module + +.. code-block:: python + + import spikeinterface as si + + recording_binary = si.read_binary(file_path='binary.bin') + + recording_zarr = si.read_zarr(file_path='zarr_file.zarr') + + Read one Sorting ---------------- .. code-block:: python - sorting_KS = read_kilosort("kilosort-folder") + sorting_KS = read_kilosort(folder_path="kilosort-folder") Read one Event @@ -60,7 +73,7 @@ Read one Event .. code-block:: python - events_OE = read_openephys_event("open-ephys-folder") + events_OE = read_openephys_event(folder_path="open-ephys-folder") For a comprehensive list of compatible technologies, see :ref:`compatible_formats`. @@ -77,7 +90,7 @@ The actual reading will be done on demand using the :py:meth:`~spikeinterface.co .. code-block:: python # opening a 40GB SpikeGLX dataset is fast - recording_spikeglx = read_spikeglx("spikeglx-folder") + recording_spikeglx = read_spikeglx(folder_path="spikeglx-folder") # this really does load the full 40GB into memory : not recommended!!!!! traces = recording_spikeglx.get_traces(start_frame=None, end_frame=None, return_scaled=False) diff --git a/doc/modules/motion_correction.rst b/doc/modules/motion_correction.rst index afedc4f982..96ecc1fcec 100644 --- a/doc/modules/motion_correction.rst +++ b/doc/modules/motion_correction.rst @@ -77,12 +77,12 @@ We currently have 3 presets: .. code-block:: python # read and preprocess - rec = read_spikeglx('/my/Neuropixel/recording') - rec = bandpass_filter(rec) - rec = common_reference(rec) + rec = read_spikeglx(folder_path='/my/Neuropixel/recording') + rec = bandpass_filter(recording=rec) + rec = common_reference(recording=rec) # then correction is one line of code - rec_corrected = correct_motion(rec, preset="nonrigid_accurate") + rec_corrected = correct_motion(recording=rec, preset="nonrigid_accurate") The process is quite long due the two first steps (activity profile + motion inference) But the return :code:`rec_corrected` is a lazy recording object that will interpolate traces on the @@ -94,17 +94,17 @@ If you want to user other presets, this is as easy as: .. code-block:: python # mimic kilosort motion - rec_corrected = correct_motion(rec, preset="kilosort_like") + rec_corrected = correct_motion(recording=rec, preset="kilosort_like") # super but less accurate and rigid - rec_corrected = correct_motion(rec, preset="rigid_fast") + rec_corrected = correct_motion(recording=rec, preset="rigid_fast") Optionally any parameter from the preset can be overwritten: .. code-block:: python - rec_corrected = correct_motion(rec, preset="nonrigid_accurate", + rec_corrected = correct_motion(recording=rec, preset="nonrigid_accurate", detect_kwargs=dict( detect_threshold=10.), estimate_motion_kwargs=dic( @@ -123,7 +123,7 @@ and checking. The folder will contain the motion vector itself of course but als .. code-block:: python motion_folder = '/somewhere/to/save/the/motion' - rec_corrected = correct_motion(rec, preset="nonrigid_accurate", folder=motion_folder) + rec_corrected = correct_motion(recording=rec, preset="nonrigid_accurate", folder=motion_folder) # and then motion_info = load_motion_info(motion_folder) @@ -156,14 +156,16 @@ The high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` is inte job_kwargs = dict(chunk_duration="1s", n_jobs=20, progress_bar=True) # Step 1 : activity profile - peaks = detect_peaks(rec, method="locally_exclusive", detect_threshold=8.0, **job_kwargs) + peaks = detect_peaks(recording=rec, method="locally_exclusive", detect_threshold=8.0, **job_kwargs) # (optional) sub-select some peaks to speed up the localization - peaks = select_peaks(peaks, ...) - peak_locations = localize_peaks(rec, peaks, method="monopolar_triangulation",radius_um=75.0, + peaks = select_peaks(peaks=peaks, ...) + peak_locations = localize_peaks(recording=rec, peaks=peaks, method="monopolar_triangulation",radius_um=75.0, max_distance_um=150.0, **job_kwargs) # Step 2: motion inference - motion, temporal_bins, spatial_bins = estimate_motion(rec, peaks, peak_locations, + motion, temporal_bins, spatial_bins = estimate_motion(recording=rec, + peaks=peaks, + peak_locations=peak_locations, method="decentralized", direction="y", bin_duration_s=2.0, @@ -173,7 +175,9 @@ The high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` is inte # Step 3: motion interpolation # this step is lazy - rec_corrected = interpolate_motion(rec, motion, temporal_bins, spatial_bins, + rec_corrected = interpolate_motion(recording=rec, motion=motion, + temporal_bins=temporal_bins, + spatial_bins=spatial_bins, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=30.) @@ -196,20 +200,20 @@ different preprocessing chains: one for motion correction and one for spike sort .. code-block:: python - raw_rec = read_spikeglx(...) + raw_rec = read_spikeglx(folder_path='/spikeglx_folder') # preprocessing 1 : bandpass (this is smoother) + cmr - rec1 = si.bandpass_filter(raw_rec, freq_min=300., freq_max=5000.) - rec1 = si.common_reference(rec1, reference='global', operator='median') + rec1 = si.bandpass_filter(recording=raw_rec, freq_min=300., freq_max=5000.) + rec1 = si.common_reference(recording=rec1, reference='global', operator='median') # here the corrected recording is done on the preprocessing 1 # rec_corrected1 will not be used for sorting! motion_folder = '/my/folder' - rec_corrected1 = correct_motion(rec1, preset="nonrigid_accurate", folder=motion_folder) + rec_corrected1 = correct_motion(recording=rec1, preset="nonrigid_accurate", folder=motion_folder) # preprocessing 2 : highpass + cmr - rec2 = si.highpass_filter(raw_rec, freq_min=300.) - rec2 = si.common_reference(rec2, reference='global', operator='median') + rec2 = si.highpass_filter(recording=raw_rec, freq_min=300.) + rec2 = si.common_reference(recording=rec2, reference='global', operator='median') # we use another preprocessing for the final interpolation motion_info = load_motion_info(motion_folder) @@ -220,7 +224,7 @@ different preprocessing chains: one for motion correction and one for spike sort spatial_bins=motion_info['spatial_bins'], **motion_info['parameters']['interpolate_motion_kwargs']) - sorting = run_sorter("montainsort5", rec_corrected2) + sorting = run_sorter(sorter_name="montainsort5", recording=rec_corrected2) References diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index a560f4d5c9..112c6e367d 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -14,9 +14,9 @@ WaveformExtractor extensions There are several postprocessing tools available, and all of them are implemented as a :py:class:`~spikeinterface.core.BaseWaveformExtractorExtension`. All computations on top -of a WaveformExtractor will be saved along side the WaveformExtractor itself (sub folder, zarr path or sub dict). +of a :code:`WaveformExtractor` will be saved along side the :code:`WaveformExtractor` itself (sub folder, zarr path or sub dict). This workflow is convenient for retrieval of time-consuming computations (such as pca or spike amplitudes) when reloading a -WaveformExtractor. +:code:`WaveformExtractor`. :py:class:`~spikeinterface.core.BaseWaveformExtractorExtension` objects are tightly connected to the parent :code:`WaveformExtractor` object, so that operations done on the :code:`WaveformExtractor`, such as saving, @@ -80,9 +80,9 @@ This extension computes the principal components of the waveforms. There are sev * "by_channel_local" (default): fits one PCA model for each by_channel * "by_channel_global": fits the same PCA model to all channels (also termed temporal PCA) -* "concatenated": contatenates all channels and fits a PCA model on the concatenated data +* "concatenated": concatenates all channels and fits a PCA model on the concatenated data -If the input :code:`WaveformExtractor` is sparse, the sparsity is used when computing PCA. +If the input :code:`WaveformExtractor` is sparse, the sparsity is used when computing the PCA. For dense waveforms, sparsity can also be passed as an argument. For more information, see :py:func:`~spikeinterface.postprocessing.compute_principal_components` @@ -127,7 +127,7 @@ with center of mass (:code:`method="center_of_mass"` - fast, but less accurate), For more information, see :py:func:`~spikeinterface.postprocessing.compute_spike_locations` -unit locations +unit_locations ^^^^^^^^^^^^^^ diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 7c1f33f298..67f1e52011 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -22,8 +22,8 @@ In this code example, we build a preprocessing chain with two steps: import spikeinterface.preprocessing import bandpass_filter, common_reference # recording is a RecordingExtractor object - recording_f = bandpass_filter(recording, freq_min=300, freq_max=6000) - recording_cmr = common_reference(recording_f, operator="median") + recording_f = bandpass_filter(recording=recording, freq_min=300, freq_max=6000) + recording_cmr = common_reference(recording=recording_f, operator="median") These two preprocessors will not compute anything at instantiation, but the computation will be "on-demand" ("on-the-fly") when getting traces. @@ -38,7 +38,7 @@ save the object: .. code-block:: python # here the spykingcircus2 sorter engine directly uses the lazy "recording_cmr" object - sorting = run_sorter(recording_cmr, 'spykingcircus2') + sorting = run_sorter(recording=recording_cmr, sorter_name='spykingcircus2') Most of the external sorters, however, will need a binary file as input, so we can optionally save the processed recording with the efficient SpikeInterface :code:`save()` function: @@ -64,12 +64,13 @@ dtype (unless specified otherwise): .. code-block:: python + import spikeinterface.extractors as se # spikeGLX is int16 - rec_int16 = read_spikeglx("my_folder") + rec_int16 = se.read_spikeglx(folder_path"my_folder") # by default the int16 is kept - rec_f = bandpass_filter(rec_int16, freq_min=300, freq_max=6000) + rec_f = bandpass_filter(recording=rec_int16, freq_min=300, freq_max=6000) # we can force a float32 casting - rec_f2 = bandpass_filter(rec_int16, freq_min=300, freq_max=6000, dtype='float32') + rec_f2 = bandpass_filter(recording=rec_int16, freq_min=300, freq_max=6000, dtype='float32') Some scaling pre-processors, such as :code:`whiten()` or :code:`zscore()`, will force the output to :code:`float32`. @@ -83,6 +84,8 @@ The full list of preprocessing functions can be found here: :ref:`api_preprocess Here is a full list of possible preprocessing steps, grouped by type of processing: +For all examples :code:`rec` is a :code:`RecordingExtractor`. + filter() / bandpass_filter() / notch_filter() / highpass_filter() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -98,7 +101,7 @@ Important aspects of filtering functions: .. code-block:: python - rec_f = bandpass_filter(rec, freq_min=300, freq_max=6000) + rec_f = bandpass_filter(recording=rec, freq_min=300, freq_max=6000) * :py:func:`~spikeinterface.preprocessing.filter()` @@ -119,7 +122,7 @@ There are various options when combining :code:`operator` and :code:`reference` .. code-block:: python - rec_cmr = common_reference(rec, operator="median", reference="global") + rec_cmr = common_reference(recording=rec, operator="median", reference="global") * :py:func:`~spikeinterface.preprocessing.common_reference()` @@ -144,8 +147,8 @@ difference on artifact removal. .. code-block:: python - rec_shift = phase_shift(rec) - rec_cmr = common_reference(rec_shift, operator="median", reference="global") + rec_shift = phase_shift(recording=rec) + rec_cmr = common_reference(recording=rec_shift, operator="median", reference="global") @@ -168,7 +171,7 @@ centered with unitary variance on each channel. .. code-block:: python - rec_normed = zscore(rec) + rec_normed = zscore(recording=rec) * :py:func:`~spikeinterface.preprocessing.normalize_by_quantile()` * :py:func:`~spikeinterface.preprocessing.scale()` @@ -186,7 +189,7 @@ The whitened traces are then the dot product between the traces and the :code:`W .. code-block:: python - rec_w = whiten(rec) + rec_w = whiten(recording=rec) * :py:func:`~spikeinterface.preprocessing.whiten()` @@ -199,7 +202,7 @@ The :code:`blank_staturation()` function is similar, but it automatically estima .. code-block:: python - rec_w = clip(rec, a_min=-250., a_max=260) + rec_w = clip(recording=rec, a_min=-250., a_max=260) * :py:func:`~spikeinterface.preprocessing.clip()` * :py:func:`~spikeinterface.preprocessing.blank_staturation()` @@ -234,11 +237,11 @@ interpolated with the :code:`interpolate_bad_channels()` function (channels labe .. code-block:: python # detect - bad_channel_ids, channel_labels = detect_bad_channels(rec) + bad_channel_ids, channel_labels = detect_bad_channels(recording=rec) # Case 1 : remove then - rec_clean = recording.remove_channels(bad_channel_ids) + rec_clean = recording.remove_channels(remove_channel_ids=bad_channel_ids) # Case 2 : interpolate then - rec_clean = interpolate_bad_channels(rec, bad_channel_ids) + rec_clean = interpolate_bad_channels(recording=rec, bad_channel_ids=bad_channel_ids) * :py:func:`~spikeinterface.preprocessing.detect_bad_channels()` @@ -257,13 +260,13 @@ remove_artifacts() Given an external list of trigger times, :code:`remove_artifacts()` function can remove artifacts with several strategies: -* replace with zeros (blank) -* make a linear or cubic interpolation -* remove the median or average template (with optional time jitter and amplitude scaling correction) +* replace with zeros (blank) :code:`'zeros'` +* make a linear (:code:`'linear'`) or cubic (:code:`'cubic'`) interpolation +* remove the median (:code:`'median'`) or average (:code:`'avereage'`) template (with optional time jitter and amplitude scaling correction) .. code-block:: python - rec_clean = remove_artifacts(rec, list_triggers) + rec_clean = remove_artifacts(recording=rec, list_triggers=[100, 200, 300], mode='zeros') * :py:func:`~spikeinterface.preprocessing.remove_artifacts()` @@ -276,7 +279,7 @@ Similarly to :code:`numpy.astype()`, the :code:`astype()` casts the traces to th .. code-block:: python - rec_int16 = astype(rec_float, "int16") + rec_int16 = astype(recording=rec_float, dtype="int16") For recordings whose traces are unsigned (e.g. Maxwell Biosystems), the :code:`unsigned_to_signed()` function makes them @@ -286,7 +289,7 @@ is subtracted, and the traces are finally cast to :code:`int16`: .. code-block:: python - rec_int16 = unsigned_to_signed(rec_uint16) + rec_int16 = unsigned_to_signed(recording=rec_uint16) * :py:func:`~spikeinterface.preprocessing.astype()` * :py:func:`~spikeinterface.preprocessing.unsigned_to_signed()` @@ -300,7 +303,7 @@ required. .. code-block:: python - rec_with_more_channels = zero_channel_pad(rec, 128) + rec_with_more_channels = zero_channel_pad(parent_recording=rec, num_channels=128) * :py:func:`~spikeinterface.preprocessing.zero_channel_pad()` @@ -331,7 +334,7 @@ How to implement "IBL destriping" or "SpikeGLX CatGT" in SpikeInterface SpikeGLX has a built-in function called `CatGT `_ to apply some preprocessing on the traces to remove noise and artifacts. IBL also has a standardized pipeline for preprocessed traces a bit similar to CatGT which is called "destriping" [IBL_spikesorting]_. -In these both cases, the traces are entiely read, processed and written back to a file. +In both these cases, the traces are entirely read, processed and written back to a file. SpikeInterface can reproduce similar results without the need to write back to a file by building a *lazy* preprocessing chain. Optionally, the result can still be written to a binary (or a zarr) file. @@ -341,12 +344,12 @@ Here is a recipe to mimic the **IBL destriping**: .. code-block:: python - rec = read_spikeglx('my_spikeglx_folder') - rec = highpass_filter(rec, n_channel_pad=60) - rec = phase_shift(rec) - bad_channel_ids = detect_bad_channels(rec) - rec = interpolate_bad_channels(rec, bad_channel_ids) - rec = highpass_spatial_filter(rec) + rec = read_spikeglx(folder_path='my_spikeglx_folder') + rec = highpass_filter(recording=rec, n_channel_pad=60) + rec = phase_shift(recording=rec) + bad_channel_ids = detect_bad_channels(recording=rec) + rec = interpolate_bad_channels(recording=rec, bad_channel_ids=bad_channel_ids) + rec = highpass_spatial_filter(recording=rec) # optional rec.save(folder='clean_traces', n_jobs=10, chunk_duration='1s', progres_bar=True) @@ -356,9 +359,9 @@ Here is a recipe to mimic the **SpikeGLX CatGT**: .. code-block:: python - rec = read_spikeglx('my_spikeglx_folder') - rec = phase_shift(rec) - rec = common_reference(rec, operator="median", reference="global") + rec = read_spikeglx(folder_path='my_spikeglx_folder') + rec = phase_shift(recording=rec) + rec = common_reference(recording=rec, operator="median", reference="global") # optional rec.save(folder='clean_traces', n_jobs=10, chunk_duration='1s', progres_bar=True) @@ -369,7 +372,6 @@ Of course, these pipelines can be enhanced and customized using other available - Preprocessing on Snippets ------------------------- diff --git a/doc/modules/qualitymetrics.rst b/doc/modules/qualitymetrics.rst index 447d83db52..ec1788350f 100644 --- a/doc/modules/qualitymetrics.rst +++ b/doc/modules/qualitymetrics.rst @@ -47,16 +47,16 @@ This code snippet shows how to compute quality metrics (with or without principa .. code-block:: python - we = si.load_waveforms(...) # start from a waveform extractor + we = si.load_waveforms(folder='waveforms') # start from a waveform extractor # without PC - metrics = compute_quality_metrics(we, metric_names=['snr']) + metrics = compute_quality_metrics(waveform_extractor=we, metric_names=['snr']) assert 'snr' in metrics.columns # with PCs from spikeinterface.postprocessing import compute_principal_components - pca = compute_principal_components(we, n_components=5, mode='by_channel_local') - metrics = compute_quality_metrics(we) + pca = compute_principal_components(waveform_extractor=we, n_components=5, mode='by_channel_local') + metrics = compute_quality_metrics(waveform_extractor=we) assert 'isolation_distance' in metrics.columns For more information about quality metrics, check out this excellent diff --git a/doc/modules/qualitymetrics/amplitude_cv.rst b/doc/modules/qualitymetrics/amplitude_cv.rst index 13117b607c..81d3b4f12d 100644 --- a/doc/modules/qualitymetrics/amplitude_cv.rst +++ b/doc/modules/qualitymetrics/amplitude_cv.rst @@ -37,7 +37,7 @@ Example code # Make recording, sorting and wvf_extractor object for your data. # It is required to run `compute_spike_amplitudes(wvf_extractor)` or # `compute_amplitude_scalings(wvf_extractor)` (if missing, values will be NaN) - amplitude_cv_median, amplitude_cv_range = sqm.compute_amplitude_cv_metrics(wvf_extractor) + amplitude_cv_median, amplitude_cv_range = sqm.compute_amplitude_cv_metrics(waveform_extractor=wvf_extractor) # amplitude_cv_median and amplitude_cv_range are dicts containing the unit ids as keys, # and their amplitude_cv metrics as values. diff --git a/doc/modules/qualitymetrics/amplitude_median.rst b/doc/modules/qualitymetrics/amplitude_median.rst index 3ac52560e8..c77a57b033 100644 --- a/doc/modules/qualitymetrics/amplitude_median.rst +++ b/doc/modules/qualitymetrics/amplitude_median.rst @@ -24,7 +24,7 @@ Example code # It is also recommended to run `compute_spike_amplitudes(wvf_extractor)` # in order to use amplitude values from all spikes. - amplitude_medians = sqm.compute_amplitude_medians(wvf_extractor) + amplitude_medians = sqm.compute_amplitude_medians(waveform_extractor=wvf_extractor) # amplitude_medians is a dict containing the unit IDs as keys, # and their estimated amplitude medians as values. diff --git a/doc/modules/qualitymetrics/d_prime.rst b/doc/modules/qualitymetrics/d_prime.rst index e3bd61c580..9b540be743 100644 --- a/doc/modules/qualitymetrics/d_prime.rst +++ b/doc/modules/qualitymetrics/d_prime.rst @@ -34,7 +34,7 @@ Example code import spikeinterface.qualitymetrics as sqm - d_prime = sqm.lda_metrics(all_pcs, all_labels, 0) + d_prime = sqm.lda_metrics(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0) Reference diff --git a/doc/modules/qualitymetrics/drift.rst b/doc/modules/qualitymetrics/drift.rst index ae52f7f883..dad2aafe7c 100644 --- a/doc/modules/qualitymetrics/drift.rst +++ b/doc/modules/qualitymetrics/drift.rst @@ -43,10 +43,10 @@ Example code import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - # It is required to run `compute_spike_locations(wvf_extractor)` + # It is required to run `compute_spike_locations(wvf_extractor) first` # (if missing, values will be NaN) - drift_ptps, drift_stds, drift_mads = sqm.compute_drift_metrics(wvf_extractor, peak_sign="neg") - # drift_ptps, drift_stds, and drift_mads are dict containing the units' ID as keys, + drift_ptps, drift_stds, drift_mads = sqm.compute_drift_metrics(waveform_extractor=wvf_extractor, peak_sign="neg") + # drift_ptps, drift_stds, and drift_mads are each a dict containing the unit IDs as keys, # and their metrics as values. diff --git a/doc/modules/qualitymetrics/firing_range.rst b/doc/modules/qualitymetrics/firing_range.rst index 925539e9c6..1cbd903c7a 100644 --- a/doc/modules/qualitymetrics/firing_range.rst +++ b/doc/modules/qualitymetrics/firing_range.rst @@ -24,7 +24,7 @@ Example code import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - firing_range = sqm.compute_firing_ranges(wvf_extractor) + firing_range = sqm.compute_firing_ranges(waveform_extractor=wvf_extractor) # firing_range is a dict containing the unit IDs as keys, # and their firing firing_range as values (in Hz). diff --git a/doc/modules/qualitymetrics/firing_rate.rst b/doc/modules/qualitymetrics/firing_rate.rst index c0e15d7c2e..ef8cb3d8f4 100644 --- a/doc/modules/qualitymetrics/firing_rate.rst +++ b/doc/modules/qualitymetrics/firing_rate.rst @@ -40,7 +40,7 @@ With SpikeInterface: import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - firing_rate = sqm.compute_firing_rates(wvf_extractor) + firing_rate = sqm.compute_firing_rates(waveform_extractor=wvf_extractor) # firing_rate is a dict containing the unit IDs as keys, # and their firing rates across segments as values (in Hz). diff --git a/doc/modules/qualitymetrics/isolation_distance.rst b/doc/modules/qualitymetrics/isolation_distance.rst index 640a5a8b5a..6ba0d0b1ec 100644 --- a/doc/modules/qualitymetrics/isolation_distance.rst +++ b/doc/modules/qualitymetrics/isolation_distance.rst @@ -23,6 +23,16 @@ Expectation and use Isolation distance can be interpreted as a measure of distance from the cluster to the nearest other cluster. A well isolated unit should have a large isolation distance. +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as sqm + + iso_distance, _ = sqm.isolation_distance(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0) + + References ---------- diff --git a/doc/modules/qualitymetrics/l_ratio.rst b/doc/modules/qualitymetrics/l_ratio.rst index b37913ba58..ae31ab40a4 100644 --- a/doc/modules/qualitymetrics/l_ratio.rst +++ b/doc/modules/qualitymetrics/l_ratio.rst @@ -37,6 +37,17 @@ Since this metric identifies unit separation, a high value indicates a highly co A well separated unit should have a low L-ratio ([Schmitzer-Torbert]_ et al.). + +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as sqm + + _, l_ratio = sqm.isolation_distance(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0) + + References ---------- diff --git a/doc/modules/qualitymetrics/presence_ratio.rst b/doc/modules/qualitymetrics/presence_ratio.rst index 5a420c8ccf..ad0766d37c 100644 --- a/doc/modules/qualitymetrics/presence_ratio.rst +++ b/doc/modules/qualitymetrics/presence_ratio.rst @@ -27,7 +27,7 @@ Example code # Make recording, sorting and wvf_extractor object for your data. - presence_ratio = sqm.compute_presence_ratios(wvf_extractor) + presence_ratio = sqm.compute_presence_ratios(waveform_extractor=wvf_extractor) # presence_ratio is a dict containing the unit IDs as keys # and their presence ratio (between 0 and 1) as values. diff --git a/doc/modules/qualitymetrics/silhouette_score.rst b/doc/modules/qualitymetrics/silhouette_score.rst index b924cdbf73..7da01e0476 100644 --- a/doc/modules/qualitymetrics/silhouette_score.rst +++ b/doc/modules/qualitymetrics/silhouette_score.rst @@ -50,6 +50,16 @@ To reduce complexity the default implementation in SpikeInterface is to use the This can be changes by switching the silhouette method to either 'full' (the Rousseeuw implementation) or ('simplified', 'full') for both methods when entering the qm_params parameter. +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as sqm + + simple_sil_score = sqm.simplified_silhouette_score(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0) + + References ---------- diff --git a/doc/modules/qualitymetrics/sliding_rp_violations.rst b/doc/modules/qualitymetrics/sliding_rp_violations.rst index de68c3a92f..fd53d7da3b 100644 --- a/doc/modules/qualitymetrics/sliding_rp_violations.rst +++ b/doc/modules/qualitymetrics/sliding_rp_violations.rst @@ -31,7 +31,7 @@ With SpikeInterface: # Make recording, sorting and wvf_extractor object for your data. - contamination = sqm.compute_sliding_rp_violations(wvf_extractor, bin_size_ms=0.25) + contamination = sqm.compute_sliding_rp_violations(waveform_extractor=wvf_extractor, bin_size_ms=0.25) References ---------- diff --git a/doc/modules/qualitymetrics/snr.rst b/doc/modules/qualitymetrics/snr.rst index b88d3291be..7f27a5078a 100644 --- a/doc/modules/qualitymetrics/snr.rst +++ b/doc/modules/qualitymetrics/snr.rst @@ -44,8 +44,7 @@ With SpikeInterface: import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - - SNRs = sqm.compute_snrs(wvf_extractor) + SNRs = sqm.compute_snrs(waveform_extractor=wvf_extractor) # SNRs is a dict containing the unit IDs as keys and their SNRs as values. Links to original implementations diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index 0750940199..d1a3c70a97 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -29,7 +29,7 @@ Example code import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - synchrony = sqm.compute_synchrony_metrics(wvf_extractor, synchrony_sizes=(2, 4, 8)) + synchrony = sqm.compute_synchrony_metrics(waveform_extractor=wvf_extractor, synchrony_sizes=(2, 4, 8)) # synchrony is a tuple of dicts with the synchrony metrics for each unit diff --git a/doc/modules/sorters.rst b/doc/modules/sorters.rst index f3c8e7b733..5040b01ec2 100644 --- a/doc/modules/sorters.rst +++ b/doc/modules/sorters.rst @@ -49,15 +49,15 @@ to easily run spike sorters: from spikeinterface.sorters import run_sorter # run Tridesclous - sorting_TDC = run_sorter("tridesclous", recording, output_folder="/folder_TDC") + sorting_TDC = run_sorter(sorter_name="tridesclous", recording=recording, output_folder="/folder_TDC") # run Kilosort2.5 - sorting_KS2_5 = run_sorter("kilosort2_5", recording, output_folder="/folder_KS2.5") + sorting_KS2_5 = run_sorter(sorter_name="kilosort2_5", recording=recording, output_folder="/folder_KS2.5") # run IronClust - sorting_IC = run_sorter("ironclust", recording, output_folder="/folder_IC") + sorting_IC = run_sorter(sorter_name="ironclust", recording=recording, output_folder="/folder_IC") # run pyKilosort - sorting_pyKS = run_sorter("pykilosort", recording, output_folder="/folder_pyKS") + sorting_pyKS = run_sorter(sorter_name="pykilosort", recording=recording, output_folder="/folder_pyKS") # run SpykingCircus - sorting_SC = run_sorter("spykingcircus", recording, output_folder="/folder_SC") + sorting_SC = run_sorter(sorter_name="spykingcircus", recording=recording, output_folder="/folder_SC") Then the output, which is a :py:class:`~spikeinterface.core.BaseSorting` object, can be easily @@ -81,10 +81,10 @@ Spike-sorter-specific parameters can be controlled directly from the .. code-block:: python - sorting_TDC = run_sorter('tridesclous', recording, output_folder="/folder_TDC", + sorting_TDC = run_sorter(sorter_name='tridesclous', recording=recording, output_folder="/folder_TDC", detect_threshold=8.) - sorting_KS2_5 = run_sorter("kilosort2_5", recording, output_folder="/folder_KS2.5" + sorting_KS2_5 = run_sorter(sorter_name="kilosort2_5", recording=recording, output_folder="/folder_KS2.5" do_correction=False, preclust_threshold=6, freq_min=200.) @@ -185,7 +185,7 @@ The following code creates a test recording and runs a containerized spike sorte ) test_recording = test_recording.save(folder="test-docker-folder") - sorting = ss.run_sorter('kilosort3', + sorting = ss.run_sorter(sorter_name='kilosort3', recording=test_recording, output_folder="kilosort3", singularity_image=True) @@ -201,7 +201,7 @@ To run in Docker instead of Singularity, use ``docker_image=True``. .. code-block:: python - sorting = run_sorter('kilosort3', recording=test_recording, + sorting = run_sorter(sorter_name='kilosort3', recording=test_recording, output_folder="/tmp/kilosort3", docker_image=True) To use a specific image, set either ``docker_image`` or ``singularity_image`` to a string, @@ -209,7 +209,7 @@ e.g. ``singularity_image="spikeinterface/kilosort3-compiled-base:0.1.0"``. .. code-block:: python - sorting = run_sorter("kilosort3", + sorting = run_sorter(sorter_name="kilosort3", recording=test_recording, output_folder="kilosort3", singularity_image="spikeinterface/kilosort3-compiled-base:0.1.0") @@ -271,7 +271,7 @@ And use the custom image whith the :code:`run_sorter` function: .. code-block:: python - sorting = run_sorter("kilosort3", + sorting = run_sorter(sorter_name="kilosort3", recording=recording, docker_image="my-user/ks3-with-spikeinterface-test:0.1.0") @@ -302,7 +302,7 @@ an :code:`engine` that supports parallel processing (such as :code:`joblib` or : ] # run in loop - sortings = run_sorter_jobs(job_list, engine='loop') + sortings = run_sorter_jobs(job_list=job_list, engine='loop') @@ -314,11 +314,11 @@ an :code:`engine` that supports parallel processing (such as :code:`joblib` or : .. code-block:: python - run_sorter_jobs(job_list, engine='loop') + run_sorter_jobs(job_list=job_list, engine='loop') - run_sorter_jobs(job_list, engine='joblib', engine_kwargs={'n_jobs': 2}) + run_sorter_jobs(job_list=job_list, engine='joblib', engine_kwargs={'n_jobs': 2}) - run_sorter_jobs(job_list, engine='slurm', engine_kwargs={'cpus_per_task': 10, 'mem', '5G'}) + run_sorter_jobs(job_list=job_list, engine='slurm', engine_kwargs={'cpus_per_task': 10, 'mem': '5G'}) Spike sorting by group @@ -374,7 +374,7 @@ In this example, we create a 16-channel recording with 4 tetrodes: # here the result is a dict of a sorting object sortings = {} for group, sub_recording in recordings.items(): - sorting = run_sorter('kilosort2', recording, output_folder=f"folder_KS2_group{group}") + sorting = run_sorter(sorter_name='kilosort2', recording=recording, output_folder=f"folder_KS2_group{group}") sortings[group] = sorting **Option 2 : Automatic splitting** @@ -382,7 +382,7 @@ In this example, we create a 16-channel recording with 4 tetrodes: .. code-block:: python # here the result is one sorting that aggregates all sub sorting objects - aggregate_sorting = run_sorter_by_property('kilosort2', recording_4_tetrodes, + aggregate_sorting = run_sorter_by_property(sorter_name='kilosort2', recording=recording_4_tetrodes, grouping_property='group', working_folder='working_path') @@ -421,7 +421,7 @@ do not handle multi-segment, and in that case we will use the # multirecording has 4 segments of 10s each # run tridesclous in multi-segment mode - multisorting = si.run_sorter('tridesclous', multirecording) + multisorting = si.run_sorter(sorter_name='tridesclous', recording=multirecording) print(multisorting) # Case 2: the sorter DOES NOT handle multi-segment objects @@ -433,7 +433,7 @@ do not handle multi-segment, and in that case we will use the # multirecording has 1 segment of 40s each # run mountainsort4 in mono-segment mode - multisorting = si.run_sorter('mountainsort4', multirecording) + multisorting = si.run_sorter(sorter_name='mountainsort4', recording=multirecording) See also the :ref:`multi_seg` section. @@ -507,7 +507,7 @@ message will appear indicating how to install the given sorter, .. code:: python - recording = run_sorter('ironclust', recording) + recording = run_sorter(sorter_name='ironclust', recording=recording) throws the error, @@ -540,7 +540,7 @@ From the user's perspective, they behave exactly like the external sorters: .. code-block:: python - sorting = run_sorter("spykingcircus2", recording, "/tmp/folder") + sorting = run_sorter(sorter_name="spykingcircus2", recording=recording, output_folder="/tmp/folder") Contributing diff --git a/doc/modules/sortingcomponents.rst b/doc/modules/sortingcomponents.rst index 422eaea890..f3371f7e7b 100644 --- a/doc/modules/sortingcomponents.rst +++ b/doc/modules/sortingcomponents.rst @@ -47,7 +47,8 @@ follows: job_kwargs = dict(chunk_duration='1s', n_jobs=8, progress_bar=True) peaks = detect_peaks( - recording, method='by_channel', + recording=recording, + method='by_channel', peak_sign='neg', detect_threshold=5, exclude_sweep_ms=0.2, @@ -94,7 +95,7 @@ follows: job_kwargs = dict(chunk_duration='1s', n_jobs=8, progress_bar=True) - peak_locations = localize_peaks(recording, peaks, method='center_of_mass', + peak_locations = localize_peaks(recording=recording, peaks=peaks, method='center_of_mass', radius_um=70., ms_before=0.3, ms_after=0.6, **job_kwargs) @@ -122,7 +123,7 @@ For instance, the 'monopolar_triangulation' method will have: .. note:: - By convention in SpikeInterface, when a probe is described in 2d + By convention in SpikeInterface, when a probe is described in 3d * **'x'** is the width of the probe * **'y'** is the depth * **'z'** is orthogonal to the probe plane @@ -144,11 +145,11 @@ can be *hidden* by this process. from spikeinterface.sortingcomponents.peak_detection import detect_peaks - many_peaks = detect_peaks(...) + many_peaks = detect_peaks(...) # as in above example from spikeinterface.sortingcomponents.peak_selection import select_peaks - some_peaks = select_peaks(many_peaks, method='uniform', n_peaks=10000) + some_peaks = select_peaks(peaks=many_peaks, method='uniform', n_peaks=10000) Implemented methods are the following: @@ -183,15 +184,15 @@ Here is an example with non-rigid motion estimation: .. code-block:: python from spikeinterface.sortingcomponents.peak_detection import detect_peaks - peaks = detect_peaks(recording, ...) + peaks = detect_peaks(recording=ecording, ...) # as in above example from spikeinterface.sortingcomponents.peak_localization import localize_peaks - peak_locations = localize_peaks(recording, peaks, ...) + peak_locations = localize_peaks(recording=recording, peaks=peaks, ...) # as above from spikeinterface.sortingcomponents.motion_estimation import estimate_motion motion, temporal_bins, spatial_bins, - extra_check = estimate_motion(recording, peaks, peak_locations=peak_locations, + extra_check = estimate_motion(recording=recording, peaks=peaks, peak_locations=peak_locations, direction='y', bin_duration_s=10., bin_um=10., margin_um=0., method='decentralized_registration', rigid=False, win_shape='gaussian', win_step_um=50., win_sigma_um=150., @@ -217,7 +218,7 @@ Here is a short example that depends on the output of "Motion interpolation": from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording - recording_corrected = InterpolateMotionRecording(recording_with_drift, motion, temporal_bins, spatial_bins + recording_corrected = InterpolateMotionRecording(recording=recording_with_drift, motion=motion, temporal_bins=temporal_bins, spatial_bins=spatial_bins spatial_interpolation_method='kriging, border_mode='remove_channels') @@ -255,10 +256,10 @@ Different methods may need different inputs (for instance some of them require p .. code-block:: python from spikeinterface.sortingcomponents.peak_detection import detect_peaks - peaks = detect_peaks(recording, ...) + peaks = detect_peaks(recording, ...) # as in above example from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks - labels, peak_labels = find_cluster_from_peaks(recording, peaks, method="sliding_hdbscan") + labels, peak_labels = find_cluster_from_peaks(recording=recording, peaks=peaks, method="sliding_hdbscan") * **labels** : contains all possible labels diff --git a/doc/modules/widgets.rst b/doc/modules/widgets.rst index 8565e94fce..f37b2a5a6f 100644 --- a/doc/modules/widgets.rst +++ b/doc/modules/widgets.rst @@ -148,7 +148,7 @@ The :code:`plot_*(..., backend="matplotlib")` functions come with the following .. code-block:: python # matplotlib backend - w = plot_traces(recording, backend="matplotlib") + w = plot_traces(recording=recording, backend="matplotlib") **Output:** @@ -173,7 +173,7 @@ Each function has the following additional arguments: # ipywidgets backend also supports multiple "layers" for plot_traces rec_dict = dict(filt=recording, cmr=common_reference(recording)) - w = sw.plot_traces(rec_dict, backend="ipywidgets") + w = sw.plot_traces(recording=rec_dict, backend="ipywidgets") **Output:** @@ -196,8 +196,8 @@ The functions have the following additional arguments: .. code-block:: python # sortingview backend - w_ts = sw.plot_traces(recording, backend="ipywidgets") - w_ss = sw.plot_sorting_summary(recording, backend="sortingview") + w_ts = sw.plot_traces(recording=recording, backend="ipywidgets") + w_ss = sw.plot_sorting_summary(recording=recording, backend="sortingview") **Output:** @@ -249,7 +249,7 @@ The :code:`ephyviewer` backend is currently only available for the :py:func:`~sp .. code-block:: python - plot_traces(recording, backend="ephyviewer", mode="line", show_channel_ids=True) + plot_traces(recording=recording, backend="ephyviewer", mode="line", show_channel_ids=True) .. image:: ../images/plot_traces_ephyviewer.png From 5140a0423f8c33e3ba6906d48169508585e19807 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Sep 2023 20:32:49 +0000 Subject: [PATCH 20/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- doc/modules/extractors.rst | 2 +- doc/modules/motion_correction.rst | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/modules/extractors.rst b/doc/modules/extractors.rst index 1eeca9a325..ccc5d2a311 100644 --- a/doc/modules/extractors.rst +++ b/doc/modules/extractors.rst @@ -6,7 +6,7 @@ Overview The :py:mod:`~spikeinterface.extractors` module allows you to load :py:class:`~spikeinterface.core.BaseRecording`, :py:class:`~spikeinterface.core.BaseSorting`, and :py:class:`~spikeinterface.core.BaseEvent` objects from -a large variety of acquisition systems and spike sorting outputs. +a large variety of acquisition systems and spike sorting outputs. Most of the :code:`Recording` classes are implemented by wrapping the `NEO rawio implementation `_. diff --git a/doc/modules/motion_correction.rst b/doc/modules/motion_correction.rst index 96ecc1fcec..e009e06236 100644 --- a/doc/modules/motion_correction.rst +++ b/doc/modules/motion_correction.rst @@ -163,8 +163,8 @@ The high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` is inte max_distance_um=150.0, **job_kwargs) # Step 2: motion inference - motion, temporal_bins, spatial_bins = estimate_motion(recording=rec, - peaks=peaks, + motion, temporal_bins, spatial_bins = estimate_motion(recording=rec, + peaks=peaks, peak_locations=peak_locations, method="decentralized", direction="y", @@ -175,8 +175,8 @@ The high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` is inte # Step 3: motion interpolation # this step is lazy - rec_corrected = interpolate_motion(recording=rec, motion=motion, - temporal_bins=temporal_bins, + rec_corrected = interpolate_motion(recording=rec, motion=motion, + temporal_bins=temporal_bins, spatial_bins=spatial_bins, border_mode="remove_channels", spatial_interpolation_method="kriging", From 714645c4fcf359612d2ba31ca4f79fbfd42165c4 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 29 Sep 2023 16:40:37 -0400 Subject: [PATCH 21/74] fix -> dict --- doc/modules/motion_correction.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/modules/motion_correction.rst b/doc/modules/motion_correction.rst index 96ecc1fcec..8cffeebcf3 100644 --- a/doc/modules/motion_correction.rst +++ b/doc/modules/motion_correction.rst @@ -107,7 +107,7 @@ Optionally any parameter from the preset can be overwritten: rec_corrected = correct_motion(recording=rec, preset="nonrigid_accurate", detect_kwargs=dict( detect_threshold=10.), - estimate_motion_kwargs=dic( + estimate_motion_kwargs=dict( histogram_depth_smooth_um=8., time_horizon_s=120., ), From f76e9d895a321eceb8dd6e01f0e3fe769867ec16 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 2 Oct 2023 10:14:50 +0200 Subject: [PATCH 22/74] Update src/spikeinterface/curation/sortingview_curation.py --- src/spikeinterface/curation/sortingview_curation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index 7a573c38c4..626ea79eb9 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -94,7 +94,6 @@ def apply_sortingview_curation( # Populate the properties dictionary for unit_index, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): unit_id_str = str(unit_id) - # Check for exact match first if unit_id_str in labels_dict: for label in labels_dict[unit_id_str]: properties[label][unit_index] = True From c20ffdadb908d601e546323b113e994445546891 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 2 Oct 2023 10:23:47 +0200 Subject: [PATCH 23/74] Tiny rewrite in tests --- src/spikeinterface/curation/tests/test_sortingview_curation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 22085f2f77..ce6c7dd5a6 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -174,8 +174,9 @@ def test_label_inheritance_int(): duration = 20.0 num_timepoints = int(sampling_frequency * duration) num_spikes = 1000 + num_units = 7 times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) - labels = np.random.randint(1, 8, size=num_spikes) # 7 units: 1 to 7 + labels = np.random.randint(1, 1 + num_units, size=num_spikes) # 7 units: 1 to 7 sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) From bbc81676fbcec04cb7ce9d6f93da60ed1afb0df5 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 2 Oct 2023 11:00:38 +0200 Subject: [PATCH 24/74] Minor fixes for SC2 and study --- src/spikeinterface/comparison/groundtruthstudy.py | 2 +- .../sortingcomponents/clustering/clustering_tools.py | 2 +- src/spikeinterface/sortingcomponents/matching/circus.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index d43727cb44..df0b5296c0 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -180,7 +180,7 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True if sorting_exists: # delete older sorting + log before running sorters - shutil.rmtree(sorting_exists) + shutil.rmtree(sorting_folder) log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" if log_file.exists(): log_file.unlink() diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 1a8332ad6d..891c355448 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -661,7 +661,7 @@ def remove_duplicates_via_matching( labels = np.unique(new_labels) labels = labels[labels >= 0] - del recording, sub_recording + del recording, sub_recording, method_kwargs os.remove(tmp_filename) return labels, new_labels diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 358691cd25..ea36b75847 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -592,6 +592,7 @@ def _prepare_templates(cls, d): d["spatial"] = np.moveaxis(d["spatial"], [0, 1, 2], [1, 0, 2]) d["temporal"] = np.moveaxis(d["temporal"], [0, 1, 2], [1, 2, 0]) d["singular"] = d["singular"].T[:, :, np.newaxis] + return d @classmethod From 6ceee13abe776ceec65dd6239f5f97fbca1096a4 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Mon, 2 Oct 2023 05:08:25 -0400 Subject: [PATCH 25/74] Alessio fixes Co-authored-by: Alessio Buccino --- doc/modules/exporters.rst | 2 +- doc/modules/extractors.rst | 10 ---------- doc/modules/sortingcomponents.rst | 2 +- 3 files changed, 2 insertions(+), 12 deletions(-) diff --git a/doc/modules/exporters.rst b/doc/modules/exporters.rst index 1d23f9ad6f..155050ddb0 100644 --- a/doc/modules/exporters.rst +++ b/doc/modules/exporters.rst @@ -31,7 +31,7 @@ The input of the :py:func:`~spikeinterface.exporters.export_to_phy` is a :code:` we = extract_waveforms(recording=recording, sorting=sorting, folder='waveforms', sparse=True) # some computations are done before to control all options - compute_spike_amplitudes(waveform_extractor = we) + compute_spike_amplitudes(waveform_extractor=we) compute_principal_components(waveform_extractor=we, n_components=3, mode='by_channel_global') # the export process is fast because everything is pre-computed diff --git a/doc/modules/extractors.rst b/doc/modules/extractors.rst index ccc5d2a311..2d0e047672 100644 --- a/doc/modules/extractors.rst +++ b/doc/modules/extractors.rst @@ -48,16 +48,6 @@ Importantly, some formats directly handle the probe information: print(recording_mearec.get_probe()) -Although most recordings are loaded with the :py:mod:`~spikeinterface.extractors` -a few file formats are loaded from the :py:mod:`~spikeinterface.core` module - -.. code-block:: python - - import spikeinterface as si - - recording_binary = si.read_binary(file_path='binary.bin') - - recording_zarr = si.read_zarr(file_path='zarr_file.zarr') Read one Sorting diff --git a/doc/modules/sortingcomponents.rst b/doc/modules/sortingcomponents.rst index f3371f7e7b..1e58972497 100644 --- a/doc/modules/sortingcomponents.rst +++ b/doc/modules/sortingcomponents.rst @@ -184,7 +184,7 @@ Here is an example with non-rigid motion estimation: .. code-block:: python from spikeinterface.sortingcomponents.peak_detection import detect_peaks - peaks = detect_peaks(recording=ecording, ...) # as in above example + peaks = detect_peaks(recording=recording, ...) # as in above example from spikeinterface.sortingcomponents.peak_localization import localize_peaks peak_locations = localize_peaks(recording=recording, peaks=peaks, ...) # as above From 5cefdacc3674162155b5eaa3a612b5cc2ca79675 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 2 Oct 2023 13:48:22 +0200 Subject: [PATCH 26/74] Fixes to MDASortingExtractor --- src/spikeinterface/extractors/mdaextractors.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index b863e338fa..1eb0182318 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -216,14 +216,14 @@ def write_sorting(sorting, save_path, write_primary_channels=False): times_list = [] labels_list = [] primary_channels_list = [] - for unit_id_i, unit_id in enumerate(unit_ids): + for unit_index, unit_id in enumerate(unit_ids): times = sorting.get_unit_spike_train(unit_id=unit_id) times_list.append(times) # unit id may not be numeric - if unit_id.dtype.kind in "biufc": - labels_list.append(np.ones(times.shape) * unit_id) + if unit_id.dtype.kind in "iu": + labels_list.append(np.ones(times.shape, dtype=unit_id.dtype) * unit_id) else: - labels_list.append(np.ones(times.shape) * unit_id_i) + labels_list.append(np.ones(times.shape, dtype=int) * unit_index) if write_primary_channels: if "max_channel" in sorting.get_unit_property_names(unit_id): primary_channels_list.append([sorting.get_unit_property(unit_id, "max_channel")] * times.shape[0]) From c06df711a3dfb0f08d6eb8718147210be0c144c6 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Mon, 2 Oct 2023 08:36:42 -0400 Subject: [PATCH 27/74] add pypi docs and dev docs --- README.md | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 55f33d04b1..883dcdb944 100644 --- a/README.md +++ b/README.md @@ -59,15 +59,17 @@ With SpikeInterface, users can: - post-process sorted datasets. - compare and benchmark spike sorting outputs. - compute quality metrics to validate and curate spike sorting outputs. -- visualize recordings and spike sorting outputs in several ways (matplotlib, sortingview, in jupyter) -- export report and export to phy -- offer a powerful Qt-based viewer in separate package [spikeinterface-gui](https://github.com/SpikeInterface/spikeinterface-gui) -- have some powerful sorting components to build your own sorter. +- visualize recordings and spike sorting outputs in several ways (matplotlib, sortingview, jupyter, ephyviewer) +- export a report and/or export to phy +- offer a powerful Qt-based viewer in a separate package [spikeinterface-gui](https://github.com/SpikeInterface/spikeinterface-gui) +- have powerful sorting components to build your own sorter. ## Documentation -Detailed documentation for spikeinterface can be found [here](https://spikeinterface.readthedocs.io/en/latest). +Detailed documentation of the latest PyPI release of SpikeInterface can be found [here](https://spikeinterface.readthedocs.io/en/0.98.2). + +Detailed documentation of the development version of SpikeInterface can be found [here](https://spikeinterface.readthedocs.io/en/latest). Several tutorials to get started can be found in [spiketutorials](https://github.com/SpikeInterface/spiketutorials). @@ -77,9 +79,9 @@ and sorting components. You can also have a look at the [spikeinterface-gui](https://github.com/SpikeInterface/spikeinterface-gui). -## How to install spikeinteface +## How to install spikeinterface -You can install the new `spikeinterface` version with pip: +You can install the latest version of `spikeinterface` version with pip: ```bash pip install spikeinterface[full] @@ -94,7 +96,7 @@ To install all interactive widget backends, you can use: ``` -To get the latest updates, you can install `spikeinterface` from sources: +To get the latest updates, you can install `spikeinterface` from source: ```bash git clone https://github.com/SpikeInterface/spikeinterface.git From cf65301c82c48e72e10a77f6a7f891453b69e409 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 2 Oct 2023 15:24:30 +0200 Subject: [PATCH 28/74] Check main_ids are ints or strings --- src/spikeinterface/core/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 8b4f094c20..86692fa69c 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -47,6 +47,7 @@ def __init__(self, main_ids: Sequence) -> None: # 'main_ids' will either be channel_ids or units_ids # They is used for properties self._main_ids = np.array(main_ids) + assert self._main_ids.dtype.kind in "uiSU", "Main IDs can only be integers (signed/unsigned) or strings" # dict at object level self._annotations = {} From 8343d3a70a6bb3cf56f3013abc77c8e534059150 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 2 Oct 2023 15:57:54 +0200 Subject: [PATCH 29/74] Fix NpySnippets --- src/spikeinterface/core/base.py | 3 ++- src/spikeinterface/core/baserecordingsnippets.py | 4 ++-- src/spikeinterface/core/basesnippets.py | 2 -- src/spikeinterface/core/npysnippetsextractor.py | 5 ++++- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 86692fa69c..f1a51c99d1 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -47,7 +47,8 @@ def __init__(self, main_ids: Sequence) -> None: # 'main_ids' will either be channel_ids or units_ids # They is used for properties self._main_ids = np.array(main_ids) - assert self._main_ids.dtype.kind in "uiSU", "Main IDs can only be integers (signed/unsigned) or strings" + if len(self._main_ids) > 0: + assert self._main_ids.dtype.kind in "uiSU", "Main IDs can only be integers (signed/unsigned) or strings" # dict at object level self._annotations = {} diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index affde8a75e..d411f38d2a 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -1,4 +1,4 @@ -from typing import List +from __future__ import annotations from pathlib import Path import numpy as np @@ -19,7 +19,7 @@ class BaseRecordingSnippets(BaseExtractor): has_default_locations = False - def __init__(self, sampling_frequency: float, channel_ids: List, dtype): + def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype: np.dtype): BaseExtractor.__init__(self, channel_ids) self._sampling_frequency = sampling_frequency self._dtype = np.dtype(dtype) diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index f35bc2b266..b4e3c11f55 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -1,10 +1,8 @@ from typing import List, Union -from pathlib import Path from .base import BaseSegment from .baserecordingsnippets import BaseRecordingSnippets import numpy as np from warnings import warn -from probeinterface import Probe, ProbeGroup, write_probeinterface, read_probeinterface, select_axes # snippets segments? diff --git a/src/spikeinterface/core/npysnippetsextractor.py b/src/spikeinterface/core/npysnippetsextractor.py index 80979ce6c9..69c48356e5 100644 --- a/src/spikeinterface/core/npysnippetsextractor.py +++ b/src/spikeinterface/core/npysnippetsextractor.py @@ -27,6 +27,9 @@ def __init__( num_segments = len(file_paths) data = np.load(file_paths[0], mmap_mode="r") + if channel_ids is None: + channel_ids = np.arange(data["snippet"].shape[2]) + BaseSnippets.__init__( self, sampling_frequency, @@ -84,7 +87,7 @@ def write_snippets(snippets, file_paths, dtype=None): arr = np.empty(n, dtype=snippets_t, order="F") arr["frame"] = snippets.get_frames(segment_index=i) arr["snippet"] = snippets.get_snippets(segment_index=i).astype(dtype, copy=False) - + file_paths[i].parent.mkdir(parents=True, exist_ok=True) np.save(file_paths[i], arr) From 89d1f827c445702a61eda864c9972401567a9b67 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 2 Oct 2023 16:26:25 +0200 Subject: [PATCH 30/74] Force CellExplorer unit ids as int --- src/spikeinterface/core/base.py | 4 +++- src/spikeinterface/extractors/cellexplorersortingextractor.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index f1a51c99d1..1116aeb507 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -48,7 +48,9 @@ def __init__(self, main_ids: Sequence) -> None: # They is used for properties self._main_ids = np.array(main_ids) if len(self._main_ids) > 0: - assert self._main_ids.dtype.kind in "uiSU", "Main IDs can only be integers (signed/unsigned) or strings" + assert ( + self._main_ids.dtype.kind in "uiSU" + ), f"Main IDs can only be integers (signed/unsigned) or strings, not {self._main_ids.dtype}" # dict at object level self._annotations = {} diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index 31241a4147..f72670fbcd 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -118,7 +118,7 @@ def __init__( spike_times = spikes_data["times"] # CellExplorer reports spike times in units seconds; SpikeExtractors uses time units of sampling frames - unit_ids = unit_ids[:].tolist() + unit_ids = unit_ids[:].astype(int).tolist() spiketrains_dict = {unit_id: spike_times[index] for index, unit_id in enumerate(unit_ids)} for unit_id in unit_ids: spiketrains_dict[unit_id] = (sampling_frequency * spiketrains_dict[unit_id]).round().astype(np.int64) From d75f0588707da10a61e926e337334739a0b9a20b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Oct 2023 10:58:15 +0200 Subject: [PATCH 31/74] Update src/spikeinterface/extractors/cellexplorersortingextractor.py Co-authored-by: Heberto Mayorquin --- src/spikeinterface/extractors/cellexplorersortingextractor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index f72670fbcd..0096a40a79 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -119,6 +119,7 @@ def __init__( # CellExplorer reports spike times in units seconds; SpikeExtractors uses time units of sampling frames unit_ids = unit_ids[:].astype(int).tolist() + unit_ids = [str(unit_id) for unit_id in unit_ids] spiketrains_dict = {unit_id: spike_times[index] for index, unit_id in enumerate(unit_ids)} for unit_id in unit_ids: spiketrains_dict[unit_id] = (sampling_frequency * spiketrains_dict[unit_id]).round().astype(np.int64) From 1939b936e94d30c8437633f89c49fd006ca71a80 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 4 Oct 2023 10:19:11 +0200 Subject: [PATCH 32/74] Diff for SC2 --- src/spikeinterface/sorters/internal/spyking_circus2.py | 7 ++++--- .../sortingcomponents/clustering/clustering_tools.py | 7 +++++-- .../sortingcomponents/clustering/random_projections.py | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a0a4d0823c..db06287f6c 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -6,7 +6,7 @@ from spikeinterface.core import NumpySorting, load_extractor, BaseRecording, get_noise_levels, extract_waveforms from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore +from spikeinterface.preprocessing import common_reference, zscore, whiten, highpass_filter try: import hdbscan @@ -22,7 +22,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, - "filtering": {"dtype": "float32"}, + "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 5}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, "localization": {}, @@ -60,11 +60,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## First, we are filtering the data filtering_params = params["filtering"].copy() if params["apply_preprocessing"]: - recording_f = bandpass_filter(recording, **filtering_params) + recording_f = highpass_filter(recording, **filtering_params) recording_f = common_reference(recording_f) else: recording_f = recording + #recording_f = whiten(recording_f, dtype="float32") recording_f = zscore(recording_f, dtype="float32") ## Then, we are detecting peaks with a locally_exclusive method diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 891c355448..6dba4b7f0f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -598,14 +598,17 @@ def remove_duplicates_via_matching( "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, "amplitudes": [0.95, 1.05], - "omp_min_sps": 0.1, + "omp_min_sps": 0.05, } ) + spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()['unit_index'], return_counts=True) + indices = np.argsort(counts) + ignore_ids = [] similar_templates = [[], []] - for i in range(nb_templates): + for i in np.arange(nb_templates)[indices]: t_start = padding + i * duration t_stop = padding + (i + 1) * duration diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 1f97bf5201..d7ceef2561 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -33,7 +33,7 @@ class RandomProjectionClustering: "min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": os.cpu_count(), - "cluster_selection_method": "leaf", + "cluster_selection_method": "leaf" }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, From a46994f5ea58e4359ef0a514bae9cd96dc2bf5f8 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 4 Oct 2023 13:43:48 +0200 Subject: [PATCH 33/74] waveform extactor reload --- src/spikeinterface/core/waveform_extractor.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 2710ff1338..6d9e5d41e3 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -175,7 +175,13 @@ def load_from_folder( rec_attributes = None if sorting is None: - sorting = load_extractor(folder / "sorting.json", base_folder=folder) + if (folder / "sorting.json").exists(): + sorting = load_extractor(folder / "sorting.json", base_folder=folder) + elif (folder / "sorting.pickle").exists(): + sorting = load_extractor(folder / "sorting.pickle") + else: + raise FileNotFoundError("load_waveforms() impossible to find the sorting object (json or pickle)") + # the sparsity is the sparsity of the saved/cached waveforms arrays sparsity_file = folder / "sparsity.json" From 7d9c0753fb3c59577dd244d3c9bce1d6272015e6 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 4 Oct 2023 14:11:54 +0200 Subject: [PATCH 34/74] WIP --- src/spikeinterface/preprocessing/remove_artifacts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 7e84822c61..8e72b96c6d 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -1,4 +1,5 @@ import numpy as np +import scipy from spikeinterface.core.core_tools import define_function_from_class From 87a9dc964d59530267ed5be8b297a08b35427b75 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 4 Oct 2023 15:36:07 +0200 Subject: [PATCH 35/74] yep --- src/spikeinterface/core/generate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 06a5ec96ec..9d656db977 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1405,6 +1405,7 @@ def generate_ground_truth_recording( assert sorting.sampling_frequency == sampling_frequency num_spikes = sorting.to_spike_vector().size + if probe is None: probe = generate_linear_probe(num_elec=num_channels) probe.set_device_channel_indices(np.arange(num_channels)) From 0c97fc46adfb8c19683285ab77338ea9e103ac25 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Oct 2023 13:37:30 +0000 Subject: [PATCH 36/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/generate.py | 1 - src/spikeinterface/core/waveform_extractor.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 9d656db977..06a5ec96ec 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1405,7 +1405,6 @@ def generate_ground_truth_recording( assert sorting.sampling_frequency == sampling_frequency num_spikes = sorting.to_spike_vector().size - if probe is None: probe = generate_linear_probe(num_elec=num_channels) probe.set_device_channel_indices(np.arange(num_channels)) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 6d9e5d41e3..576a0a1a58 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -182,7 +182,6 @@ def load_from_folder( else: raise FileNotFoundError("load_waveforms() impossible to find the sorting object (json or pickle)") - # the sparsity is the sparsity of the saved/cached waveforms arrays sparsity_file = folder / "sparsity.json" if sparsity_file.is_file(): From 86b2271df55b671b49cd5b58601df94ab0dd2109 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 4 Oct 2023 16:03:51 +0200 Subject: [PATCH 37/74] Change some default parameters for better user experience. --- src/spikeinterface/core/waveform_extractor.py | 8 ++++---- src/spikeinterface/postprocessing/correlograms.py | 4 ++-- src/spikeinterface/postprocessing/unit_localization.py | 2 +- src/spikeinterface/sorters/runsorter.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 6d9e5d41e3..1c6002226f 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -1458,13 +1458,13 @@ def extract_waveforms( folder=None, mode="folder", precompute_template=("average",), - ms_before=3.0, - ms_after=4.0, + ms_before=1.0, + ms_after=2.0, max_spikes_per_unit=500, overwrite=False, return_scaled=True, dtype=None, - sparse=False, + sparse=True, sparsity=None, num_spikes_for_sparsity=100, allow_unfiltered=False, @@ -1508,7 +1508,7 @@ def extract_waveforms( If True and recording has gain_to_uV/offset_to_uV properties, waveforms are converted to uV. dtype: dtype or None Dtype of the output waveforms. If None, the recording dtype is maintained. - sparse: bool (default False) + sparse: bool (default True) If True, before extracting all waveforms the `precompute_sparsity()` function is run using a few spikes to get an estimate of dense templates to create a ChannelSparsity object. Then, the waveforms will be sparse at extraction time, which saves a lot of memory. diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 6cd5238abd..6e693635eb 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -137,8 +137,8 @@ def compute_crosscorrelogram_from_spiketrain(spike_times1, spike_times2, window_ def compute_correlograms( waveform_or_sorting_extractor, load_if_exists=False, - window_ms: float = 100.0, - bin_ms: float = 5.0, + window_ms: float = 50.0, + bin_ms: float = 1.0, method: str = "auto", ): """Compute auto and cross correlograms. diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index d2739f69dd..48ceb34a4e 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -96,7 +96,7 @@ def get_extension_function(): def compute_unit_locations( - waveform_extractor, load_if_exists=False, method="center_of_mass", outputs="numpy", **method_kwargs + waveform_extractor, load_if_exists=False, method="monopolar_triangulation", outputs="numpy", **method_kwargs ): """ Localize units in 2D or 3D with several methods given the template. diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 9bacd8e2c9..a49a605a75 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -91,7 +91,7 @@ def run_sorter( sorter_name: str, recording: BaseRecording, output_folder: Optional[str] = None, - remove_existing_folder: bool = True, + remove_existing_folder: bool = False, delete_output_folder: bool = False, verbose: bool = False, raise_error: bool = True, From e97005aa5e94328cee3d97097b98d6a7289ee437 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 4 Oct 2023 16:21:54 +0200 Subject: [PATCH 38/74] Patch for scipy --- src/spikeinterface/preprocessing/remove_artifacts.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 8e72b96c6d..1746b23941 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -1,5 +1,4 @@ import numpy as np -import scipy from spikeinterface.core.core_tools import define_function_from_class @@ -108,8 +107,6 @@ def __init__( time_jitter=0, waveforms_kwargs={"allow_unfiltered": True, "mode": "memory"}, ): - import scipy.interpolate - available_modes = ("zeros", "linear", "cubic", "average", "median") num_seg = recording.get_num_segments() @@ -237,7 +234,6 @@ def __init__( time_pad, sparsity, ): - import scipy.interpolate BasePreprocessorSegment.__init__(self, parent_recording_segment) @@ -255,6 +251,8 @@ def __init__( self.sparsity = sparsity def get_traces(self, start_frame, end_frame, channel_indices): + + if self.mode in ["average", "median"]: traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) else: @@ -286,6 +284,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): elif trig + pad[1] >= end_frame - start_frame: traces[trig - pad[0] :, :] = 0 elif self.mode in ["linear", "cubic"]: + import scipy.interpolate for trig in triggers: if pad is None: pre_data_end_idx = trig - 1 From 2a5e37c83054999514ccacd45b3c81d1865bc196 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Oct 2023 14:23:26 +0000 Subject: [PATCH 39/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/remove_artifacts.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 1746b23941..1eafa48a0b 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -234,7 +234,6 @@ def __init__( time_pad, sparsity, ): - BasePreprocessorSegment.__init__(self, parent_recording_segment) self.triggers = np.asarray(triggers, dtype="int64") @@ -251,8 +250,6 @@ def __init__( self.sparsity = sparsity def get_traces(self, start_frame, end_frame, channel_indices): - - if self.mode in ["average", "median"]: traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) else: @@ -285,6 +282,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces[trig - pad[0] :, :] = 0 elif self.mode in ["linear", "cubic"]: import scipy.interpolate + for trig in triggers: if pad is None: pre_data_end_idx = trig - 1 From d9803d43e9598810337d11d2e68414261dbc3b81 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 4 Oct 2023 17:07:05 +0200 Subject: [PATCH 40/74] oups --- src/spikeinterface/core/waveform_extractor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index d83b3d66f1..eb027faf81 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -1726,6 +1726,7 @@ def precompute_sparsity( max_spikes_per_unit=num_spikes_for_sparsity, return_scaled=False, allow_unfiltered=allow_unfiltered, + sparse=False, **job_kwargs, ) local_sparsity = compute_sparsity(local_we, **sparse_kwargs) From 590cd6ba2440569469859a0e08ce321a5320e27d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 4 Oct 2023 21:04:26 +0200 Subject: [PATCH 41/74] small fix --- src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py index d25f1ea97b..364fc298c6 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py @@ -43,6 +43,8 @@ def plot(self): self._do_plot() def _do_plot(self): + from matplotlib import pyplot as plt + fig = self.figure for ax in fig.axes: From 204c8e90fd44d56e4b5eb6b0b7e92f09ea18db91 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 4 Oct 2023 21:08:17 +0200 Subject: [PATCH 42/74] fix waveform extactor with empty sorting and sparse --- src/spikeinterface/core/sparsity.py | 6 +++++- src/spikeinterface/core/tests/test_waveform_extractor.py | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 8c5c62d568..896e3800d7 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -102,7 +102,11 @@ def __init__(self, mask, unit_ids, channel_ids): self.num_channels = self.channel_ids.size self.num_units = self.unit_ids.size - self.max_num_active_channels = self.mask.sum(axis=1).max() + if self.mask.shape[0]: + self.max_num_active_channels = self.mask.sum(axis=1).max() + else: + # empty sorting without units + self.max_num_active_channels = 0 def __repr__(self): density = np.mean(self.mask) diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 2bbf5e9b0f..00244f600b 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -556,4 +556,5 @@ def test_non_json_object(): # test_portability() # test_recordingless() # test_compute_sparsity() - test_non_json_object() + # test_non_json_object() + test_empty_sorting() From 4cd3747786728e2942bef43b5c9d5ecba8d102fb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 06:25:31 +0000 Subject: [PATCH 43/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/clustering/clustering_tools.py | 2 +- .../sortingcomponents/clustering/random_projections.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index db06287f6c..6cf925e852 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -65,7 +65,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: recording_f = recording - #recording_f = whiten(recording_f, dtype="float32") + # recording_f = whiten(recording_f, dtype="float32") recording_f = zscore(recording_f, dtype="float32") ## Then, we are detecting peaks with a locally_exclusive method diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 6dba4b7f0f..72cfd71791 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -602,7 +602,7 @@ def remove_duplicates_via_matching( } ) - spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()['unit_index'], return_counts=True) + spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()["unit_index"], return_counts=True) indices = np.argsort(counts) ignore_ids = [] diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index d7ceef2561..1f97bf5201 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -33,7 +33,7 @@ class RandomProjectionClustering: "min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": os.cpu_count(), - "cluster_selection_method": "leaf" + "cluster_selection_method": "leaf", }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, From 22c0eb426507be87790cbcd68427e3d3764721ee Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 5 Oct 2023 08:29:18 +0200 Subject: [PATCH 44/74] Fix bug while reloading --- .../sortingcomponents/clustering/clustering_tools.py | 2 +- .../sortingcomponents/clustering/random_projections.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 6dba4b7f0f..d94345f56b 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -664,7 +664,7 @@ def remove_duplicates_via_matching( labels = np.unique(new_labels) labels = labels[labels >= 0] - del recording, sub_recording, method_kwargs + del recording, sub_recording, method_kwargs, waveform_extractor os.remove(tmp_filename) return labels, new_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index d7ceef2561..4d1dd1f9d5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -223,6 +223,8 @@ def sigmoid(x, L, x0, k, b): ) del we, sorting + import gc + gc.collect() if params["tmp_folder"] is None: shutil.rmtree(tmp_folder) From f69d7e3dbd013c52564b79c1f6ce5c87a3f67af0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 06:30:11 +0000 Subject: [PATCH 45/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/random_projections.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 7cb882409d..620346a875 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -224,6 +224,7 @@ def sigmoid(x, L, x0, k, b): del we, sorting import gc + gc.collect() if params["tmp_folder"] is None: From 403890ce83b065a76bcc1542a562d1a73e6e04be Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 5 Oct 2023 09:01:02 +0200 Subject: [PATCH 46/74] Found it! --- .../clustering/clustering_tools.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index ce29c47113..734ceff1a3 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -593,7 +593,9 @@ def remove_duplicates_via_matching( chunk_size = duration + 3 * margin - method_kwargs.update( + local_params = method_kwargs.copy() + + local_params.update( { "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, @@ -613,12 +615,12 @@ def remove_duplicates_via_matching( t_stop = padding + (i + 1) * duration sub_recording = recording.frame_slice(t_start - half_marging, t_stop + half_marging) - method_kwargs.update({"ignored_ids": ignore_ids + [i]}) + local_params.update({"ignored_ids": ignore_ids + [i]}) spikes, computed = find_spikes_from_templates( - sub_recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs + sub_recording, method=method, method_kwargs=local_params, extra_outputs=True, **job_kwargs ) if method == "circus-omp-svd": - method_kwargs.update( + local_params.update( { "overlaps": computed["overlaps"], "templates": computed["templates"], @@ -632,7 +634,7 @@ def remove_duplicates_via_matching( } ) elif method == "circus-omp": - method_kwargs.update( + local_params.update( { "overlaps": computed["overlaps"], "templates": computed["templates"], @@ -664,7 +666,7 @@ def remove_duplicates_via_matching( labels = np.unique(new_labels) labels = labels[labels >= 0] - del recording, sub_recording, method_kwargs, waveform_extractor + del recording, sub_recording, local_params, waveform_extractor os.remove(tmp_filename) return labels, new_labels From 6951e856c0794e78108be180d6f16e0fde6af6e2 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 5 Oct 2023 09:54:27 +0200 Subject: [PATCH 47/74] WIP --- .../sortingcomponents/clustering/clustering_tools.py | 2 +- .../sortingcomponents/clustering/random_projections.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 734ceff1a3..b4938717f8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -610,7 +610,7 @@ def remove_duplicates_via_matching( ignore_ids = [] similar_templates = [[], []] - for i in np.arange(nb_templates)[indices]: + for i in indices: t_start = padding + i * duration t_stop = padding + (i + 1) * duration diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 620346a875..1f97bf5201 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -223,9 +223,6 @@ def sigmoid(x, L, x0, k, b): ) del we, sorting - import gc - - gc.collect() if params["tmp_folder"] is None: shutil.rmtree(tmp_folder) From fdebd12b09654796a177f4ab91b8e614409f5ac7 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 5 Oct 2023 10:43:20 +0200 Subject: [PATCH 48/74] Sparse waveforms were not handled --- src/spikeinterface/sorters/internal/spyking_circus2.py | 4 ++-- .../sortingcomponents/clustering/random_projections.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 6cf925e852..0c3b9f95d1 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -99,10 +99,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We launch a clustering (using hdbscan) relying on positions and features extracted on ## the fly from the snippets clustering_params = params["clustering"].copy() - clustering_params["waveforms_kwargs"] = params["waveforms"] + clustering_params["waveforms"] = params["waveforms"].copy() for k in ["ms_before", "ms_after"]: - clustering_params["waveforms_kwargs"][k] = params["general"][k] + clustering_params["waveforms"][k] = params["general"][k] clustering_params.update(dict(shared_memory=params["shared_memory"])) clustering_params["job_kwargs"] = job_kwargs diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 1f97bf5201..ffb868f682 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -199,9 +199,8 @@ def sigmoid(x, L, x0, k, b): recording, sorting, waveform_folder, - ms_before=params["ms_before"], - ms_after=params["ms_after"], **params["job_kwargs"], + **params['waveforms'], return_scaled=False, mode=mode, ) From b6f9235a7cf9c2ad106ec0e4cb6be365a243d2af Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 08:44:20 +0000 Subject: [PATCH 49/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/random_projections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index ffb868f682..a81458d7a8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -200,7 +200,7 @@ def sigmoid(x, L, x0, k, b): sorting, waveform_folder, **params["job_kwargs"], - **params['waveforms'], + **params["waveforms"], return_scaled=False, mode=mode, ) From f4f3fb4199a59add1882b26e0925e08c00d1fed3 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 5 Oct 2023 11:34:41 +0200 Subject: [PATCH 50/74] wip --- .../sortingcomponents/clustering/merge.py | 87 ++++++++++++++++--- .../sortingcomponents/clustering/split.py | 13 ++- 2 files changed, 83 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index e2049d70bf..1dd9f9fc37 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -80,9 +80,45 @@ def merge_clusters( method_kwargs=method_kwargs, **job_kwargs, ) + + + DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + fig, ax = plt.subplots() + ax.matshow(pair_values) + + pair_values[~pair_mask] = 20 + + import hdbscan + fig, ax = plt.subplots() + clusterer = hdbscan.HDBSCAN(metric='precomputed', min_cluster_size=2, allow_single_cluster=True) + clusterer.fit(pair_values) + print(clusterer.labels_) + clusterer.single_linkage_tree_.plot(cmap='viridis', colorbar=True) + #~ fig, ax = plt.subplots() + #~ clusterer.minimum_spanning_tree_.plot(edge_cmap='viridis', + #~ edge_alpha=0.6, + #~ node_size=80, + #~ edge_linewidth=2) + + graph = clusterer.single_linkage_tree_.to_networkx() + + import scipy.cluster + fig, ax = plt.subplots() + scipy.cluster.hierarchy.dendrogram(clusterer.single_linkage_tree_.to_numpy(), ax=ax) + + import networkx as nx + fig = plt.figure() + nx.draw_networkx(graph) + plt.show() - # merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="partial") - merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full") + plt.show() + + + + merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="partial") + # merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full") group_shifts = resolve_final_shifts(labels_set, merges, pair_mask, pair_shift) @@ -187,7 +223,7 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" else: raise ValueError - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -196,7 +232,7 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" nx.draw_networkx(sub_graph) plt.show() - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -348,6 +384,7 @@ def merge( criteria="diptest", threshold_diptest=0.5, threshold_percentile=80.0, + threshold_overlap=0.4, num_shift=2, ): if num_shift > 0: @@ -449,6 +486,23 @@ def merge( l1 = np.percentile(feat1, 100.0 - threshold_percentile) is_merge = l0 >= l1 merge_value = l0 - l1 + elif criteria == "distrib_overlap": + lim0 = min(np.min(feat0), np.min(feat1)) + lim1 = max(np.max(feat0), np.max(feat1)) + bin_size = (lim1 - lim0) / 200. + bins = np.arange(lim0, lim1, bin_size) + + pdf0, _ = np.histogram(feat0, bins=bins, density=True) + pdf1, _ = np.histogram(feat1, bins=bins, density=True) + pdf0 *= bin_size + pdf1 *= bin_size + overlap = np.sum(np.minimum(pdf0, pdf1)) + + is_merge = overlap >= threshold_overlap + + merge_value = 1 - overlap + + else: raise ValueError(f"bad criteria {criteria}") @@ -457,11 +511,13 @@ def merge( else: final_shift = 0 - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG and is_merge: - # if DEBUG: + # if DEBUG and not is_merge: + # if DEBUG and (overlap > 0.05 and overlap <0.25): + # if label0 == 49 and label1== 65: import matplotlib.pyplot as plt flatten_wfs0 = wfs0.swapaxes(1, 2).reshape(wfs0.shape[0], -1) @@ -479,13 +535,16 @@ def merge( ax.legend() bins = np.linspace(np.percentile(feat, 1), np.percentile(feat, 99), 100) - - count0, _ = np.histogram(feat0, bins=bins) - count1, _ = np.histogram(feat1, bins=bins) + bin_size = bins[1] - bins[0] + count0, _ = np.histogram(feat0, bins=bins, density=True) + count1, _ = np.histogram(feat1, bins=bins, density=True) + pdf0 = count0 * bin_size + pdf1 = count1 * bin_size + ax = axs[1] - ax.plot(bins[:-1], count0, color="C0") - ax.plot(bins[:-1], count1, color="C1") + ax.plot(bins[:-1], pdf0, color="C0") + ax.plot(bins[:-1], pdf1, color="C1") if criteria == "diptest": ax.set_title(f"{dipscore:.4f} {is_merge}") @@ -493,9 +552,11 @@ def merge( ax.set_title(f"{l0:.4f} {l1:.4f} {is_merge}") ax.axvline(l0, color="C0") ax.axvline(l1, color="C1") + elif criteria == "distrib_overlap": + print(lim0, lim1, ) + ax.set_title(f"{overlap:.4f} {is_merge}") + ax.plot(bins[:-1], np.minimum(pdf0, pdf1), ls='--', color='k') - - plt.show() diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 411d8c2116..d3e630a165 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -205,9 +205,11 @@ def split( final_features = TruncatedSVD(n_pca_features).fit_transform(flatten_features) if clusterer == "hdbscan": - clust = HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples, allow_single_cluster=True) + clust = HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples, allow_single_cluster=True, + cluster_selection_method="leaf") clust.fit(final_features) possible_labels = clust.labels_ + is_split = np.setdiff1d(possible_labels, [-1]).size > 1 elif clusterer == "isocut5": dipscore, cutpoint = isocut5(final_features[:, 0]) possible_labels = np.zeros(final_features.shape[0]) @@ -215,14 +217,15 @@ def split( mask = final_features[:, 0] > cutpoint if np.sum(mask) > min_cluster_size and np.sum(~mask): possible_labels[mask] = 1 + is_split = np.setdiff1d(possible_labels, [-1]).size > 1 else: - return False, None + is_split = False else: raise ValueError(f"wrong clusterer {clusterer}") - is_split = np.setdiff1d(possible_labels, [-1]).size > 1 + - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -243,6 +246,8 @@ def split( ax = axs[1] ax.plot(flatten_wfs[mask][sl].T, color=colors[k], alpha=0.5) + + axs[0].set_title(f"{clusterer} {is_split}") plt.show() From 50f6fcf5322bf10f1b8310ac228921a975b17557 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 5 Oct 2023 12:16:50 +0200 Subject: [PATCH 51/74] small fix unrelated --- src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py index 364fc298c6..c921f42c6d 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py @@ -179,6 +179,8 @@ def plot(self): def _do_plot(self): import sklearn + import matplotlib.pyplot as plt + import matplotlib # compute similarity # take index of template (respect unit_ids order) From 0798169827321ca8a823780baa377ed8d5820469 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Thu, 5 Oct 2023 13:12:27 +0200 Subject: [PATCH 52/74] Update src/spikeinterface/core/waveform_extractor.py --- src/spikeinterface/core/waveform_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index eb027faf81..0fc5694207 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -1507,7 +1507,7 @@ def extract_waveforms( If True and recording has gain_to_uV/offset_to_uV properties, waveforms are converted to uV. dtype: dtype or None Dtype of the output waveforms. If None, the recording dtype is maintained. - sparse: bool (default True) + sparse: bool, default: True If True, before extracting all waveforms the `precompute_sparsity()` function is run using a few spikes to get an estimate of dense templates to create a ChannelSparsity object. Then, the waveforms will be sparse at extraction time, which saves a lot of memory. From 4293b2244be7b71aa0ce68f4dabad24d23318637 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 11:17:03 +0000 Subject: [PATCH 53/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py index c921f42c6d..468b96ff3b 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py @@ -44,7 +44,7 @@ def plot(self): def _do_plot(self): from matplotlib import pyplot as plt - + fig = self.figure for ax in fig.axes: From 3371915310a4bda8cbd9ecd8a5e2d2f3e0ee55b1 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 5 Oct 2023 15:36:46 +0200 Subject: [PATCH 54/74] Keep sparse=False in postprocessing tests --- .../postprocessing/tests/common_extension_tests.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 8f864e9b84..50e2ecdb57 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -57,6 +57,7 @@ def setUp(self): ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500, + sparse=False, n_jobs=1, chunk_size=30000, overwrite=True, @@ -92,6 +93,7 @@ def setUp(self): ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500, + sparse=False, n_jobs=1, chunk_size=30000, overwrite=True, @@ -112,6 +114,7 @@ def setUp(self): recording, sorting, mode="memory", + sparse=False, ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500, From bef9c4ab9d5eeea9331bfbab5076da23ef5f61cc Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 5 Oct 2023 16:09:48 +0200 Subject: [PATCH 55/74] change split merge naming --- .../sortingcomponents/clustering/merge.py | 17 +++++++++-- .../sortingcomponents/clustering/split.py | 29 ++++++++++++++----- .../sortingcomponents/tests/test_merge.py | 13 +++++++++ .../sortingcomponents/tests/test_split.py | 1 + 4 files changed, 49 insertions(+), 11 deletions(-) create mode 100644 src/spikeinterface/sortingcomponents/tests/test_merge.py diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 1dd9f9fc37..5539ec1051 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -368,8 +368,19 @@ def find_pair_function_wrapper(label0, label1): return is_merge, label0, label1, shift, merge_value -class WaveformsLda: - name = "waveforms_lda" +class ProjectDistribution: + """ + This method is a refactorized mix between: + * old tridesclous code + * some ideas by Charlie Windolf in spikespvae + + The idea is : + * project the waveform (or features) samples on a 1d axis (using LDA for instance). + * check that it is the same or not distribution (diptest, distrib_overlap, ...) + + + """ + name = "project_distribution" @staticmethod def merge( @@ -564,6 +575,6 @@ def merge( find_pair_method_list = [ - WaveformsLda, + ProjectDistribution, ] find_pair_method_dict = {e.name: e for e in find_pair_method_list} diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index d3e630a165..dc649cec97 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -13,6 +13,9 @@ from .isocut5 import isocut5 +# important all DEBUG and matplotlib are left in the code intentionally + + def split_clusters( peak_labels, recording, @@ -25,7 +28,7 @@ def split_clusters( **job_kwargs, ): """ - Run recusrsively or not in a multi process pool a local split method. + Run recusrsively (or not) in a multi process pool a local split method. Parameters ---------- @@ -151,11 +154,20 @@ def split_function_wrapper(peak_indices): return is_split, local_labels, peak_indices -class HdbscanOnLocalPca: - # @charlie : this is the equivalent of "herding_split()" in DART - # but simplified, flexible and renamed - name = "hdbscan_on_local_pca" +class LocalFeatureClustering: + """ + This method is a refactorized mix between: + * old tridesclous code + * "herding_split()" in DART/spikepsvae by Charlie Windolf + + The idea simple : + * agregate features (svd or even waveforms) with sparse channel. + * run a local feature reduction (pca or svd) + * try a new split (hdscan or isocut5) + """ + + name = "local_feature_clustering" @staticmethod def split( @@ -170,6 +182,8 @@ def split( min_cluster_size=25, min_samples=25, n_pca_features=2, + minimum_common_channels=2, + ): local_labels = np.zeros(peak_indices.size, dtype=np.int64) @@ -183,8 +197,7 @@ def split( target_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0)) # TODO fix this a better way, this when cluster have too few overlapping channels - minimum_channels = 2 - if target_channels.size < minimum_channels: + if target_channels.size < minimum_common_channels: return False, None aligned_wfs, dont_have_channels = aggregate_sparse_features( @@ -260,6 +273,6 @@ def split( split_methods_list = [ - HdbscanOnLocalPca, + LocalFeatureClustering, ] split_methods_dict = {e.name: e for e in split_methods_list} diff --git a/src/spikeinterface/sortingcomponents/tests/test_merge.py b/src/spikeinterface/sortingcomponents/tests/test_merge.py new file mode 100644 index 0000000000..b7a669a263 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/tests/test_merge.py @@ -0,0 +1,13 @@ +import pytest +import numpy as np + +from spikeinterface.sortingcomponents.clustering.split import split_clusters + +# no proper test at the moment this is used in tridesclous2 + +def test_merge(): + pass + + +if __name__ == "__main__": + test_merge() diff --git a/src/spikeinterface/sortingcomponents/tests/test_split.py b/src/spikeinterface/sortingcomponents/tests/test_split.py index ed5e756469..ca5e5b57e7 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_split.py +++ b/src/spikeinterface/sortingcomponents/tests/test_split.py @@ -3,6 +3,7 @@ from spikeinterface.sortingcomponents.clustering.split import split_clusters +# no proper test at the moment this is used in tridesclous2 def test_split(): pass From df35c6a2ba3458597e2ec3c47673cbee9e4b7182 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 5 Oct 2023 16:22:42 +0200 Subject: [PATCH 56/74] small fixes in tests --- src/spikeinterface/core/job_tools.py | 3 +-- src/spikeinterface/core/tests/test_globals.py | 6 +++--- src/spikeinterface/core/tests/test_waveform_extractor.py | 6 ++++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index a13e1dd527..e42f7bb8b4 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -434,8 +434,7 @@ def function_wrapper(args): return _func(segment_index, start_frame, end_frame, _worker_ctx) -# Here some utils - +# Here some utils copy/paste from DART (Charlie Windolf) class MockFuture: """A non-concurrent class for mocking the concurrent.futures API.""" diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index 8216a4aae6..2c0792c152 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -37,16 +37,16 @@ def test_global_tmp_folder(): def test_global_job_kwargs(): - job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True) + job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) global_job_kwargs = get_global_job_kwargs() - assert global_job_kwargs == dict(n_jobs=1, chunk_duration="1s", progress_bar=True) + assert global_job_kwargs == dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) set_global_job_kwargs(**job_kwargs) assert get_global_job_kwargs() == job_kwargs # test updating only one field partial_job_kwargs = dict(n_jobs=2) set_global_job_kwargs(**partial_job_kwargs) global_job_kwargs = get_global_job_kwargs() - assert global_job_kwargs == dict(n_jobs=2, chunk_duration="1s", progress_bar=True) + assert global_job_kwargs == dict(n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) # test that fix_job_kwargs grabs global kwargs new_job_kwargs = dict(n_jobs=10) job_kwargs_split = fix_job_kwargs(new_job_kwargs) diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 2bbf5e9b0f..de6c3d752a 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -346,6 +346,8 @@ def test_recordingless(): # delete original recording and rely on rec_attributes if platform.system() != "Windows": + # this avoid reference on the folder + del we, recording shutil.rmtree(cache_folder / "recording1") we_loaded = WaveformExtractor.load(wf_folder, with_recording=False) assert not we_loaded.has_recording() @@ -554,6 +556,6 @@ def test_non_json_object(): # test_WaveformExtractor() # test_extract_waveforms() # test_portability() - # test_recordingless() + test_recordingless() # test_compute_sparsity() - test_non_json_object() + # test_non_json_object() From 94bfb70f528603ecf22d7b499228146792fb33b9 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 5 Oct 2023 16:23:56 +0200 Subject: [PATCH 57/74] in1d to isin --- src/spikeinterface/sortingcomponents/clustering/tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 9a537ab8a8..8e25c9cb7f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -83,7 +83,7 @@ def aggregate_sparse_features(peaks, peak_indices, sparse_feature, sparse_mask, for chan in np.unique(local_peaks["channel_index"]): sparse_chans = np.flatnonzero(sparse_mask[chan, :]) peak_inds = np.flatnonzero(local_peaks["channel_index"] == chan) - if np.all(np.in1d(target_channels, sparse_chans)): + if np.all(np.isin(target_channels, sparse_chans)): # peaks feature channel have all target_channels source_chans = np.flatnonzero(np.in1d(sparse_chans, target_channels)) aligned_features[peak_inds, :, :] = sparse_feature[peak_indices[peak_inds], :, :][:, :, source_chans] From 48da4ea5f429eac411a331a39d9b468428b70897 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 5 Oct 2023 16:45:52 +0200 Subject: [PATCH 58/74] wip --- src/spikeinterface/sortingcomponents/clustering/tools.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 8e25c9cb7f..c334daebe3 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -94,6 +94,7 @@ def aggregate_sparse_features(peaks, peak_indices, sparse_feature, sparse_mask, return aligned_features, dont_have_channels + def compute_template_from_sparse( peaks, labels, labels_set, sparse_waveforms, sparse_mask, total_channels, peak_shifts=None ): From de2d642d5f833b5c3f68df5150311b1ed5eddca8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 14:41:53 +0000 Subject: [PATCH 59/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/job_tools.py | 1 + src/spikeinterface/core/tests/test_globals.py | 8 +- .../core/tests/test_waveform_extractor.py | 2 +- .../sortingcomponents/clustering/merge.py | 77 ++++++++++--------- .../sortingcomponents/clustering/split.py | 18 ++--- .../sortingcomponents/clustering/tools.py | 1 - .../sortingcomponents/tests/test_merge.py | 1 + .../sortingcomponents/tests/test_split.py | 1 + 8 files changed, 58 insertions(+), 51 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index e42f7bb8b4..cf7a67489c 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -436,6 +436,7 @@ def function_wrapper(args): # Here some utils copy/paste from DART (Charlie Windolf) + class MockFuture: """A non-concurrent class for mocking the concurrent.futures API.""" diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index 2c0792c152..d0672405d6 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -39,14 +39,18 @@ def test_global_tmp_folder(): def test_global_job_kwargs(): job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) global_job_kwargs = get_global_job_kwargs() - assert global_job_kwargs == dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) + assert global_job_kwargs == dict( + n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + ) set_global_job_kwargs(**job_kwargs) assert get_global_job_kwargs() == job_kwargs # test updating only one field partial_job_kwargs = dict(n_jobs=2) set_global_job_kwargs(**partial_job_kwargs) global_job_kwargs = get_global_job_kwargs() - assert global_job_kwargs == dict(n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) + assert global_job_kwargs == dict( + n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + ) # test that fix_job_kwargs grabs global kwargs new_job_kwargs = dict(n_jobs=10) job_kwargs_split = fix_job_kwargs(new_job_kwargs) diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index de6c3d752a..b56180a9e9 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -558,4 +558,4 @@ def test_non_json_object(): # test_portability() test_recordingless() # test_compute_sparsity() - # test_non_json_object() + # test_non_json_object() diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 5539ec1051..d892d0723a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -80,45 +80,46 @@ def merge_clusters( method_kwargs=method_kwargs, **job_kwargs, ) - - + DEBUG = False if DEBUG: import matplotlib.pyplot as plt + fig, ax = plt.subplots() ax.matshow(pair_values) - - pair_values[~pair_mask] = 20 - + + pair_values[~pair_mask] = 20 + import hdbscan + fig, ax = plt.subplots() - clusterer = hdbscan.HDBSCAN(metric='precomputed', min_cluster_size=2, allow_single_cluster=True) + clusterer = hdbscan.HDBSCAN(metric="precomputed", min_cluster_size=2, allow_single_cluster=True) clusterer.fit(pair_values) print(clusterer.labels_) - clusterer.single_linkage_tree_.plot(cmap='viridis', colorbar=True) - #~ fig, ax = plt.subplots() - #~ clusterer.minimum_spanning_tree_.plot(edge_cmap='viridis', - #~ edge_alpha=0.6, - #~ node_size=80, - #~ edge_linewidth=2) - + clusterer.single_linkage_tree_.plot(cmap="viridis", colorbar=True) + # ~ fig, ax = plt.subplots() + # ~ clusterer.minimum_spanning_tree_.plot(edge_cmap='viridis', + # ~ edge_alpha=0.6, + # ~ node_size=80, + # ~ edge_linewidth=2) + graph = clusterer.single_linkage_tree_.to_networkx() import scipy.cluster + fig, ax = plt.subplots() scipy.cluster.hierarchy.dendrogram(clusterer.single_linkage_tree_.to_numpy(), ax=ax) - + import networkx as nx + fig = plt.figure() nx.draw_networkx(graph) plt.show() plt.show() - - - + merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="partial") - # merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full") + # merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full") group_shifts = resolve_final_shifts(labels_set, merges, pair_mask, pair_shift) @@ -223,7 +224,7 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" else: raise ValueError - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -232,7 +233,7 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" nx.draw_networkx(sub_graph) plt.show() - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -377,9 +378,10 @@ class ProjectDistribution: The idea is : * project the waveform (or features) samples on a 1d axis (using LDA for instance). * check that it is the same or not distribution (diptest, distrib_overlap, ...) - + """ + name = "project_distribution" @staticmethod @@ -412,13 +414,12 @@ def merge( chans1 = np.unique(peaks["channel_index"][inds1]) target_chans1 = np.flatnonzero(np.all(waveforms_sparse_mask[chans1, :], axis=0)) - if inds0.size <40 or inds1.size <40: + if inds0.size < 40 or inds1.size < 40: is_merge = False merge_value = 0 final_shift = 0 return is_merge, label0, label1, final_shift, merge_value - target_chans = np.intersect1d(target_chans0, target_chans1) inds = np.concatenate([inds0, inds1]) @@ -500,20 +501,19 @@ def merge( elif criteria == "distrib_overlap": lim0 = min(np.min(feat0), np.min(feat1)) lim1 = max(np.max(feat0), np.max(feat1)) - bin_size = (lim1 - lim0) / 200. + bin_size = (lim1 - lim0) / 200.0 bins = np.arange(lim0, lim1, bin_size) - + pdf0, _ = np.histogram(feat0, bins=bins, density=True) pdf1, _ = np.histogram(feat1, bins=bins, density=True) pdf0 *= bin_size - pdf1 *= bin_size + pdf1 *= bin_size overlap = np.sum(np.minimum(pdf0, pdf1)) - + is_merge = overlap >= threshold_overlap - + merge_value = 1 - overlap - - + else: raise ValueError(f"bad criteria {criteria}") @@ -522,13 +522,13 @@ def merge( else: final_shift = 0 - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG and is_merge: - # if DEBUG and not is_merge: - # if DEBUG and (overlap > 0.05 and overlap <0.25): - # if label0 == 49 and label1== 65: + # if DEBUG and not is_merge: + # if DEBUG and (overlap > 0.05 and overlap <0.25): + # if label0 == 49 and label1== 65: import matplotlib.pyplot as plt flatten_wfs0 = wfs0.swapaxes(1, 2).reshape(wfs0.shape[0], -1) @@ -551,7 +551,6 @@ def merge( count1, _ = np.histogram(feat1, bins=bins, density=True) pdf0 = count0 * bin_size pdf1 = count1 * bin_size - ax = axs[1] ax.plot(bins[:-1], pdf0, color="C0") @@ -564,13 +563,15 @@ def merge( ax.axvline(l0, color="C0") ax.axvline(l1, color="C1") elif criteria == "distrib_overlap": - print(lim0, lim1, ) + print( + lim0, + lim1, + ) ax.set_title(f"{overlap:.4f} {is_merge}") - ax.plot(bins[:-1], np.minimum(pdf0, pdf1), ls='--', color='k') + ax.plot(bins[:-1], np.minimum(pdf0, pdf1), ls="--", color="k") plt.show() - return is_merge, label0, label1, final_shift, merge_value diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index dc649cec97..9836e9110f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -154,13 +154,12 @@ def split_function_wrapper(peak_indices): return is_split, local_labels, peak_indices - class LocalFeatureClustering: """ This method is a refactorized mix between: * old tridesclous code * "herding_split()" in DART/spikepsvae by Charlie Windolf - + The idea simple : * agregate features (svd or even waveforms) with sparse channel. * run a local feature reduction (pca or svd) @@ -183,7 +182,6 @@ def split( min_samples=25, n_pca_features=2, minimum_common_channels=2, - ): local_labels = np.zeros(peak_indices.size, dtype=np.int64) @@ -218,8 +216,12 @@ def split( final_features = TruncatedSVD(n_pca_features).fit_transform(flatten_features) if clusterer == "hdbscan": - clust = HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples, allow_single_cluster=True, - cluster_selection_method="leaf") + clust = HDBSCAN( + min_cluster_size=min_cluster_size, + min_samples=min_samples, + allow_single_cluster=True, + cluster_selection_method="leaf", + ) clust.fit(final_features) possible_labels = clust.labels_ is_split = np.setdiff1d(possible_labels, [-1]).size > 1 @@ -236,9 +238,7 @@ def split( else: raise ValueError(f"wrong clusterer {clusterer}") - - - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -259,7 +259,7 @@ def split( ax = axs[1] ax.plot(flatten_wfs[mask][sl].T, color=colors[k], alpha=0.5) - + axs[0].set_title(f"{clusterer} {is_split}") plt.show() diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index c334daebe3..8e25c9cb7f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -94,7 +94,6 @@ def aggregate_sparse_features(peaks, peak_indices, sparse_feature, sparse_mask, return aligned_features, dont_have_channels - def compute_template_from_sparse( peaks, labels, labels_set, sparse_waveforms, sparse_mask, total_channels, peak_shifts=None ): diff --git a/src/spikeinterface/sortingcomponents/tests/test_merge.py b/src/spikeinterface/sortingcomponents/tests/test_merge.py index b7a669a263..6b3ea2a901 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_merge.py +++ b/src/spikeinterface/sortingcomponents/tests/test_merge.py @@ -5,6 +5,7 @@ # no proper test at the moment this is used in tridesclous2 + def test_merge(): pass diff --git a/src/spikeinterface/sortingcomponents/tests/test_split.py b/src/spikeinterface/sortingcomponents/tests/test_split.py index ca5e5b57e7..5953f74e24 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_split.py +++ b/src/spikeinterface/sortingcomponents/tests/test_split.py @@ -5,6 +5,7 @@ # no proper test at the moment this is used in tridesclous2 + def test_split(): pass From f2fe6bbcedc5a1cca38918444afe52e3ae1bec19 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Thu, 5 Oct 2023 11:42:38 -0400 Subject: [PATCH 60/74] assert typo fixes round 1 --- src/spikeinterface/core/base.py | 6 +-- src/spikeinterface/core/baserecording.py | 6 +-- src/spikeinterface/core/basesorting.py | 2 +- .../core/binaryrecordingextractor.py | 2 +- .../core/channelsaggregationrecording.py | 4 +- src/spikeinterface/core/channelslice.py | 4 +- .../core/frameslicerecording.py | 2 +- src/spikeinterface/core/frameslicesorting.py | 8 ++-- src/spikeinterface/core/generate.py | 4 +- src/spikeinterface/core/template_tools.py | 41 ++++++++++--------- .../core/unitsaggregationsorting.py | 2 +- 11 files changed, 41 insertions(+), 40 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 8b4f094c20..ba18cf09b6 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -45,7 +45,7 @@ def __init__(self, main_ids: Sequence) -> None: self._kwargs = {} # 'main_ids' will either be channel_ids or units_ids - # They is used for properties + # They are used for properties self._main_ids = np.array(main_ids) # dict at object level @@ -984,7 +984,7 @@ def _load_extractor_from_dict(dic) -> BaseExtractor: class_name = None if "kwargs" not in dic: - raise Exception(f"This dict cannot be load into extractor {dic}") + raise Exception(f"This dict cannot be loaded into extractor {dic}") # Create new kwargs to avoid modifying the original dict["kwargs"] new_kwargs = dict() @@ -1005,7 +1005,7 @@ def _load_extractor_from_dict(dic) -> BaseExtractor: assert extractor_class is not None and class_name is not None, "Could not load spikeinterface class" if not _check_same_version(class_name, dic["version"]): warnings.warn( - f"Versions are not the same. This might lead compatibility errors. " + f"Versions are not the same. This might lead to compatibility errors. " f"Using {class_name.split('.')[0]}=={dic['version']} is recommended" ) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 08f187895b..d3572ef66b 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -305,7 +305,7 @@ def get_traces( if not self.has_scaled(): raise ValueError( - "This recording do not support return_scaled=True (need gain_to_uV and offset_" "to_uV properties)" + "This recording does not support return_scaled=True (need gain_to_uV and offset_" "to_uV properties)" ) else: gains = self.get_property("gain_to_uV") @@ -416,8 +416,8 @@ def set_times(self, times, segment_index=None, with_warning=True): if with_warning: warn( "Setting times with Recording.set_times() is not recommended because " - "times are not always propagated to across preprocessing" - "Use use this carefully!" + "times are not always propagated across preprocessing" + "Use this carefully!" ) def sample_index_to_time(self, sample_ind, segment_index=None): diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index e6d08d38f7..2a06a699cb 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -170,7 +170,7 @@ def register_recording(self, recording, check_spike_frames=True): if check_spike_frames: if has_exceeding_spikes(recording, self): warnings.warn( - "Some spikes are exceeding the recording's duration! " + "Some spikes exceed the recording's duration! " "Removing these excess spikes with `spikeinterface.curation.remove_excess_spikes()` " "Might be necessary for further postprocessing." ) diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index 72a95637f6..b45290caa5 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -91,7 +91,7 @@ def __init__( file_path_list = [Path(file_paths)] if t_starts is not None: - assert len(t_starts) == len(file_path_list), "t_starts must be a list of same size than file_paths" + assert len(t_starts) == len(file_path_list), "t_starts must be a list of the same size as file_paths" t_starts = [float(t_start) for t_start in t_starts] dtype = np.dtype(dtype) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index d36e168f8d..8714580821 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -104,11 +104,11 @@ def __init__(self, channel_map, parent_segments): times_kargs0 = parent_segment0.get_times_kwargs() if times_kargs0["time_vector"] is None: for ps in parent_segments: - assert ps.get_times_kwargs()["time_vector"] is None, "All segment should not have times set" + assert ps.get_times_kwargs()["time_vector"] is None, "All segments should not have times set" else: for ps in parent_segments: assert ps.get_times_kwargs()["t_start"] == times_kargs0["t_start"], ( - "All segment should have the same " "t_start" + "All segments should have the same " "t_start" ) BaseRecordingSegment.__init__(self, **times_kargs0) diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index ebd1b7db03..3a21e356a6 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -35,7 +35,7 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) ), "ChannelSliceRecording: renamed channel_ids must be the same size" assert ( self._channel_ids.size == np.unique(self._channel_ids).size - ), "ChannelSliceRecording : channel_ids not unique" + ), "ChannelSliceRecording : channel_ids are not unique" sampling_frequency = parent_recording.get_sampling_frequency() @@ -123,7 +123,7 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): ), "ChannelSliceSnippets: renamed channel_ids must be the same size" assert ( self._channel_ids.size == np.unique(self._channel_ids).size - ), "ChannelSliceSnippets : channel_ids not unique" + ), "ChannelSliceSnippets : channel_ids are not unique" sampling_frequency = parent_snippets.get_sampling_frequency() diff --git a/src/spikeinterface/core/frameslicerecording.py b/src/spikeinterface/core/frameslicerecording.py index 968f27c6ad..b8574c506f 100644 --- a/src/spikeinterface/core/frameslicerecording.py +++ b/src/spikeinterface/core/frameslicerecording.py @@ -27,7 +27,7 @@ class FrameSliceRecording(BaseRecording): def __init__(self, parent_recording, start_frame=None, end_frame=None): channel_ids = parent_recording.get_channel_ids() - assert parent_recording.get_num_segments() == 1, "FrameSliceRecording work only with one segment" + assert parent_recording.get_num_segments() == 1, "FrameSliceRecording only works with one segment" parent_size = parent_recording.get_num_samples(0) if start_frame is None: diff --git a/src/spikeinterface/core/frameslicesorting.py b/src/spikeinterface/core/frameslicesorting.py index 5da5350f06..ed1391b0e2 100644 --- a/src/spikeinterface/core/frameslicesorting.py +++ b/src/spikeinterface/core/frameslicesorting.py @@ -36,7 +36,7 @@ class FrameSliceSorting(BaseSorting): def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike_frames=True): unit_ids = parent_sorting.get_unit_ids() - assert parent_sorting.get_num_segments() == 1, "FrameSliceSorting work only with one segment" + assert parent_sorting.get_num_segments() == 1, "FrameSliceSorting only works with one segment" if start_frame is None: start_frame = 0 @@ -49,10 +49,10 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike end_frame = parent_n_samples assert ( end_frame <= parent_n_samples - ), "`end_frame` should be smaller than the sortings total number of samples." + ), "`end_frame` should be smaller than the sortings' total number of samples." assert ( start_frame <= parent_n_samples - ), "`start_frame` should be smaller than the sortings total number of samples." + ), "`start_frame` should be smaller than the sortings' total number of samples." if check_spike_frames and has_exceeding_spikes(parent_sorting._recording, parent_sorting): raise ValueError( "The sorting object has spikes exceeding the recording duration. You have to remove those spikes " @@ -67,7 +67,7 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike end_frame = max_spike_time + 1 assert start_frame < end_frame, ( - "`start_frame` should be greater than `end_frame`. " + "`start_frame` should be less than `end_frame`. " "This may be due to start_frame >= max_spike_time, if the end frame " "was not specified explicitly." ) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 06a5ec96ec..0c67404069 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1101,11 +1101,11 @@ def __init__( # handle also upsampling and jitter upsample_factor = templates.shape[3] elif templates.ndim == 5: - # handle also dirft + # handle also drift raise NotImplementedError("Drift will be implented soon...") # upsample_factor = templates.shape[3] else: - raise ValueError("templates have wring dim should 3 or 4") + raise ValueError("templates have wrong dim should 3 or 4") if upsample_factor is not None: assert upsample_vector is not None diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 95278b76da..552642751c 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -1,3 +1,4 @@ +from __future__ import annotations import numpy as np import warnings @@ -5,7 +6,7 @@ from .recording_tools import get_channel_distances, get_noise_levels -def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: str = "extremum"): +def get_template_amplitudes(waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum"): """ Get amplitude per channel for each unit. @@ -13,9 +14,9 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') - mode: str + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the template to compute best channels + mode: "extremum" | "at_index", default: "extremum" 'extremum': max or min 'at_index': take value at spike index @@ -24,8 +25,8 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st peak_values: dict Dictionary with unit ids as keys and template amplitudes as values """ - assert peak_sign in ("both", "neg", "pos") - assert mode in ("extremum", "at_index") + assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'both', 'neg', or 'pos'" + assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'" unit_ids = waveform_extractor.sorting.unit_ids before = waveform_extractor.nbefore @@ -57,7 +58,7 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st def get_template_extremum_channel( - waveform_extractor, peak_sign: str = "neg", mode: str = "extremum", outputs: str = "id" + waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum", outputs: "id" | "index" = "id" ): """ Compute the channel with the extremum peak for each unit. @@ -66,12 +67,12 @@ def get_template_extremum_channel( ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') - mode: str + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the template to compute best channels + mode: "extremum" | "at_index", default: "extremum" 'extremum': max or min 'at_index': take value at spike index - outputs: str + outputs: "id" | "index", default: "id" * 'id': channel id * 'index': channel index @@ -159,7 +160,7 @@ def get_template_channel_sparsity( get_template_channel_sparsity.__doc__ = get_template_channel_sparsity.__doc__.format(_sparsity_doc) -def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str = "neg"): +def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg"): """ In some situations spike sorters could return a spike index with a small shift related to the waveform peak. This function estimates and return these alignment shifts for the mean template. @@ -169,8 +170,8 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the template to compute best channels Returns ------- @@ -203,7 +204,7 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str return shifts -def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg", mode: str = "at_index"): +def get_template_extremum_amplitude(waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "at_index"): """ Computes amplitudes on the best channel. @@ -211,9 +212,9 @@ def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg", ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') - mode: str + peak_sign: "neg" | "pos" | "both" + Sign of the template to compute best channels + mode: "extremum" | "at_index", default: "at_index" Where the amplitude is computed 'extremum': max or min 'at_index': take value at spike index @@ -223,8 +224,8 @@ def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg", amplitudes: dict Dictionary with unit ids as keys and amplitudes as values """ - assert peak_sign in ("both", "neg", "pos") - assert mode in ("extremum", "at_index") + assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'neg' or 'pos' or 'both'" + assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'" unit_ids = waveform_extractor.sorting.unit_ids before = waveform_extractor.nbefore diff --git a/src/spikeinterface/core/unitsaggregationsorting.py b/src/spikeinterface/core/unitsaggregationsorting.py index 32158f00df..4e98864ba9 100644 --- a/src/spikeinterface/core/unitsaggregationsorting.py +++ b/src/spikeinterface/core/unitsaggregationsorting.py @@ -95,7 +95,7 @@ def __init__(self, sorting_list, renamed_unit_ids=None): try: property_dict[prop_name] = np.concatenate((property_dict[prop_name], values)) except Exception as e: - print(f"Skipping property '{prop_name}' for shape inconsistency") + print(f"Skipping property '{prop_name}' due to shape inconsistency") del property_dict[prop_name] break for prop_name, prop_values in property_dict.items(): From 2417b9af67a652f38e32cf24f749f9c7706554e9 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Thu, 5 Oct 2023 12:11:01 -0400 Subject: [PATCH 61/74] add asserts msgs and fix typos --- src/spikeinterface/preprocessing/clip.py | 2 +- src/spikeinterface/preprocessing/common_reference.py | 2 +- .../preprocessing/detect_bad_channels.py | 4 ++-- src/spikeinterface/preprocessing/filter.py | 6 +++--- src/spikeinterface/preprocessing/filter_opencl.py | 12 ++++++------ .../preprocessing/highpass_spatial_filter.py | 2 +- src/spikeinterface/preprocessing/normalize_scale.py | 4 ++-- src/spikeinterface/preprocessing/phase_shift.py | 2 +- 8 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/preprocessing/clip.py b/src/spikeinterface/preprocessing/clip.py index a2349c1ee9..cc18d51d2e 100644 --- a/src/spikeinterface/preprocessing/clip.py +++ b/src/spikeinterface/preprocessing/clip.py @@ -97,7 +97,7 @@ def __init__( chunk_size=500, seed=0, ): - assert direction in ("upper", "lower", "both") + assert direction in ("upper", "lower", "both"), "'direction' must be 'upper', 'lower', or 'both'" if fill_value is None or quantile_threshold is not None: random_data = get_random_data_chunks( diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index d2ac227217..6d6ce256de 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -83,7 +83,7 @@ def __init__( ref_channel_ids = np.asarray(ref_channel_ids) assert np.all( [ch in recording.get_channel_ids() for ch in ref_channel_ids] - ), "Some wrong 'ref_channel_ids'!" + ), "Some 'ref_channel_ids' are wrong!" elif reference == "local": assert groups is None, "With 'local' CAR, the group option should not be used." closest_inds, dist = get_closest_channels(recording) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index cc4e8601e2..e6e2836a35 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -211,9 +211,9 @@ def detect_bad_channels( if bad_channel_ids.size > recording.get_num_channels() / 3: warnings.warn( - "Over 1/3 of channels are detected as bad. In the precense of a high" + "Over 1/3 of channels are detected as bad. In the presence of a high" "number of dead / noisy channels, bad channel detection may fail " - "(erroneously label good channels as dead)." + "(good channels may be erroneously labeled as dead)." ) elif method == "neighborhood_r2": diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 51c1fb4ad6..b31088edf7 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -71,10 +71,10 @@ def __init__( ): import scipy.signal - assert filter_mode in ("sos", "ba") + assert filter_mode in ("sos", "ba"), "'filter' mode must be 'sos' or 'ba'" fs = recording.get_sampling_frequency() if coeff is None: - assert btype in ("bandpass", "highpass") + assert btype in ("bandpass", "highpass"), "'bytpe' must be 'bandpass' or 'highpass'" # coefficient # self.coeff is 'sos' or 'ab' style filter_coeff = scipy.signal.iirfilter( @@ -258,7 +258,7 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): if dtype.kind == "u": raise TypeError( "The notch filter only supports signed types. Use the 'dtype' argument" - "to specify a signed type (e.g. 'int16', 'float32'" + "to specify a signed type (e.g. 'int16', 'float32')" ) BasePreprocessor.__init__(self, recording, dtype=dtype) diff --git a/src/spikeinterface/preprocessing/filter_opencl.py b/src/spikeinterface/preprocessing/filter_opencl.py index 790279d647..d3a08297c6 100644 --- a/src/spikeinterface/preprocessing/filter_opencl.py +++ b/src/spikeinterface/preprocessing/filter_opencl.py @@ -50,9 +50,9 @@ def __init__( margin_ms=5.0, ): assert HAVE_PYOPENCL, "You need to install pyopencl (and GPU driver!!)" - - assert btype in ("bandpass", "lowpass", "highpass", "bandstop") - assert filter_mode in ("sos",) + btype_modes = ("bandpass", "lowpass", "highpass", "bandstop") + assert btype in btype_modes, f"'btype' must be in {btype_modes}" + assert filter_mode in ("sos",), "'filter_mode' must be 'sos'" # coefficient sf = recording.get_sampling_frequency() @@ -96,8 +96,8 @@ def __init__(self, parent_recording_segment, executor, margin): self.margin = margin def get_traces(self, start_frame, end_frame, channel_indices): - assert start_frame is not None, "FilterOpenCLRecording work with fixed chunk_size" - assert end_frame is not None, "FilterOpenCLRecording work with fixed chunk_size" + assert start_frame is not None, "FilterOpenCLRecording only works with fixed chunk_size" + assert end_frame is not None, "FilterOpenCLRecording only works with fixed chunk_size" chunk_size = end_frame - start_frame if chunk_size != self.executor.chunk_size: @@ -157,7 +157,7 @@ def process(self, traces): if traces.shape[0] != self.full_size: if self.full_size is not None: - print(f"Warning : chunk_size have change {self.chunk_size} {traces.shape[0]}, need recompile CL!!!") + print(f"Warning : chunk_size has changed {self.chunk_size} {traces.shape[0]}, need to recompile CL!!!") self.create_buffers_and_compile() event = pyopencl.enqueue_copy(self.queue, self.input_cl, traces) diff --git a/src/spikeinterface/preprocessing/highpass_spatial_filter.py b/src/spikeinterface/preprocessing/highpass_spatial_filter.py index aa98410568..4df4a409bc 100644 --- a/src/spikeinterface/preprocessing/highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/highpass_spatial_filter.py @@ -212,7 +212,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = traces * self.taper[np.newaxis, :] # apply actual HP filter - import scipy + import scipy.signal traces = scipy.signal.sosfiltfilt(self.sos_filter, traces, axis=1) diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 7d43982853..bd53866b6a 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -68,7 +68,7 @@ def __init__( dtype="float32", **random_chunk_kwargs, ): - assert mode in ("pool_channel", "by_channel") + assert mode in ("pool_channel", "by_channel"), "'mode' must be 'pool_channel' or 'by_channel'" random_data = get_random_data_chunks(recording, **random_chunk_kwargs) @@ -260,7 +260,7 @@ def __init__( dtype="float32", **random_chunk_kwargs, ): - assert mode in ("median+mad", "mean+std") + assert mode in ("median+mad", "mean+std"), "'mode' must be 'median+mad' or 'mean+std'" # fix dtype dtype_ = fix_dtype(recording, dtype) diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 9c8b2589a0..237f32eca4 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -42,7 +42,7 @@ def __init__(self, recording, margin_ms=40.0, inter_sample_shift=None, dtype=Non assert "inter_sample_shift" in recording.get_property_keys(), "'inter_sample_shift' is not a property!" sample_shifts = recording.get_property("inter_sample_shift") else: - assert len(inter_sample_shift) == recording.get_num_channels(), "sample " + assert len(inter_sample_shift) == recording.get_num_channels(), "the 'inter_sample_shift' must be same size at the num_channels " sample_shifts = np.asarray(inter_sample_shift) margin = int(margin_ms * recording.get_sampling_frequency() / 1000.0) From 9db087de50bd4b132b5e42c743dcf17fa8a9106b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 16:27:04 +0000 Subject: [PATCH 62/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/baserecording.py | 3 ++- src/spikeinterface/core/template_tools.py | 13 ++++++++++--- src/spikeinterface/preprocessing/phase_shift.py | 4 +++- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index d3572ef66b..2977211c25 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -305,7 +305,8 @@ def get_traces( if not self.has_scaled(): raise ValueError( - "This recording does not support return_scaled=True (need gain_to_uV and offset_" "to_uV properties)" + "This recording does not support return_scaled=True (need gain_to_uV and offset_" + "to_uV properties)" ) else: gains = self.get_property("gain_to_uV") diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 552642751c..b6022e27c0 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -6,7 +6,9 @@ from .recording_tools import get_channel_distances, get_noise_levels -def get_template_amplitudes(waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum"): +def get_template_amplitudes( + waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum" +): """ Get amplitude per channel for each unit. @@ -58,7 +60,10 @@ def get_template_amplitudes(waveform_extractor, peak_sign: "neg" | "pos" | "both def get_template_extremum_channel( - waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum", outputs: "id" | "index" = "id" + waveform_extractor, + peak_sign: "neg" | "pos" | "both" = "neg", + mode: "extremum" | "at_index" = "extremum", + outputs: "id" | "index" = "id", ): """ Compute the channel with the extremum peak for each unit. @@ -204,7 +209,9 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: "neg return shifts -def get_template_extremum_amplitude(waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "at_index"): +def get_template_extremum_amplitude( + waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "at_index" +): """ Computes amplitudes on the best channel. diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 237f32eca4..bdba55038d 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -42,7 +42,9 @@ def __init__(self, recording, margin_ms=40.0, inter_sample_shift=None, dtype=Non assert "inter_sample_shift" in recording.get_property_keys(), "'inter_sample_shift' is not a property!" sample_shifts = recording.get_property("inter_sample_shift") else: - assert len(inter_sample_shift) == recording.get_num_channels(), "the 'inter_sample_shift' must be same size at the num_channels " + assert ( + len(inter_sample_shift) == recording.get_num_channels() + ), "the 'inter_sample_shift' must be same size at the num_channels " sample_shifts = np.asarray(inter_sample_shift) margin = int(margin_ms * recording.get_sampling_frequency() / 1000.0) From 57078791382deed5fe73c4799bd352e6c3e0ee80 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 5 Oct 2023 18:39:27 +0200 Subject: [PATCH 63/74] Fix ipywidgets with explicit dense/sparse waveforms --- .../widgets/tests/test_widgets.py | 102 +++++++++--------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index f44878927d..da16136fa9 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -49,28 +49,28 @@ def setUpClass(cls): cls.num_units = len(cls.sorting.get_unit_ids()) if (cache_folder / "mearec_test").is_dir(): - cls.we = load_waveforms(cache_folder / "mearec_test") + cls.we_dense = load_waveforms(cache_folder / "mearec_test") else: - cls.we = extract_waveforms(cls.recording, cls.sorting, cache_folder / "mearec_test") + cls.we_dense = extract_waveforms(cls.recording, cls.sorting, cache_folder / "mearec_test", sparse=False) sw.set_default_plotter_backend("matplotlib") metric_names = ["snr", "isi_violation", "num_spikes"] - _ = compute_spike_amplitudes(cls.we) - _ = compute_unit_locations(cls.we) - _ = compute_spike_locations(cls.we) - _ = compute_quality_metrics(cls.we, metric_names=metric_names) - _ = compute_template_metrics(cls.we) - _ = compute_correlograms(cls.we) - _ = compute_template_similarity(cls.we) + _ = compute_spike_amplitudes(cls.we_dense) + _ = compute_unit_locations(cls.we_dense) + _ = compute_spike_locations(cls.we_dense) + _ = compute_quality_metrics(cls.we_dense, metric_names=metric_names) + _ = compute_template_metrics(cls.we_dense) + _ = compute_correlograms(cls.we_dense) + _ = compute_template_similarity(cls.we_dense) # make sparse waveforms - cls.sparsity_radius = compute_sparsity(cls.we, method="radius", radius_um=50) - cls.sparsity_best = compute_sparsity(cls.we, method="best_channels", num_channels=5) + cls.sparsity_radius = compute_sparsity(cls.we_dense, method="radius", radius_um=50) + cls.sparsity_best = compute_sparsity(cls.we_dense, method="best_channels", num_channels=5) if (cache_folder / "mearec_test_sparse").is_dir(): cls.we_sparse = load_waveforms(cache_folder / "mearec_test_sparse") else: - cls.we_sparse = cls.we.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius) + cls.we_sparse = cls.we_dense.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius) cls.skip_backends = ["ipywidgets", "ephyviewer"] @@ -124,17 +124,17 @@ def test_plot_unit_waveforms(self): possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_waveforms(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_waveforms(self.we_dense, backend=backend, **self.backend_kwargs[backend]) unit_ids = self.sorting.unit_ids[:6] sw.plot_unit_waveforms( - self.we, + self.we_dense, sparsity=self.sparsity_radius, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend], ) sw.plot_unit_waveforms( - self.we, + self.we_dense, sparsity=self.sparsity_best, unit_ids=unit_ids, backend=backend, @@ -148,10 +148,10 @@ def test_plot_unit_templates(self): possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_templates(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_templates(self.we_dense, backend=backend, **self.backend_kwargs[backend]) unit_ids = self.sorting.unit_ids[:6] sw.plot_unit_templates( - self.we, + self.we_dense, sparsity=self.sparsity_radius, unit_ids=unit_ids, backend=backend, @@ -171,7 +171,7 @@ def test_plot_unit_waveforms_density_map(self): if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] sw.plot_unit_waveforms_density_map( - self.we, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) def test_plot_unit_waveforms_density_map_sparsity_radius(self): @@ -180,7 +180,7 @@ def test_plot_unit_waveforms_density_map_sparsity_radius(self): if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] sw.plot_unit_waveforms_density_map( - self.we, + self.we_dense, sparsity=self.sparsity_radius, same_axis=False, unit_ids=unit_ids, @@ -234,11 +234,11 @@ def test_amplitudes(self): possible_backends = list(sw.AmplitudesWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_amplitudes(self.we, backend=backend, **self.backend_kwargs[backend]) - unit_ids = self.we.unit_ids[:4] - sw.plot_amplitudes(self.we, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]) + sw.plot_amplitudes(self.we_dense, backend=backend, **self.backend_kwargs[backend]) + unit_ids = self.we_dense.unit_ids[:4] + sw.plot_amplitudes(self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]) sw.plot_amplitudes( - self.we, unit_ids=unit_ids, plot_histograms=True, backend=backend, **self.backend_kwargs[backend] + self.we_dense, unit_ids=unit_ids, plot_histograms=True, backend=backend, **self.backend_kwargs[backend] ) sw.plot_amplitudes( self.we_sparse, @@ -252,9 +252,9 @@ def test_plot_all_amplitudes_distributions(self): possible_backends = list(sw.AllAmplitudesDistributionsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - unit_ids = self.we.unit_ids[:4] + unit_ids = self.we_dense.unit_ids[:4] sw.plot_all_amplitudes_distributions( - self.we, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) sw.plot_all_amplitudes_distributions( self.we_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] @@ -264,7 +264,7 @@ def test_unit_locations(self): possible_backends = list(sw.UnitLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_locations(self.we, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_locations(self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) sw.plot_unit_locations( self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) @@ -273,7 +273,7 @@ def test_spike_locations(self): possible_backends = list(sw.SpikeLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_spike_locations(self.we, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) + sw.plot_spike_locations(self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) sw.plot_spike_locations( self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) @@ -282,28 +282,28 @@ def test_similarity(self): possible_backends = list(sw.TemplateSimilarityWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_template_similarity(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_similarity(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_template_similarity(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_quality_metrics(self): possible_backends = list(sw.QualityMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_quality_metrics(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_quality_metrics(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_quality_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_template_metrics(self): possible_backends = list(sw.TemplateMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_template_metrics(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_metrics(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_template_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_depths(self): possible_backends = list(sw.UnitDepthsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_depths(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_depths(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_unit_depths(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_summary(self): @@ -311,17 +311,17 @@ def test_plot_unit_summary(self): for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_summary( - self.we, self.we.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] + self.we_dense, self.we_dense.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] ) sw.plot_unit_summary( - self.we_sparse, self.we.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] + self.we_sparse, self.we_sparse.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] ) def test_sorting_summary(self): possible_backends = list(sw.SortingSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_sorting_summary(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_sorting_summary(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_sorting_summary(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_agreement_matrix(self): @@ -355,23 +355,23 @@ def test_plot_rasters(self): mytest = TestWidgets() mytest.setUpClass() - # mytest.test_plot_unit_waveforms_density_map() - # mytest.test_plot_unit_summary() - # mytest.test_plot_all_amplitudes_distributions() - # mytest.test_plot_traces() - # mytest.test_plot_unit_waveforms() - # mytest.test_plot_unit_templates() - # mytest.test_plot_unit_templates() - # mytest.test_plot_unit_depths() - # mytest.test_plot_unit_templates() - # mytest.test_plot_unit_summary() - # mytest.test_unit_locations() - # mytest.test_quality_metrics() - # mytest.test_template_metrics() - # mytest.test_amplitudes() - # mytest.test_plot_agreement_matrix() - # mytest.test_plot_confusion_matrix() - # mytest.test_plot_probe_map() + mytest.test_plot_unit_waveforms_density_map() + mytest.test_plot_unit_summary() + mytest.test_plot_all_amplitudes_distributions() + mytest.test_plot_traces() + mytest.test_plot_unit_waveforms() + mytest.test_plot_unit_templates() + mytest.test_plot_unit_templates() + mytest.test_plot_unit_depths() + mytest.test_plot_unit_templates() + mytest.test_plot_unit_summary() + mytest.test_unit_locations() + mytest.test_quality_metrics() + mytest.test_template_metrics() + mytest.test_amplitudes() + mytest.test_plot_agreement_matrix() + mytest.test_plot_confusion_matrix() + mytest.test_plot_probe_map() mytest.test_plot_rasters() # plt.ion() From 3ac58086dd8d46e02d433ee840378617d5d42e9d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Oct 2023 06:31:41 +0000 Subject: [PATCH 64/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/tests/test_widgets.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index da16136fa9..ca53d85648 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -238,7 +238,11 @@ def test_amplitudes(self): unit_ids = self.we_dense.unit_ids[:4] sw.plot_amplitudes(self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]) sw.plot_amplitudes( - self.we_dense, unit_ids=unit_ids, plot_histograms=True, backend=backend, **self.backend_kwargs[backend] + self.we_dense, + unit_ids=unit_ids, + plot_histograms=True, + backend=backend, + **self.backend_kwargs[backend], ) sw.plot_amplitudes( self.we_sparse, @@ -264,7 +268,9 @@ def test_unit_locations(self): possible_backends = list(sw.UnitLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_locations(self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_locations( + self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] + ) sw.plot_unit_locations( self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) @@ -273,7 +279,9 @@ def test_spike_locations(self): possible_backends = list(sw.SpikeLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_spike_locations(self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) + sw.plot_spike_locations( + self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] + ) sw.plot_spike_locations( self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) From 3448e1ec4b19d5f5091ba6a2792362cf35a9f941 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 08:57:56 +0200 Subject: [PATCH 65/74] Fix plot_traces with ipywidgets when channel_ids is not None --- src/spikeinterface/widgets/traces.py | 10 ++++++---- src/spikeinterface/widgets/utils_ipywidgets.py | 16 ++++++++++++++-- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 9b6716e8f3..2783b6a369 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -138,9 +138,10 @@ def __init__( # colors is a nested dict by layer and channels # lets first create black for all channels and layer + # all color are generated for ipywidgets colors = {} for k in layer_keys: - colors[k] = {chan_id: "k" for chan_id in channel_ids} + colors[k] = {chan_id: "k" for chan_id in rec0.channel_ids} if color_groups: channel_groups = rec0.get_channel_groups(channel_ids=channel_ids) @@ -149,7 +150,7 @@ def __init__( group_colors = get_some_colors(groups, color_engine="auto") channel_colors = {} - for i, chan_id in enumerate(channel_ids): + for i, chan_id in enumerate(rec0.channel_ids): group = channel_groups[i] channel_colors[chan_id] = group_colors[group] @@ -159,12 +160,12 @@ def __init__( elif color is not None: # old behavior one color for all channel # if multi layer then black for all - colors[layer_keys[0]] = {chan_id: color for chan_id in channel_ids} + colors[layer_keys[0]] = {chan_id: color for chan_id in rec0.channel_ids} elif color is None and len(recordings) > 1: # several layer layer_colors = get_some_colors(layer_keys) for k in layer_keys: - colors[k] = {chan_id: layer_colors[k] for chan_id in channel_ids} + colors[k] = {chan_id: layer_colors[k] for chan_id in rec0.channel_ids} else: # color is None unique layer : all channels black pass @@ -336,6 +337,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) self.scaler = ScaleWidget() self.channel_selector = ChannelSelector(self.rec0.channel_ids) + self.channel_selector.value = data_plot["channel_ids"] left_sidebar = W.VBox( children=[ diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 6e872eca55..5bbe31302c 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -235,8 +235,7 @@ def __init__(self, channel_ids, **kwargs): self.slider.observe(self.on_slider_changed, names=["value"], type="change") self.selector.observe(self.on_selector_changed, names=["value"], type="change") - # TODO external value change - # self.observe(self.value_changed, names=['value'], type="change") + self.observe(self.value_changed, names=['value'], type="change") def on_slider_changed(self, change=None): i0, i1 = self.slider.value @@ -259,6 +258,19 @@ def on_selector_changed(self, change=None): self.slider.observe(self.on_slider_changed, names=["value"], type="change") self.value = channel_ids + + def value_changed(self, change=None): + self.selector.unobserve(self.on_selector_changed, names=["value"], type="change") + self.selector.value = change["new"] + self.selector.observe(self.on_selector_changed, names=["value"], type="change") + + channel_ids = self.selector.value + self.slider.unobserve(self.on_slider_changed, names=["value"], type="change") + i0 = self.channel_ids.index(channel_ids[0]) + i1 = self.channel_ids.index(channel_ids[-1]) + 1 + self.slider.value = (i0, i1) + self.slider.observe(self.on_slider_changed, names=["value"], type="change") + class ScaleWidget(W.VBox): From e51bb75f226c7c2be97c4a6ceeae460a7c610efe Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 09:25:35 +0200 Subject: [PATCH 66/74] Fix order_channel_by_depth in ipywidgets Fix order_channel_by_depth when channel_ids is given. --- src/spikeinterface/widgets/traces.py | 58 +++++++++++++++------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 2783b6a369..802f90c62a 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -88,6 +88,26 @@ def __init__( else: raise ValueError("plot_traces recording must be recording or dict or list") + if "location" in rec0.get_property_keys(): + channel_locations = rec0.get_channel_locations() + else: + channel_locations = None + + if order_channel_by_depth and channel_locations is not None: + from ..preprocessing import depth_order + rec0 = depth_order(rec0) + recordings = {k: depth_order(rec) for k, rec in recordings.items()} + + if channel_ids is not None: + # ensure that channel_ids are in the good order + channel_ids_ = list(rec0.channel_ids) + order = np.argsort([channel_ids_.index(c) for c in channel_ids]) + channel_ids = list(np.array(channel_ids)[order]) + + if channel_ids is None: + channel_ids = rec0.channel_ids + + layer_keys = list(recordings.keys()) if segment_index is None: @@ -95,19 +115,6 @@ def __init__( raise ValueError("You must provide segment_index=...") segment_index = 0 - if channel_ids is None: - channel_ids = rec0.channel_ids - - if "location" in rec0.get_property_keys(): - channel_locations = rec0.get_channel_locations() - else: - channel_locations = None - - if order_channel_by_depth: - if channel_locations is not None: - order, _ = order_channels_by_depth(rec0, channel_ids) - else: - order = None fs = rec0.get_sampling_frequency() if time_range is None: @@ -124,7 +131,7 @@ def __init__( cmap = cmap times, list_traces, frame_range, channel_ids = _get_trace_list( - recordings, channel_ids, time_range, segment_index, order, return_scaled + recordings, channel_ids, time_range, segment_index, return_scaled=return_scaled ) # stat for auto scaling done on the first layer @@ -202,7 +209,6 @@ def __init__( show_channel_ids=show_channel_ids, add_legend=add_legend, order_channel_by_depth=order_channel_by_depth, - order=order, tile_size=tile_size, num_timepoints_per_row=int(seconds_per_row * fs), return_scaled=return_scaled, @@ -337,7 +343,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) self.scaler = ScaleWidget() self.channel_selector = ChannelSelector(self.rec0.channel_ids) - self.channel_selector.value = data_plot["channel_ids"] + self.channel_selector.value = list(data_plot["channel_ids"]) left_sidebar = W.VBox( children=[ @@ -400,17 +406,17 @@ def _mode_changed(self, change=None): def _retrieve_traces(self, change=None): channel_ids = np.array(self.channel_selector.value) - if self.data_plot["order_channel_by_depth"]: - order, _ = order_channels_by_depth(self.rec0, channel_ids) - else: - order = None + # if self.data_plot["order_channel_by_depth"]: + # order, _ = order_channels_by_depth(self.rec0, channel_ids) + # else: + # order = None start_frame, end_frame, segment_index = self.time_slider.value time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency self._selected_recordings = {k: self.recordings[k] for k in self._get_layers()} times, list_traces, frame_range, channel_ids = _get_trace_list( - self._selected_recordings, channel_ids, time_range, segment_index, order, self.return_scaled + self._selected_recordings, channel_ids, time_range, segment_index, return_scaled=self.return_scaled ) self._channel_ids = channel_ids @@ -525,7 +531,7 @@ def plot_ephyviewer(self, data_plot, **backend_kwargs): app.exec() -def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False): +def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_scaled=False): # function also used in ipywidgets plotter k0 = list(recordings.keys())[0] rec0 = recordings[k0] @@ -552,11 +558,11 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=No return_scaled=return_scaled, ) - if order is not None: - traces = traces[:, order] + # if order is not None: + # traces = traces[:, order] list_traces.append(traces) - if order is not None: - channel_ids = np.array(channel_ids)[order] + # if order is not None: + # channel_ids = np.array(channel_ids)[order] return times, list_traces, frame_range, channel_ids From bc3234cc4ce7d35cd62e0c29e33e38002f43ecd0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 09:52:20 +0200 Subject: [PATCH 67/74] More fix in widgets due to sparse=True by default --- .../tests/test_widgets_legacy.py | 6 +- .../widgets/tests/test_widgets.py | 57 +++++++++---------- 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py index 39eb80e2e5..8814e0131a 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py @@ -32,10 +32,10 @@ def setUp(self): self.num_units = len(self._sorting.get_unit_ids()) #  self._we = extract_waveforms(self._rec, self._sorting, './toy_example', load_if_exists=True) - if (cache_folder / "mearec_test").is_dir(): - self._we = load_waveforms(cache_folder / "mearec_test") + if (cache_folder / "mearec_test_old_api").is_dir(): + self._we = load_waveforms(cache_folder / "mearec_test_old_api") else: - self._we = extract_waveforms(self._rec, self._sorting, cache_folder / "mearec_test") + self._we = extract_waveforms(self._rec, self._sorting, cache_folder / "mearec_test_old_api", sparse=False) self._amplitudes = compute_spike_amplitudes(self._we, peak_sign="neg", outputs="by_unit") self._gt_comp = sc.compare_sorter_to_ground_truth(self._sorting, self._sorting) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index ca53d85648..5f1a936a6e 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -48,22 +48,21 @@ def setUpClass(cls): cls.sorting = se.MEArecSortingExtractor(local_path) cls.num_units = len(cls.sorting.get_unit_ids()) - if (cache_folder / "mearec_test").is_dir(): - cls.we_dense = load_waveforms(cache_folder / "mearec_test") + if (cache_folder / "mearec_test_dense").is_dir(): + cls.we_dense = load_waveforms(cache_folder / "mearec_test_dense") else: - cls.we_dense = extract_waveforms(cls.recording, cls.sorting, cache_folder / "mearec_test", sparse=False) + cls.we_dense = extract_waveforms(cls.recording, cls.sorting, cache_folder / "mearec_test_dense", sparse=False) + metric_names = ["snr", "isi_violation", "num_spikes"] + _ = compute_spike_amplitudes(cls.we_dense) + _ = compute_unit_locations(cls.we_dense) + _ = compute_spike_locations(cls.we_dense) + _ = compute_quality_metrics(cls.we_dense, metric_names=metric_names) + _ = compute_template_metrics(cls.we_dense) + _ = compute_correlograms(cls.we_dense) + _ = compute_template_similarity(cls.we_dense) sw.set_default_plotter_backend("matplotlib") - metric_names = ["snr", "isi_violation", "num_spikes"] - _ = compute_spike_amplitudes(cls.we_dense) - _ = compute_unit_locations(cls.we_dense) - _ = compute_spike_locations(cls.we_dense) - _ = compute_quality_metrics(cls.we_dense, metric_names=metric_names) - _ = compute_template_metrics(cls.we_dense) - _ = compute_correlograms(cls.we_dense) - _ = compute_template_similarity(cls.we_dense) - # make sparse waveforms cls.sparsity_radius = compute_sparsity(cls.we_dense, method="radius", radius_um=50) cls.sparsity_best = compute_sparsity(cls.we_dense, method="best_channels", num_channels=5) @@ -363,24 +362,24 @@ def test_plot_rasters(self): mytest = TestWidgets() mytest.setUpClass() - mytest.test_plot_unit_waveforms_density_map() - mytest.test_plot_unit_summary() - mytest.test_plot_all_amplitudes_distributions() - mytest.test_plot_traces() - mytest.test_plot_unit_waveforms() - mytest.test_plot_unit_templates() - mytest.test_plot_unit_templates() - mytest.test_plot_unit_depths() - mytest.test_plot_unit_templates() - mytest.test_plot_unit_summary() - mytest.test_unit_locations() - mytest.test_quality_metrics() - mytest.test_template_metrics() - mytest.test_amplitudes() + # mytest.test_plot_unit_waveforms_density_map() + # mytest.test_plot_unit_summary() + # mytest.test_plot_all_amplitudes_distributions() + # mytest.test_plot_traces() + # mytest.test_plot_unit_waveforms() + # mytest.test_plot_unit_templates() + # mytest.test_plot_unit_templates() + # mytest.test_plot_unit_depths() + # mytest.test_plot_unit_templates() + # mytest.test_plot_unit_summary() + # mytest.test_unit_locations() + # mytest.test_quality_metrics() + # mytest.test_template_metrics() + # mytest.test_amplitudes() mytest.test_plot_agreement_matrix() - mytest.test_plot_confusion_matrix() - mytest.test_plot_probe_map() - mytest.test_plot_rasters() + # mytest.test_plot_confusion_matrix() + # mytest.test_plot_probe_map() + # mytest.test_plot_rasters() # plt.ion() plt.show() From 7cd60ac434288e7eb9d43684e0b575396f70daaa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Oct 2023 07:52:41 +0000 Subject: [PATCH 68/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/tests/test_widgets.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 5f1a936a6e..1a2fdf38d9 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -51,7 +51,9 @@ def setUpClass(cls): if (cache_folder / "mearec_test_dense").is_dir(): cls.we_dense = load_waveforms(cache_folder / "mearec_test_dense") else: - cls.we_dense = extract_waveforms(cls.recording, cls.sorting, cache_folder / "mearec_test_dense", sparse=False) + cls.we_dense = extract_waveforms( + cls.recording, cls.sorting, cache_folder / "mearec_test_dense", sparse=False + ) metric_names = ["snr", "isi_violation", "num_spikes"] _ = compute_spike_amplitudes(cls.we_dense) _ = compute_unit_locations(cls.we_dense) @@ -366,7 +368,7 @@ def test_plot_rasters(self): # mytest.test_plot_unit_summary() # mytest.test_plot_all_amplitudes_distributions() # mytest.test_plot_traces() - # mytest.test_plot_unit_waveforms() + # mytest.test_plot_unit_waveforms() # mytest.test_plot_unit_templates() # mytest.test_plot_unit_templates() # mytest.test_plot_unit_depths() From 5c5f32fb0df19cb5faf7e24c11758639c1740f18 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 09:53:33 +0200 Subject: [PATCH 69/74] yep --- src/spikeinterface/widgets/traces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 802f90c62a..d010c96a27 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -88,7 +88,7 @@ def __init__( else: raise ValueError("plot_traces recording must be recording or dict or list") - if "location" in rec0.get_property_keys(): + if rec0.has_channel_locations(): channel_locations = rec0.get_channel_locations() else: channel_locations = None From 986d6d9f26417740dd7162e671db3082363930f6 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 10:20:20 +0200 Subject: [PATCH 70/74] Fix fix with sparse waveform extractor --- src/spikeinterface/exporters/tests/test_export_to_phy.py | 6 +++--- src/spikeinterface/exporters/to_phy.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/exporters/tests/test_export_to_phy.py b/src/spikeinterface/exporters/tests/test_export_to_phy.py index 7528f0ebf9..39bb875ea8 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_phy.py +++ b/src/spikeinterface/exporters/tests/test_export_to_phy.py @@ -78,7 +78,7 @@ def test_export_to_phy_by_property(): recording = recording.save(folder=rec_folder) sorting = sorting.save(folder=sort_folder) - waveform_extractor = extract_waveforms(recording, sorting, waveform_folder) + waveform_extractor = extract_waveforms(recording, sorting, waveform_folder, sparse=False) sparsity_group = compute_sparsity(waveform_extractor, method="by_property", by_property="group") export_to_phy( waveform_extractor, @@ -96,7 +96,7 @@ def test_export_to_phy_by_property(): # Remove one channel recording_rm = recording.channel_slice([0, 2, 3, 4, 5, 6, 7]) - waveform_extractor_rm = extract_waveforms(recording_rm, sorting, waveform_folder_rm) + waveform_extractor_rm = extract_waveforms(recording_rm, sorting, waveform_folder_rm, sparse=False) sparsity_group = compute_sparsity(waveform_extractor_rm, method="by_property", by_property="group") export_to_phy( @@ -130,7 +130,7 @@ def test_export_to_phy_by_sparsity(): if f.is_dir(): shutil.rmtree(f) - waveform_extractor = extract_waveforms(recording, sorting, waveform_folder) + waveform_extractor = extract_waveforms(recording, sorting, waveform_folder, sparse=False) sparsity_radius = compute_sparsity(waveform_extractor, method="radius", radius_um=50.0) export_to_phy( waveform_extractor, diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index ebc810b953..31a452f389 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -94,6 +94,7 @@ def export_to_phy( if waveform_extractor.is_sparse(): used_sparsity = waveform_extractor.sparsity + assert sparsity is None elif sparsity is not None: used_sparsity = sparsity else: From 63494f2a44424085d7ad22935313f9cbd2c8b88c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Oct 2023 09:11:43 +0000 Subject: [PATCH 71/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/traces.py | 3 +-- src/spikeinterface/widgets/utils_ipywidgets.py | 5 ++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index d010c96a27..7a4306b284 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -95,6 +95,7 @@ def __init__( if order_channel_by_depth and channel_locations is not None: from ..preprocessing import depth_order + rec0 = depth_order(rec0) recordings = {k: depth_order(rec) for k, rec in recordings.items()} @@ -107,7 +108,6 @@ def __init__( if channel_ids is None: channel_ids = rec0.channel_ids - layer_keys = list(recordings.keys()) if segment_index is None: @@ -115,7 +115,6 @@ def __init__( raise ValueError("You must provide segment_index=...") segment_index = 0 - fs = rec0.get_sampling_frequency() if time_range is None: time_range = (0, 1.0) diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 5bbe31302c..58dd5c7f32 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -235,7 +235,7 @@ def __init__(self, channel_ids, **kwargs): self.slider.observe(self.on_slider_changed, names=["value"], type="change") self.selector.observe(self.on_selector_changed, names=["value"], type="change") - self.observe(self.value_changed, names=['value'], type="change") + self.observe(self.value_changed, names=["value"], type="change") def on_slider_changed(self, change=None): i0, i1 = self.slider.value @@ -258,7 +258,7 @@ def on_selector_changed(self, change=None): self.slider.observe(self.on_slider_changed, names=["value"], type="change") self.value = channel_ids - + def value_changed(self, change=None): self.selector.unobserve(self.on_selector_changed, names=["value"], type="change") self.selector.value = change["new"] @@ -272,7 +272,6 @@ def value_changed(self, change=None): self.slider.observe(self.on_slider_changed, names=["value"], type="change") - class ScaleWidget(W.VBox): value = traitlets.Float() From 5660de282ac43d96324184d47aa2d951910d6fec Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 6 Oct 2023 11:16:24 +0200 Subject: [PATCH 72/74] Simplify parsing in cellexplorer --- src/spikeinterface/extractors/cellexplorersortingextractor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index 0096a40a79..0980e89f1c 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -118,7 +118,6 @@ def __init__( spike_times = spikes_data["times"] # CellExplorer reports spike times in units seconds; SpikeExtractors uses time units of sampling frames - unit_ids = unit_ids[:].astype(int).tolist() unit_ids = [str(unit_id) for unit_id in unit_ids] spiketrains_dict = {unit_id: spike_times[index] for index, unit_id in enumerate(unit_ids)} for unit_id in unit_ids: From c0d4c60095f9704f9b27adfb5fa0f4867adfaf10 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 11:38:15 +0200 Subject: [PATCH 73/74] oups --- src/spikeinterface/widgets/traces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index d010c96a27..ce34af0bfa 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -88,7 +88,7 @@ def __init__( else: raise ValueError("plot_traces recording must be recording or dict or list") - if rec0.has_channel_locations(): + if rec0.has_channel_location(): channel_locations = rec0.get_channel_locations() else: channel_locations = None From 2907934928719cf8d0403a2c55628645483187f7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 11:48:37 +0200 Subject: [PATCH 74/74] clean --- src/spikeinterface/widgets/traces.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 5a8212302c..fc8b30eb05 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -557,11 +557,6 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_s return_scaled=return_scaled, ) - # if order is not None: - # traces = traces[:, order] list_traces.append(traces) - # if order is not None: - # channel_ids = np.array(channel_ids)[order] - return times, list_traces, frame_range, channel_ids