From 36474a8cea03c9905dd21931af2c9de418d37798 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 2 Sep 2024 17:36:23 +0200 Subject: [PATCH 01/61] Refactor the get_random_data_chunks with an internal function. to allow more methods --- src/spikeinterface/core/job_tools.py | 28 +-- src/spikeinterface/core/recording_tools.py | 163 ++++++++++++++---- .../core/tests/test_recording_tools.py | 16 +- 3 files changed, 154 insertions(+), 53 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index a5279247f5..1aa9ac9333 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -187,6 +187,21 @@ def ensure_n_jobs(recording, n_jobs=1): return n_jobs +def chunk_duration_to_chunk_size(chunk_duration, recording): + if isinstance(chunk_duration, float): + chunk_size = int(chunk_duration * recording.get_sampling_frequency()) + elif isinstance(chunk_duration, str): + if chunk_duration.endswith("ms"): + chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0 + elif chunk_duration.endswith("s"): + chunk_duration = float(chunk_duration.replace("s", "")) + else: + raise ValueError("chunk_duration must ends with s or ms") + chunk_size = int(chunk_duration * recording.get_sampling_frequency()) + else: + raise ValueError("chunk_duration must be str or float") + return chunk_size + def ensure_chunk_size( recording, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs @@ -234,18 +249,7 @@ def ensure_chunk_size( num_channels = recording.get_num_channels() chunk_size = int(total_memory / (num_channels * n_bytes * n_jobs)) elif chunk_duration is not None: - if isinstance(chunk_duration, float): - chunk_size = int(chunk_duration * recording.get_sampling_frequency()) - elif isinstance(chunk_duration, str): - if chunk_duration.endswith("ms"): - chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0 - elif chunk_duration.endswith("s"): - chunk_duration = float(chunk_duration.replace("s", "")) - else: - raise ValueError("chunk_duration must ends with s or ms") - chunk_size = int(chunk_duration * recording.get_sampling_frequency()) - else: - raise ValueError("chunk_duration must be str or float") + chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording) else: # Edge case to define single chunk per segment for n_jobs=1. # All chunking parameters equal None mean single chunk per segment diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 2c7e75668f..764b6f0c66 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -18,6 +18,7 @@ fix_job_kwargs, ChunkRecordingExecutor, _shared_job_kwargs_doc, + chunk_duration_to_chunk_size, ) @@ -509,6 +510,87 @@ def determine_cast_unsigned(recording, dtype): return cast_unsigned + + +def get_random_recording_slices(recording, + method="legacy", + num_chunks_per_segment=20, + chunk_duration="500ms", + chunk_size=None, + margin_frames=0, + seed=None): + """ + Get random slice of a recording across segments. + + This is used for instance in get_noise_levels() and get_random_data_chunks() to estimate noise on traces. + + Parameters + ---------- + recording : BaseRecording + The recording to get random chunks from + methid : "legacy" + The method used. + num_chunks_per_segment : int, default: 20 + Number of chunks per segment + chunk_duration : str | float | None, default "500ms" + The duration of each chunk in 's' or 'ms' + chunk_size : int | None + Size of a chunk in number of frames + + concatenated : bool, default: True + If True chunk are concatenated along time axis + seed : int, default: 0 + Random seed + margin_frames : int, default: 0 + Margin in number of frames to avoid edge effects + + Returns + ------- + chunk_list : np.array + Array of concatenate chunks per segment + + + """ + # TODO: if segment have differents length make another sampling that dependant on the length of the segment + # Should be done by changing kwargs with total_num_chunks=XXX and total_duration=YYYY + # And randomize the number of chunk per segment weighted by segment duration + + if method == "legacy": + if chunk_size is None: + if chunk_duration is not None: + chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording) + else: + raise ValueError("get_random_recording_slices need chunk_size or chunk_duration") + + # check chunk size + num_segments = recording.get_num_segments() + for segment_index in range(num_segments): + chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames + if chunk_size > chunk_size_limit: + chunk_size = chunk_size_limit - 1 + warnings.warn( + f"chunk_size is greater than the number " + f"of samples for segment index {segment_index}. " + f"Using {chunk_size}." + ) + rng = np.random.default_rng(seed) + recording_slices = [] + low = margin_frames + size = num_chunks_per_segment + for segment_index in range(num_segments): + num_frames = recording.get_num_frames(segment_index) + high = num_frames - chunk_size - margin_frames + random_starts = rng.integers(low=low, high=high, size=size) + random_starts = np.sort(random_starts) + recording_slices += [ + (segment_index, start_frame, (start_frame + chunk_size)) for start_frame in random_starts + ] + else: + raise ValueError(f"get_random_recording_slices : wrong method {method}") + + return recording_slices + + def get_random_data_chunks( recording, return_scaled=False, @@ -545,41 +627,56 @@ def get_random_data_chunks( chunk_list : np.array Array of concatenate chunks per segment """ - # TODO: if segment have differents length make another sampling that dependant on the length of the segment - # Should be done by changing kwargs with total_num_chunks=XXX and total_duration=YYYY - # And randomize the number of chunk per segment weighted by segment duration - - # check chunk size - num_segments = recording.get_num_segments() - for segment_index in range(num_segments): - chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames - if chunk_size > chunk_size_limit: - chunk_size = chunk_size_limit - 1 - warnings.warn( - f"chunk_size is greater than the number " - f"of samples for segment index {segment_index}. " - f"Using {chunk_size}." - ) + # # check chunk size + # num_segments = recording.get_num_segments() + # for segment_index in range(num_segments): + # chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames + # if chunk_size > chunk_size_limit: + # chunk_size = chunk_size_limit - 1 + # warnings.warn( + # f"chunk_size is greater than the number " + # f"of samples for segment index {segment_index}. " + # f"Using {chunk_size}." + # ) + + # rng = np.random.default_rng(seed) + # chunk_list = [] + # low = margin_frames + # size = num_chunks_per_segment + # for segment_index in range(num_segments): + # num_frames = recording.get_num_frames(segment_index) + # high = num_frames - chunk_size - margin_frames + # random_starts = rng.integers(low=low, high=high, size=size) + # segment_trace_chunk = [ + # recording.get_traces( + # start_frame=start_frame, + # end_frame=(start_frame + chunk_size), + # segment_index=segment_index, + # return_scaled=return_scaled, + # ) + # for start_frame in random_starts + # ] + + # chunk_list.extend(segment_trace_chunk) + + recording_slices = get_random_recording_slices(recording, + method="legacy", + num_chunks_per_segment=num_chunks_per_segment, + chunk_size=chunk_size, + # chunk_duration=chunk_duration, + margin_frames=margin_frames, + seed=seed) + print(recording_slices) - rng = np.random.default_rng(seed) chunk_list = [] - low = margin_frames - size = num_chunks_per_segment - for segment_index in range(num_segments): - num_frames = recording.get_num_frames(segment_index) - high = num_frames - chunk_size - margin_frames - random_starts = rng.integers(low=low, high=high, size=size) - segment_trace_chunk = [ - recording.get_traces( - start_frame=start_frame, - end_frame=(start_frame + chunk_size), - segment_index=segment_index, - return_scaled=return_scaled, - ) - for start_frame in random_starts - ] - - chunk_list.extend(segment_trace_chunk) + for segment_index, start_frame, stop_frame in recording_slices: + traces_chunk = recording.get_traces( + start_frame=start_frame, + end_frame=(start_frame + chunk_size), + segment_index=segment_index, + return_scaled=return_scaled, + ) + chunk_list.append(traces_chunk) if concatenated: return np.concatenate(chunk_list, axis=0) diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 23a1574f2a..e54981744d 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -333,14 +333,14 @@ def test_do_recording_attributes_match(): if __name__ == "__main__": # Create a temporary folder using the standard library - import tempfile + # import tempfile - with tempfile.TemporaryDirectory() as tmpdirname: - tmp_path = Path(tmpdirname) - test_write_binary_recording(tmp_path) - test_write_memory_recording() + # with tempfile.TemporaryDirectory() as tmpdirname: + # tmp_path = Path(tmpdirname) + # test_write_binary_recording(tmp_path) + # test_write_memory_recording() test_get_random_data_chunks() - test_get_closest_channels() - test_get_noise_levels() - test_order_channels_by_depth() + # test_get_closest_channels() + # test_get_noise_levels() + # test_order_channels_by_depth() From 63574ef6a45b948f228bae94dbf090a6486279e7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 3 Sep 2024 17:08:10 +0200 Subject: [PATCH 02/61] Noise level in parallel --- src/spikeinterface/core/job_tools.py | 6 +- src/spikeinterface/core/recording_tools.py | 71 ++++++++++++++++--- .../core/tests/test_recording_tools.py | 21 +++--- 3 files changed, 78 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 1aa9ac9333..45d04e83df 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -389,11 +389,13 @@ def __init__( f"chunk_duration={chunk_duration_str}", ) - def run(self): + def run(self, all_chunks=None): """ Runs the defined jobs. """ - all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size) + + if all_chunks is None: + all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size) if self.handle_returns: returns = [] diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 764b6f0c66..37fcd9714a 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -19,6 +19,7 @@ ChunkRecordingExecutor, _shared_job_kwargs_doc, chunk_duration_to_chunk_size, + split_job_kwargs, ) @@ -666,7 +667,6 @@ def get_random_data_chunks( # chunk_duration=chunk_duration, margin_frames=margin_frames, seed=seed) - print(recording_slices) chunk_list = [] for segment_index, start_frame, stop_frame in recording_slices: @@ -731,12 +731,42 @@ def get_closest_channels(recording, channel_ids=None, num_channels=None): return np.array(closest_channels_inds), np.array(dists) +def _noise_level_chunk(segment_index, start_frame, end_frame, worker_ctx): + recording = worker_ctx["recording"] + + one_chunk = recording.get_traces( + start_frame=start_frame, + end_frame=end_frame, + segment_index=segment_index, + return_scaled=worker_ctx["return_scaled"], + ) + + + if worker_ctx["method"] == "mad": + med = np.median(one_chunk, axis=0, keepdims=True) + # hard-coded so that core doesn't depend on scipy + noise_levels = np.median(np.abs(one_chunk - med), axis=0) / 0.6744897501960817 + elif worker_ctx["method"] == "std": + noise_levels = np.std(one_chunk, axis=0) + + return noise_levels + + +def _noise_level_chunk_init(recording, return_scaled, method): + worker_ctx = {} + worker_ctx["recording"] = recording + worker_ctx["return_scaled"] = return_scaled + worker_ctx["method"] = method + return worker_ctx + def get_noise_levels( recording: "BaseRecording", return_scaled: bool = True, method: Literal["mad", "std"] = "mad", force_recompute: bool = False, - **random_chunk_kwargs, + **kwargs, + # **random_chunk_kwargs, + # **job_kwargs ): """ Estimate noise for each channel using MAD methods. @@ -773,19 +803,40 @@ def get_noise_levels( if key in recording.get_property_keys() and not force_recompute: noise_levels = recording.get_property(key=key) else: - random_chunks = get_random_data_chunks(recording, return_scaled=return_scaled, **random_chunk_kwargs) - - if method == "mad": - med = np.median(random_chunks, axis=0, keepdims=True) - # hard-coded so that core doesn't depend on scipy - noise_levels = np.median(np.abs(random_chunks - med), axis=0) / 0.6744897501960817 - elif method == "std": - noise_levels = np.std(random_chunks, axis=0) + # random_chunks = get_random_data_chunks(recording, return_scaled=return_scaled, **random_chunk_kwargs) + + # if method == "mad": + # med = np.median(random_chunks, axis=0, keepdims=True) + # # hard-coded so that core doesn't depend on scipy + # noise_levels = np.median(np.abs(random_chunks - med), axis=0) / 0.6744897501960817 + # elif method == "std": + # noise_levels = np.std(random_chunks, axis=0) + + random_slices_kwargs, job_kwargs = split_job_kwargs(kwargs) + recording_slices = get_random_recording_slices(recording,**random_slices_kwargs) + + noise_levels_chunks = [] + def append_noise_chunk(res): + noise_levels_chunks.append(res) + + func = _noise_level_chunk + init_func = _noise_level_chunk_init + init_args = (recording, return_scaled, method) + executor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name="noise_level", verbose=False, + gather_func=append_noise_chunk, **job_kwargs + ) + executor.run(all_chunks=recording_slices) + noise_levels_chunks = np.stack(noise_levels_chunks) + noise_levels = np.mean(noise_levels_chunks, axis=0) + + # set property recording.set_property(key, noise_levels) return noise_levels + def get_chunk_with_margin( rec_segment, start_frame, diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index e54981744d..918e15803a 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -166,6 +166,9 @@ def test_write_memory_recording(): for shm in shms: shm.unlink() +def test_get_random_recording_slices(): + # TODO + pass def test_get_random_data_chunks(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) @@ -182,16 +185,17 @@ def test_get_closest_channels(): def test_get_noise_levels(): + job_kwargs = dict(n_jobs=1, progress_bar=True) rec = generate_recording(num_channels=2, sampling_frequency=1000.0, durations=[60.0]) - noise_levels_1 = get_noise_levels(rec, return_scaled=False) - noise_levels_2 = get_noise_levels(rec, return_scaled=False) + noise_levels_1 = get_noise_levels(rec, return_scaled=False, **job_kwargs) + noise_levels_2 = get_noise_levels(rec, return_scaled=False, **job_kwargs) rec.set_channel_gains(0.1) rec.set_channel_offsets(0) - noise_levels = get_noise_levels(rec, return_scaled=True, force_recompute=True) + noise_levels = get_noise_levels(rec, return_scaled=True, force_recompute=True, **job_kwargs) - noise_levels = get_noise_levels(rec, return_scaled=True, method="std") + noise_levels = get_noise_levels(rec, return_scaled=True, method="std", **job_kwargs) # Generate a recording following a gaussian distribution to check the result of get_noise. std = 6.0 @@ -201,8 +205,8 @@ def test_get_noise_levels(): recording = NumpyRecording(traces, 30000) assert np.all(noise_levels_1 == noise_levels_2) - assert np.allclose(get_noise_levels(recording, return_scaled=False), [std, std], rtol=1e-2, atol=1e-3) - assert np.allclose(get_noise_levels(recording, method="std", return_scaled=False), [std, std], rtol=1e-2, atol=1e-3) + assert np.allclose(get_noise_levels(recording, return_scaled=False, **job_kwargs), [std, std], rtol=1e-2, atol=1e-3) + assert np.allclose(get_noise_levels(recording, method="std", return_scaled=False, **job_kwargs), [std, std], rtol=1e-2, atol=1e-3) def test_get_noise_levels_output(): @@ -340,7 +344,8 @@ def test_do_recording_attributes_match(): # test_write_binary_recording(tmp_path) # test_write_memory_recording() - test_get_random_data_chunks() + # test_get_random_recording_slices() + # test_get_random_data_chunks() # test_get_closest_channels() - # test_get_noise_levels() + test_get_noise_levels() # test_order_channels_by_depth() From ce7338e1bef86448c8c94bbde070745a83070592 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 12 Sep 2024 09:28:09 -0600 Subject: [PATCH 03/61] expose reading attempts --- .../extractors/neoextractors/plexon2.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 2f360ed864..7a5f463fef 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -28,6 +28,10 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): ids: ["source3.1" , "source3.2", "source3.3", "source3.4"] all_annotations : bool, default: False Load exhaustively all annotations from neo. + readding_attemps : int, default: 25 + Number of attempts to read the file before raising an error + This opening process is somewhat unreliable and might fail occasionally. Adjust this higher + if you encounter problems in opening the file. Examples -------- @@ -37,7 +41,15 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): NeoRawIOClass = "Plexon2RawIO" - def __init__(self, file_path, stream_id=None, stream_name=None, use_names_as_ids=True, all_annotations=False): + def __init__( + self, + file_path, + stream_id=None, + stream_name=None, + use_names_as_ids=True, + all_annotations=False, + readding_attemps: int = 25, + ): neo_kwargs = self.map_to_neo_kwargs(file_path) NeoBaseRecordingExtractor.__init__( self, @@ -45,6 +57,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, use_names_as_ids stream_name=stream_name, all_annotations=all_annotations, use_names_as_ids=use_names_as_ids, + readding_attemps=readding_attemps, **neo_kwargs, ) self._kwargs.update({"file_path": str(file_path)}) From 98cd18d7d7dca067a20eb626067fd90ae784a7b8 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 12 Sep 2024 10:06:22 -0600 Subject: [PATCH 04/61] take into account neo version --- .../extractors/neoextractors/plexon2.py | 17 +++++++++++++---- .../extractors/tests/test_neoextractors.py | 2 +- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 7a5f463fef..1f0d40a253 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -48,23 +48,32 @@ def __init__( stream_name=None, use_names_as_ids=True, all_annotations=False, - readding_attemps: int = 25, + reading_attempts: int = 25, ): - neo_kwargs = self.map_to_neo_kwargs(file_path) + neo_kwargs = self.map_to_neo_kwargs(file_path, reading_attempts=reading_attempts) NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, use_names_as_ids=use_names_as_ids, - readding_attemps=readding_attemps, **neo_kwargs, ) self._kwargs.update({"file_path": str(file_path)}) @classmethod - def map_to_neo_kwargs(cls, file_path): + def map_to_neo_kwargs(cls, file_path, reading_attempts: int = 25): + neo_kwargs = {"filename": str(file_path)} + + from packaging.version import Version + import neo + + neo_version = Version(neo.__version__) + + if neo_version > Version("0.13.3"): + neo_kwargs["reading_attempts"] = reading_attempts + return neo_kwargs diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index acd7ebe8ad..33d02fbde2 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -359,7 +359,7 @@ class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2RecordingExtractor downloads = ["plexon"] entities = [ - ("plexon/4chDemoPL2.pl2", {"stream_id": "3"}), + ("plexon/4chDemoPL2.pl2", {"stream_name": "WB-Wideband"}), ] From ae5edd6e717a2e030886079c278b67fc2de6a0b4 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 19 Sep 2024 11:55:01 -0600 Subject: [PATCH 05/61] Update src/spikeinterface/extractors/neoextractors/plexon2.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/extractors/neoextractors/plexon2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 1f0d40a253..e0604f7496 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -28,7 +28,7 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): ids: ["source3.1" , "source3.2", "source3.3", "source3.4"] all_annotations : bool, default: False Load exhaustively all annotations from neo. - readding_attemps : int, default: 25 + reading_attempts : int, default: 25 Number of attempts to read the file before raising an error This opening process is somewhat unreliable and might fail occasionally. Adjust this higher if you encounter problems in opening the file. From f46d13e0810ea66193206e9c49dc7bb7cc388f7c Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 24 Sep 2024 12:01:36 +0200 Subject: [PATCH 06/61] Refactoring auto_merge --- src/spikeinterface/curation/auto_merge.py | 339 ++++++++++++++++------ 1 file changed, 251 insertions(+), 88 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 19336e5943..00c156094d 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -12,8 +12,6 @@ HAVE_NUMBA = False from ..core import SortingAnalyzer, Templates -from ..core.template_tools import get_template_extremum_channel -from ..postprocessing import compute_correlograms from ..qualitymetrics import compute_refrac_period_violations, compute_firing_rates from .mergeunitssorting import MergeUnitsSorting @@ -25,35 +23,43 @@ _required_extensions = { "unit_locations": ["unit_locations"], "correlogram": ["correlograms"], + "min_snr": ["noise_levels", "templates"], "template_similarity": ["template_similarity"], "knn": ["spike_locations", "spike_amplitudes"], } +_templates_needed = ["unit_locations", "min_snr", "template_similarity", "spike_locations", "spike_amplitudes"] -def get_potential_auto_merge( + +def auto_merges( sorting_analyzer: SortingAnalyzer, preset: str | None = "similarity_correlograms", resolve_graph: bool = False, - min_spikes: int = 100, - min_snr: float = 2, - max_distance_um: float = 150.0, - corr_diff_thresh: float = 0.16, - template_diff_thresh: float = 0.25, - contamination_thresh: float = 0.2, - presence_distance_thresh: float = 100, - p_value: float = 0.2, - cc_thresh: float = 0.1, - censored_period_ms: float = 0.3, - refractory_period_ms: float = 1.0, - sigma_smooth_ms: float = 0.6, - adaptative_window_thresh: float = 0.5, - censor_correlograms_ms: float = 0.15, - firing_contamination_balance: float = 2.5, - k_nn: int = 10, - knn_kwargs: dict | None = None, - presence_distance_kwargs: dict | None = None, + num_spikes_kwargs={"min_spikes": 100}, + snr_kwargs={"min_snr": 2}, + remove_contaminated_kwargs={"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + unit_locations_kwargs={"max_distance_um": 50}, + correlogram_kwargs={ + "corr_diff_thresh": 0.16, + "censor_correlograms_ms": 0.3, + "sigma_smooth_ms": 0.6, + "adaptative_window_thresh": 0.5, + }, + template_similarity_kwargs={"template_diff_thresh": 0.25}, + presence_distance_kwargs={"presence_distance_thresh": 100}, + knn_kwargs={"k_nn": 10}, + cross_contamination_kwargs={ + "cc_thresh": 0.1, + "p_value": 0.2, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3, + }, + quality_score_kwargs={"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + compute_needed_extensions: bool = True, extra_outputs: bool = False, steps: list[str] | None = None, + force_copy: bool = True, + **job_kwargs, ) -> list[tuple[int | str, int | str]] | Tuple[tuple[int | str, int | str], dict]: """ Algorithm to find and check potential merges between units. @@ -98,56 +104,21 @@ def get_potential_auto_merge( * | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN. | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations", | "knn", "quality_score" - If `preset` is None, you can specify the steps manually with the `steps` parameter. resolve_graph : bool, default: False If True, the function resolves the potential unit pairs to be merged into multiple-unit merges. - min_spikes : int, default: 100 - Minimum number of spikes for each unit to consider a potential merge. - Enough spikes are needed to estimate the correlogram - min_snr : float, default 2 - Minimum Signal to Noise ratio for templates to be considered while merging - max_distance_um : float, default: 150 - Maximum distance between units for considering a merge - corr_diff_thresh : float, default: 0.16 - The threshold on the "correlogram distance metric" for considering a merge. - It needs to be between 0 and 1 - template_diff_thresh : float, default: 0.25 - The threshold on the "template distance metric" for considering a merge. - It needs to be between 0 and 1 - contamination_thresh : float, default: 0.2 - Threshold for not taking in account a unit when it is too contaminated. - presence_distance_thresh : float, default: 100 - Parameter to control how present two units should be simultaneously. - p_value : float, default: 0.2 - The p-value threshold for the cross-contamination test. - cc_thresh : float, default: 0.1 - The threshold on the cross-contamination for considering a merge. - censored_period_ms : float, default: 0.3 - Used to compute the refractory period violations aka "contamination". - refractory_period_ms : float, default: 1 - Used to compute the refractory period violations aka "contamination". - sigma_smooth_ms : float, default: 0.6 - Parameters to smooth the correlogram estimation. - adaptative_window_thresh : float, default: 0.5 - Parameter to detect the window size in correlogram estimation. - censor_correlograms_ms : float, default: 0.15 - The period to censor on the auto and cross-correlograms. - firing_contamination_balance : float, default: 2.5 - Parameter to control the balance between firing rate and contamination in computing unit "quality score". - k_nn : int, default 5 - The number of neighbors to consider for every spike in the recording. - knn_kwargs : dict, default None - The dict of extra params to be passed to knn. + compute_needed_extensions : bool, default : True + Should we force the computation of needed extensions? extra_outputs : bool, default: False If True, an additional dictionary (`outs`) with processed data is returned. steps : None or list of str, default: None Which steps to run, if no preset is used. Pontential steps : "num_spikes", "snr", "remove_contaminated", "unit_locations", "correlogram", "template_similarity", "presence_distance", "cross_contamination", "knn", "quality_score" - Please check steps explanations above! - presence_distance_kwargs : None|dict, default: None - A dictionary of kwargs to be passed to compute_presence_distance(). + Please check steps explanations above!$ + force_copy : boolean, default: True + When new extensions are computed, the default is to make a copy of the analyzer, to avoid overwriting + already computed extensions. False if you want to overwrite Returns ------- @@ -230,12 +201,24 @@ def get_potential_auto_merge( "knn", "quality_score", ] + if force_copy and compute_needed_extensions: + # To avoid erasing the extensions of the user + sorting_analyzer = sorting_analyzer.copy() for step in steps: if step in _required_extensions: for ext in _required_extensions[step]: - if not sorting_analyzer.has_extension(ext): - raise ValueError(f"{step} requires {ext} extension") + if compute_needed_extensions: + if step in _templates_needed: + template_ext = sorting_analyzer.get_extension("templates") + if template_ext is None: + sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) + params = eval(f"{step}_kwargs") + params = params.get(ext, dict()) + sorting_analyzer.compute(ext, **params, **job_kwargs) + else: + if not sorting_analyzer.has_extension(ext): + raise ValueError(f"{step} requires {ext} extension") n = unit_ids.size pair_mask = np.triu(np.arange(n)) > 0 @@ -248,33 +231,38 @@ def get_potential_auto_merge( # STEP : remove units with too few spikes if step == "num_spikes": num_spikes = sorting.count_num_spikes_per_unit(outputs="array") - to_remove = num_spikes < min_spikes + to_remove = num_spikes < num_spikes_kwargs["min_spikes"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False + outs["num_spikes"] = to_remove # STEP : remove units with too small SNR elif step == "snr": qm_ext = sorting_analyzer.get_extension("quality_metrics") if qm_ext is None: - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("quality_metrics", metric_names=["snr"]) + sorting_analyzer.compute(["noise_levels"], **job_kwargs) + sorting_analyzer.compute("quality_metrics", metric_names=["snr"], **job_kwargs) qm_ext = sorting_analyzer.get_extension("quality_metrics") snrs = qm_ext.get_data()["snr"].values - to_remove = snrs < min_snr + to_remove = snrs < snr_kwargs["min_snr"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False + outs["snr"] = to_remove # STEP : remove contaminated auto corr elif step == "remove_contaminated": contaminations, nb_violations = compute_refrac_period_violations( - sorting_analyzer, refractory_period_ms=refractory_period_ms, censored_period_ms=censored_period_ms + sorting_analyzer, + refractory_period_ms=remove_contaminated_kwargs["refractory_period_ms"], + censored_period_ms=remove_contaminated_kwargs["censored_period_ms"], ) nb_violations = np.array(list(nb_violations.values())) contaminations = np.array(list(contaminations.values())) - to_remove = contaminations > contamination_thresh + to_remove = contaminations > remove_contaminated_kwargs["contamination_thresh"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False + outs["remove_contaminated"] = to_remove # STEP : unit positions are estimated roughly with channel elif step == "unit_locations" in steps: @@ -282,21 +270,23 @@ def get_potential_auto_merge( unit_locations = location_ext.get_data()[:, :2] unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean") - pair_mask = pair_mask & (unit_distances <= max_distance_um) + pair_mask = pair_mask & (unit_distances <= unit_locations_kwargs["max_distance_um"]) outs["unit_distances"] = unit_distances # STEP : potential auto merge by correlogram elif step == "correlogram" in steps: correlograms_ext = sorting_analyzer.get_extension("correlograms") correlograms, bins = correlograms_ext.get_data() - mask = (bins[:-1] >= -censor_correlograms_ms) & (bins[:-1] < censor_correlograms_ms) + censor_ms = correlogram_kwargs["censor_correlograms_ms"] + sigma_smooth_ms = correlogram_kwargs["sigma_smooth_ms"] + mask = (bins[:-1] >= -censor_ms) & (bins[:-1] < censor_ms) correlograms[:, :, mask] = 0 correlograms_smoothed = smooth_correlogram(correlograms, bins, sigma_smooth_ms=sigma_smooth_ms) # find correlogram window for each units win_sizes = np.zeros(n, dtype=int) for unit_ind in range(n): auto_corr = correlograms_smoothed[unit_ind, unit_ind, :] - thresh = np.max(auto_corr) * adaptative_window_thresh + thresh = np.max(auto_corr) * correlogram_kwargs["adaptative_window_thresh"] win_size = get_unit_adaptive_window(auto_corr, thresh) win_sizes[unit_ind] = win_size correlogram_diff = compute_correlogram_diff( @@ -306,7 +296,7 @@ def get_potential_auto_merge( pair_mask=pair_mask, ) # print(correlogram_diff) - pair_mask = pair_mask & (correlogram_diff < corr_diff_thresh) + pair_mask = pair_mask & (correlogram_diff < correlogram_kwargs["corr_diff_thresh"]) outs["correlograms"] = correlograms outs["bins"] = bins outs["correlograms_smoothed"] = correlograms_smoothed @@ -318,18 +308,17 @@ def get_potential_auto_merge( template_similarity_ext = sorting_analyzer.get_extension("template_similarity") templates_similarity = template_similarity_ext.get_data() templates_diff = 1 - templates_similarity - pair_mask = pair_mask & (templates_diff < template_diff_thresh) + pair_mask = pair_mask & (templates_diff < template_similarity_kwargs["template_diff_thresh"]) outs["templates_diff"] = templates_diff # STEP : check the vicinity of the spikes elif step == "knn" in steps: - if knn_kwargs is None: - knn_kwargs = dict() - pair_mask = get_pairs_via_nntree(sorting_analyzer, k_nn, pair_mask, **knn_kwargs) + pair_mask = get_pairs_via_nntree(sorting_analyzer, **knn_kwargs, pair_mask=pair_mask) # STEP : check how the rates overlap in times elif step == "presence_distance" in steps: - presence_distance_kwargs = presence_distance_kwargs or dict() + presence_distance_kwargs = presence_distance_kwargs.copy() + presence_distance_thresh = presence_distance_kwargs.pop("presence_distance_thresh") num_samples = [ sorting_analyzer.get_num_samples(segment_index) for segment_index in range(sorting.get_num_segments()) ] @@ -341,11 +330,14 @@ def get_potential_auto_merge( # STEP : check if the cross contamination is significant elif step == "cross_contamination" in steps: - refractory = (censored_period_ms, refractory_period_ms) + refractory = ( + cross_contamination_kwargs["censored_period_ms"], + cross_contamination_kwargs["refractory_period_ms"], + ) CC, p_values = compute_cross_contaminations( - sorting_analyzer, pair_mask, cc_thresh, refractory, contaminations + sorting_analyzer, pair_mask, cross_contamination_kwargs["cc_thresh"], refractory, contaminations ) - pair_mask = pair_mask & (p_values > p_value) + pair_mask = pair_mask & (p_values > cross_contamination_kwargs["p_value"]) outs["cross_contaminations"] = CC, p_values # STEP : validate the potential merges with CC increase the contamination quality metrics @@ -354,9 +346,9 @@ def get_potential_auto_merge( sorting_analyzer, pair_mask, contaminations, - firing_contamination_balance, - refractory_period_ms, - censored_period_ms, + quality_score_kwargs["firing_contamination_balance"], + quality_score_kwargs["refractory_period_ms"], + quality_score_kwargs["censored_period_ms"], ) outs["pairs_decreased_score"] = pairs_decreased_score @@ -364,9 +356,6 @@ def get_potential_auto_merge( ind1, ind2 = np.nonzero(pair_mask) potential_merges = list(zip(unit_ids[ind1], unit_ids[ind2])) - # some methods return identities ie (1,1) which we can cleanup first. - potential_merges = [(ids[0], ids[1]) for ids in potential_merges if ids[0] != ids[1]] - if resolve_graph: potential_merges = resolve_merging_graph(sorting, potential_merges) @@ -376,6 +365,180 @@ def get_potential_auto_merge( return potential_merges +def get_potential_auto_merge( + sorting_analyzer: SortingAnalyzer, + preset: str | None = "similarity_correlograms", + resolve_graph: bool = False, + min_spikes: int = 100, + min_snr: float = 2, + max_distance_um: float = 150.0, + corr_diff_thresh: float = 0.16, + template_diff_thresh: float = 0.25, + contamination_thresh: float = 0.2, + presence_distance_thresh: float = 100, + p_value: float = 0.2, + cc_thresh: float = 0.1, + censored_period_ms: float = 0.3, + refractory_period_ms: float = 1.0, + sigma_smooth_ms: float = 0.6, + adaptative_window_thresh: float = 0.5, + censor_correlograms_ms: float = 0.15, + firing_contamination_balance: float = 2.5, + k_nn: int = 10, + knn_kwargs: dict | None = None, + presence_distance_kwargs: dict | None = None, + extra_outputs: bool = False, + steps: list[str] | None = None, +) -> list[tuple[int | str, int | str]] | Tuple[tuple[int | str, int | str], dict]: + """ + Algorithm to find and check potential merges between units. + + The merges are proposed based on a series of steps with different criteria: + + * "num_spikes": enough spikes are found in each unit for computing the correlogram (`min_spikes`) + * "snr": the SNR of the units is above a threshold (`min_snr`) + * "remove_contaminated": each unit is not contaminated (by checking auto-correlogram - `contamination_thresh`) + * "unit_locations": estimated unit locations are close enough (`max_distance_um`) + * "correlogram": the cross-correlograms of the two units are similar to each auto-corrleogram (`corr_diff_thresh`) + * "template_similarity": the templates of the two units are similar (`template_diff_thresh`) + * "presence_distance": the presence of the units is complementary in time (`presence_distance_thresh`) + * "cross_contamination": the cross-contamination is not significant (`cc_thresh` and `p_value`) + * "knn": the two units are close in the feature space + * "quality_score": the unit "quality score" is increased after the merge + + The "quality score" factors in the increase in firing rate (**f**) due to the merge and a possible increase in + contamination (**C**), wheighted by a factor **k** (`firing_contamination_balance`). + + .. math:: + + Q = f(1 - (k + 1)C) + + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer + preset : "similarity_correlograms" | "x_contaminations" | "temporal_splits" | "feature_neighbors" | None, default: "similarity_correlograms" + The preset to use for the auto-merge. Presets combine different steps into a recipe and focus on: + + * | "similarity_correlograms": mainly focused on template similarity and correlograms. + | It uses the following steps: "num_spikes", "remove_contaminated", "unit_locations", + | "template_similarity", "correlogram", "quality_score" + * | "x_contaminations": similar to "similarity_correlograms", but checks for cross-contamination instead of correlograms. + | It uses the following steps: "num_spikes", "remove_contaminated", "unit_locations", + | "template_similarity", "cross_contamination", "quality_score" + * | "temporal_splits": focused on finding temporal splits using presence distance. + | It uses the following steps: "num_spikes", "remove_contaminated", "unit_locations", + | "template_similarity", "presence_distance", "quality_score" + * | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN. + | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations", + | "knn", "quality_score" + + If `preset` is None, you can specify the steps manually with the `steps` parameter. + resolve_graph : bool, default: False + If True, the function resolves the potential unit pairs to be merged into multiple-unit merges. + min_spikes : int, default: 100 + Minimum number of spikes for each unit to consider a potential merge. + Enough spikes are needed to estimate the correlogram + min_snr : float, default 2 + Minimum Signal to Noise ratio for templates to be considered while merging + max_distance_um : float, default: 150 + Maximum distance between units for considering a merge + corr_diff_thresh : float, default: 0.16 + The threshold on the "correlogram distance metric" for considering a merge. + It needs to be between 0 and 1 + template_diff_thresh : float, default: 0.25 + The threshold on the "template distance metric" for considering a merge. + It needs to be between 0 and 1 + contamination_thresh : float, default: 0.2 + Threshold for not taking in account a unit when it is too contaminated. + presence_distance_thresh : float, default: 100 + Parameter to control how present two units should be simultaneously. + p_value : float, default: 0.2 + The p-value threshold for the cross-contamination test. + cc_thresh : float, default: 0.1 + The threshold on the cross-contamination for considering a merge. + censored_period_ms : float, default: 0.3 + Used to compute the refractory period violations aka "contamination". + refractory_period_ms : float, default: 1 + Used to compute the refractory period violations aka "contamination". + sigma_smooth_ms : float, default: 0.6 + Parameters to smooth the correlogram estimation. + adaptative_window_thresh : float, default: 0.5 + Parameter to detect the window size in correlogram estimation. + censor_correlograms_ms : float, default: 0.15 + The period to censor on the auto and cross-correlograms. + firing_contamination_balance : float, default: 2.5 + Parameter to control the balance between firing rate and contamination in computing unit "quality score". + k_nn : int, default 5 + The number of neighbors to consider for every spike in the recording. + knn_kwargs : dict, default None + The dict of extra params to be passed to knn. + extra_outputs : bool, default: False + If True, an additional dictionary (`outs`) with processed data is returned. + steps : None or list of str, default: None + Which steps to run, if no preset is used. + Pontential steps : "num_spikes", "snr", "remove_contaminated", "unit_locations", "correlogram", + "template_similarity", "presence_distance", "cross_contamination", "knn", "quality_score" + Please check steps explanations above! + presence_distance_kwargs : None|dict, default: None + A dictionary of kwargs to be passed to compute_presence_distance(). + + Returns + ------- + potential_merges: + A list of tuples of 2 elements (if `resolve_graph`if false) or 2+ elements (if `resolve_graph` is true). + List of pairs that could be merged. + outs: + Returned only when extra_outputs=True + A dictionary that contains data for debugging and plotting. + + References + ---------- + This function is inspired and built upon similar functions from Lussac [Llobet]_, + done by Aurelien Wyngaard and Victor Llobet. + https://github.com/BarbourLab/lussac/blob/v1.0.0/postprocessing/merge_units.py + """ + presence_distance_kwargs = presence_distance_kwargs or dict() + knn_kwargs = knn_kwargs or dict() + return auto_merges( + sorting_analyzer, + preset, + resolve_graph, + num_spikes_kwargs={"min_spikes": min_spikes}, + snr_kwargs={"min_snr": min_snr}, + remove_contaminated_kwargs={ + "contamination_thresh": contamination_thresh, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + unit_locations_kwargs={"max_distance_um": max_distance_um}, + correlogram_kwargs={ + "corr_diff_thresh": corr_diff_thresh, + "censor_correlograms_ms": censor_correlograms_ms, + "sigma_smooth_ms": sigma_smooth_ms, + "adaptative_window_thresh": adaptative_window_thresh, + }, + template_similarity_kwargs={"template_diff_thresh": template_diff_thresh}, + presence_distance_kwargs={"presence_distance_thresh": presence_distance_thresh, **presence_distance_kwargs}, + knn_kwargs={"k_nn": k_nn, **knn_kwargs}, + cross_contamination_kwargs={ + "cc_thresh": cc_thresh, + "p_value": p_value, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + quality_score_kwargs={ + "firing_contamination_balance": firing_contamination_balance, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + compute_needed_extensions=False, + extra_outputs=extra_outputs, + steps=steps, + ) + + def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None, **knn_kwargs): sorting = sorting_analyzer.sorting From 4e38686836aa0a105cb01e8c9dcd25bf7f20a662 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 7 Oct 2024 20:51:52 +0200 Subject: [PATCH 07/61] feedback and clean --- src/spikeinterface/core/recording_tools.py | 75 ++++++---------------- 1 file changed, 20 insertions(+), 55 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 37fcd9714a..cde6f0ced5 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -529,18 +529,19 @@ def get_random_recording_slices(recording, ---------- recording : BaseRecording The recording to get random chunks from - methid : "legacy" - The method used. + method : "legacy" + The method used to get random slices. + * "legacy" : the one used until version 0.101.0, there is no constrain on slices + and they can overlap. num_chunks_per_segment : int, default: 20 Number of chunks per segment chunk_duration : str | float | None, default "500ms" The duration of each chunk in 's' or 'ms' chunk_size : int | None - Size of a chunk in number of frames - + Size of a chunk in number of frames. This is ued only if chunk_duration is None. concatenated : bool, default: True If True chunk are concatenated along time axis - seed : int, default: 0 + seed : int, default: None Random seed margin_frames : int, default: 0 Margin in number of frames to avoid edge effects @@ -596,7 +597,8 @@ def get_random_data_chunks( recording, return_scaled=False, num_chunks_per_segment=20, - chunk_size=10000, + chunk_duration="500ms", + chunk_size=None, concatenated=True, seed=0, margin_frames=0, @@ -604,8 +606,6 @@ def get_random_data_chunks( """ Extract random chunks across segments - This is used for instance in get_noise_levels() to estimate noise on traces. - Parameters ---------- recording : BaseRecording @@ -614,8 +614,10 @@ def get_random_data_chunks( If True, returned chunks are scaled to uV num_chunks_per_segment : int, default: 20 Number of chunks per segment - chunk_size : int, default: 10000 - Size of a chunk in number of frames + chunk_duration : str | float | None, default "500ms" + The duration of each chunk in 's' or 'ms' + chunk_size : int | None + Size of a chunk in number of frames. This is ued only if chunk_duration is None. concatenated : bool, default: True If True chunk are concatenated along time axis seed : int, default: 0 @@ -628,51 +630,19 @@ def get_random_data_chunks( chunk_list : np.array Array of concatenate chunks per segment """ - # # check chunk size - # num_segments = recording.get_num_segments() - # for segment_index in range(num_segments): - # chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames - # if chunk_size > chunk_size_limit: - # chunk_size = chunk_size_limit - 1 - # warnings.warn( - # f"chunk_size is greater than the number " - # f"of samples for segment index {segment_index}. " - # f"Using {chunk_size}." - # ) - - # rng = np.random.default_rng(seed) - # chunk_list = [] - # low = margin_frames - # size = num_chunks_per_segment - # for segment_index in range(num_segments): - # num_frames = recording.get_num_frames(segment_index) - # high = num_frames - chunk_size - margin_frames - # random_starts = rng.integers(low=low, high=high, size=size) - # segment_trace_chunk = [ - # recording.get_traces( - # start_frame=start_frame, - # end_frame=(start_frame + chunk_size), - # segment_index=segment_index, - # return_scaled=return_scaled, - # ) - # for start_frame in random_starts - # ] - - # chunk_list.extend(segment_trace_chunk) - recording_slices = get_random_recording_slices(recording, method="legacy", num_chunks_per_segment=num_chunks_per_segment, + chunk_duration=chunk_duration, chunk_size=chunk_size, - # chunk_duration=chunk_duration, margin_frames=margin_frames, seed=seed) chunk_list = [] - for segment_index, start_frame, stop_frame in recording_slices: + for segment_index, start_frame, end_frame in recording_slices: traces_chunk = recording.get_traces( start_frame=start_frame, - end_frame=(start_frame + chunk_size), + end_frame=end_frame, segment_index=segment_index, return_scaled=return_scaled, ) @@ -773,7 +743,11 @@ def get_noise_levels( You can use standard deviation with `method="std"` Internally it samples some chunk across segment. - And then, it use MAD estimator (more robust than STD) + And then, it use MAD estimator (more robust than STD) ot the STD on each chunk. + Finally the average on all MAD is performed. + + The result is cached in a property of the recording. + Next call on the same recording will use the cache unless force_recompute=True. Parameters ---------- @@ -803,15 +777,6 @@ def get_noise_levels( if key in recording.get_property_keys() and not force_recompute: noise_levels = recording.get_property(key=key) else: - # random_chunks = get_random_data_chunks(recording, return_scaled=return_scaled, **random_chunk_kwargs) - - # if method == "mad": - # med = np.median(random_chunks, axis=0, keepdims=True) - # # hard-coded so that core doesn't depend on scipy - # noise_levels = np.median(np.abs(random_chunks - med), axis=0) / 0.6744897501960817 - # elif method == "std": - # noise_levels = np.std(random_chunks, axis=0) - random_slices_kwargs, job_kwargs = split_job_kwargs(kwargs) recording_slices = get_random_recording_slices(recording,**random_slices_kwargs) From 484a5f4626c1cd160f40d08cb8f6980ea6f6b8b3 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 8 Oct 2024 10:38:45 +0200 Subject: [PATCH 08/61] WIP --- src/spikeinterface/curation/auto_merge.py | 151 +++++++++++----------- 1 file changed, 77 insertions(+), 74 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 00c156094d..7a101ad609 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -35,26 +35,31 @@ def auto_merges( sorting_analyzer: SortingAnalyzer, preset: str | None = "similarity_correlograms", resolve_graph: bool = False, - num_spikes_kwargs={"min_spikes": 100}, - snr_kwargs={"min_snr": 2}, - remove_contaminated_kwargs={"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, - unit_locations_kwargs={"max_distance_um": 50}, - correlogram_kwargs={ - "corr_diff_thresh": 0.16, - "censor_correlograms_ms": 0.3, - "sigma_smooth_ms": 0.6, - "adaptative_window_thresh": 0.5, - }, - template_similarity_kwargs={"template_diff_thresh": 0.25}, - presence_distance_kwargs={"presence_distance_thresh": 100}, - knn_kwargs={"k_nn": 10}, - cross_contamination_kwargs={ - "cc_thresh": 0.1, - "p_value": 0.2, - "refractory_period_ms": 1.0, - "censored_period_ms": 0.3, - }, - quality_score_kwargs={"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + steps_params: dict = {"num_spikes" : {"min_spikes": 100}, + "snr" : {"min_snr": 2}, + "remove_contaminated" : {"contamination_thresh": 0.2, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3}, + "unit_locations" : {"max_distance_um": 50}, + "correlogram" : { + "corr_diff_thresh": 0.16, + "censor_correlograms_ms": 0.3, + "sigma_smooth_ms": 0.6, + "adaptative_window_thresh": 0.5, + }, + "template_similarity" : {"template_diff_thresh": 0.25}, + "presence_distance" : {"presence_distance_thresh": 100}, + "knn" : {"k_nn": 10}, + "cross_contamination" : { + "cc_thresh": 0.1, + "p_value": 0.2, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3, + }, + "quality_score" : {"firing_contamination_balance": 2.5, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3}, + }, compute_needed_extensions: bool = True, extra_outputs: bool = False, steps: list[str] | None = None, @@ -115,7 +120,8 @@ def auto_merges( Which steps to run, if no preset is used. Pontential steps : "num_spikes", "snr", "remove_contaminated", "unit_locations", "correlogram", "template_similarity", "presence_distance", "cross_contamination", "knn", "quality_score" - Please check steps explanations above!$ + Please check steps explanations above! + steps_params : A dictionary whose keys are the steps, and keys are steps parameters. force_copy : boolean, default: True When new extensions are computed, the default is to make a copy of the analyzer, to avoid overwriting already computed extensions. False if you want to overwrite @@ -140,11 +146,6 @@ def auto_merges( sorting = sorting_analyzer.sorting unit_ids = sorting.unit_ids - # to get fast computation we will not analyse pairs when: - # * not enough spikes for one of theses - # * auto correlogram is contaminated - # * to far away one from each other - all_steps = [ "num_spikes", "snr", @@ -227,11 +228,13 @@ def auto_merges( for step in steps: assert step in all_steps, f"{step} is not a valid step" + params = steps_params.get(step, {}) # STEP : remove units with too few spikes if step == "num_spikes": + num_spikes = sorting.count_num_spikes_per_unit(outputs="array") - to_remove = num_spikes < num_spikes_kwargs["min_spikes"] + to_remove = num_spikes < params["min_spikes"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False outs["num_spikes"] = to_remove @@ -245,7 +248,7 @@ def auto_merges( qm_ext = sorting_analyzer.get_extension("quality_metrics") snrs = qm_ext.get_data()["snr"].values - to_remove = snrs < snr_kwargs["min_snr"] + to_remove = snrs < params["min_snr"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False outs["snr"] = to_remove @@ -254,12 +257,12 @@ def auto_merges( elif step == "remove_contaminated": contaminations, nb_violations = compute_refrac_period_violations( sorting_analyzer, - refractory_period_ms=remove_contaminated_kwargs["refractory_period_ms"], - censored_period_ms=remove_contaminated_kwargs["censored_period_ms"], + refractory_period_ms=params["refractory_period_ms"], + censored_period_ms=params["censored_period_ms"], ) nb_violations = np.array(list(nb_violations.values())) contaminations = np.array(list(contaminations.values())) - to_remove = contaminations > remove_contaminated_kwargs["contamination_thresh"] + to_remove = contaminations > params["contamination_thresh"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False outs["remove_contaminated"] = to_remove @@ -270,15 +273,15 @@ def auto_merges( unit_locations = location_ext.get_data()[:, :2] unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean") - pair_mask = pair_mask & (unit_distances <= unit_locations_kwargs["max_distance_um"]) + pair_mask = pair_mask & (unit_distances <= params["max_distance_um"]) outs["unit_distances"] = unit_distances # STEP : potential auto merge by correlogram elif step == "correlogram" in steps: correlograms_ext = sorting_analyzer.get_extension("correlograms") correlograms, bins = correlograms_ext.get_data() - censor_ms = correlogram_kwargs["censor_correlograms_ms"] - sigma_smooth_ms = correlogram_kwargs["sigma_smooth_ms"] + censor_ms = params["censor_correlograms_ms"] + sigma_smooth_ms = params["sigma_smooth_ms"] mask = (bins[:-1] >= -censor_ms) & (bins[:-1] < censor_ms) correlograms[:, :, mask] = 0 correlograms_smoothed = smooth_correlogram(correlograms, bins, sigma_smooth_ms=sigma_smooth_ms) @@ -286,7 +289,7 @@ def auto_merges( win_sizes = np.zeros(n, dtype=int) for unit_ind in range(n): auto_corr = correlograms_smoothed[unit_ind, unit_ind, :] - thresh = np.max(auto_corr) * correlogram_kwargs["adaptative_window_thresh"] + thresh = np.max(auto_corr) * params["adaptative_window_thresh"] win_size = get_unit_adaptive_window(auto_corr, thresh) win_sizes[unit_ind] = win_size correlogram_diff = compute_correlogram_diff( @@ -296,7 +299,7 @@ def auto_merges( pair_mask=pair_mask, ) # print(correlogram_diff) - pair_mask = pair_mask & (correlogram_diff < correlogram_kwargs["corr_diff_thresh"]) + pair_mask = pair_mask & (correlogram_diff < params["corr_diff_thresh"]) outs["correlograms"] = correlograms outs["bins"] = bins outs["correlograms_smoothed"] = correlograms_smoothed @@ -308,16 +311,16 @@ def auto_merges( template_similarity_ext = sorting_analyzer.get_extension("template_similarity") templates_similarity = template_similarity_ext.get_data() templates_diff = 1 - templates_similarity - pair_mask = pair_mask & (templates_diff < template_similarity_kwargs["template_diff_thresh"]) + pair_mask = pair_mask & (templates_diff < params["template_diff_thresh"]) outs["templates_diff"] = templates_diff # STEP : check the vicinity of the spikes elif step == "knn" in steps: - pair_mask = get_pairs_via_nntree(sorting_analyzer, **knn_kwargs, pair_mask=pair_mask) + pair_mask = get_pairs_via_nntree(sorting_analyzer, **params, pair_mask=pair_mask) # STEP : check how the rates overlap in times elif step == "presence_distance" in steps: - presence_distance_kwargs = presence_distance_kwargs.copy() + presence_distance_kwargs = params.copy() presence_distance_thresh = presence_distance_kwargs.pop("presence_distance_thresh") num_samples = [ sorting_analyzer.get_num_samples(segment_index) for segment_index in range(sorting.get_num_segments()) @@ -331,13 +334,13 @@ def auto_merges( # STEP : check if the cross contamination is significant elif step == "cross_contamination" in steps: refractory = ( - cross_contamination_kwargs["censored_period_ms"], - cross_contamination_kwargs["refractory_period_ms"], + params["censored_period_ms"], + params["refractory_period_ms"], ) CC, p_values = compute_cross_contaminations( - sorting_analyzer, pair_mask, cross_contamination_kwargs["cc_thresh"], refractory, contaminations + sorting_analyzer, pair_mask, params["cc_thresh"], refractory, contaminations ) - pair_mask = pair_mask & (p_values > cross_contamination_kwargs["p_value"]) + pair_mask = pair_mask & (p_values > params["p_value"]) outs["cross_contaminations"] = CC, p_values # STEP : validate the potential merges with CC increase the contamination quality metrics @@ -346,9 +349,9 @@ def auto_merges( sorting_analyzer, pair_mask, contaminations, - quality_score_kwargs["firing_contamination_balance"], - quality_score_kwargs["refractory_period_ms"], - quality_score_kwargs["censored_period_ms"], + params["firing_contamination_balance"], + params["refractory_period_ms"], + params["censored_period_ms"], ) outs["pairs_decreased_score"] = pairs_decreased_score @@ -505,34 +508,34 @@ def get_potential_auto_merge( sorting_analyzer, preset, resolve_graph, - num_spikes_kwargs={"min_spikes": min_spikes}, - snr_kwargs={"min_snr": min_snr}, - remove_contaminated_kwargs={ - "contamination_thresh": contamination_thresh, - "refractory_period_ms": refractory_period_ms, - "censored_period_ms": censored_period_ms, - }, - unit_locations_kwargs={"max_distance_um": max_distance_um}, - correlogram_kwargs={ - "corr_diff_thresh": corr_diff_thresh, - "censor_correlograms_ms": censor_correlograms_ms, - "sigma_smooth_ms": sigma_smooth_ms, - "adaptative_window_thresh": adaptative_window_thresh, - }, - template_similarity_kwargs={"template_diff_thresh": template_diff_thresh}, - presence_distance_kwargs={"presence_distance_thresh": presence_distance_thresh, **presence_distance_kwargs}, - knn_kwargs={"k_nn": k_nn, **knn_kwargs}, - cross_contamination_kwargs={ - "cc_thresh": cc_thresh, - "p_value": p_value, - "refractory_period_ms": refractory_period_ms, - "censored_period_ms": censored_period_ms, - }, - quality_score_kwargs={ - "firing_contamination_balance": firing_contamination_balance, - "refractory_period_ms": refractory_period_ms, - "censored_period_ms": censored_period_ms, - }, + step_params={"num_spikes" : {"min_spikes": min_spikes}, + "snr_kwargs" : {"min_snr": min_snr}, + "remove_contaminated_kwargs" : { + "contamination_thresh": contamination_thresh, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + "unit_locations" : {"max_distance_um": max_distance_um}, + "correlogram" : { + "corr_diff_thresh": corr_diff_thresh, + "censor_correlograms_ms": censor_correlograms_ms, + "sigma_smooth_ms": sigma_smooth_ms, + "adaptative_window_thresh": adaptative_window_thresh, + }, + "template_similarity": {"template_diff_thresh": template_diff_thresh}, + "presence_distance" : {"presence_distance_thresh": presence_distance_thresh, **presence_distance_kwargs}, + "knn" : {"k_nn": k_nn, **knn_kwargs}, + "cross_contamination" : { + "cc_thresh": cc_thresh, + "p_value": p_value, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + "quality_score" : { + "firing_contamination_balance": firing_contamination_balance, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }}, compute_needed_extensions=False, extra_outputs=extra_outputs, steps=steps, From 35ad317e619be60abbdd40f1da41a167171be1c9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Oct 2024 08:41:42 +0000 Subject: [PATCH 09/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/auto_merge.py | 107 +++++++++++----------- 1 file changed, 53 insertions(+), 54 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 7a101ad609..db3300f0d2 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -35,31 +35,28 @@ def auto_merges( sorting_analyzer: SortingAnalyzer, preset: str | None = "similarity_correlograms", resolve_graph: bool = False, - steps_params: dict = {"num_spikes" : {"min_spikes": 100}, - "snr" : {"min_snr": 2}, - "remove_contaminated" : {"contamination_thresh": 0.2, - "refractory_period_ms": 1.0, - "censored_period_ms": 0.3}, - "unit_locations" : {"max_distance_um": 50}, - "correlogram" : { - "corr_diff_thresh": 0.16, - "censor_correlograms_ms": 0.3, - "sigma_smooth_ms": 0.6, - "adaptative_window_thresh": 0.5, - }, - "template_similarity" : {"template_diff_thresh": 0.25}, - "presence_distance" : {"presence_distance_thresh": 100}, - "knn" : {"k_nn": 10}, - "cross_contamination" : { - "cc_thresh": 0.1, - "p_value": 0.2, - "refractory_period_ms": 1.0, - "censored_period_ms": 0.3, - }, - "quality_score" : {"firing_contamination_balance": 2.5, - "refractory_period_ms": 1.0, - "censored_period_ms": 0.3}, - }, + steps_params: dict = { + "num_spikes": {"min_spikes": 100}, + "snr": {"min_snr": 2}, + "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + "unit_locations": {"max_distance_um": 50}, + "correlogram": { + "corr_diff_thresh": 0.16, + "censor_correlograms_ms": 0.3, + "sigma_smooth_ms": 0.6, + "adaptative_window_thresh": 0.5, + }, + "template_similarity": {"template_diff_thresh": 0.25}, + "presence_distance": {"presence_distance_thresh": 100}, + "knn": {"k_nn": 10}, + "cross_contamination": { + "cc_thresh": 0.1, + "p_value": 0.2, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3, + }, + "quality_score": {"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + }, compute_needed_extensions: bool = True, extra_outputs: bool = False, steps: list[str] | None = None, @@ -232,7 +229,7 @@ def auto_merges( # STEP : remove units with too few spikes if step == "num_spikes": - + num_spikes = sorting.count_num_spikes_per_unit(outputs="array") to_remove = num_spikes < params["min_spikes"] pair_mask[to_remove, :] = False @@ -508,34 +505,36 @@ def get_potential_auto_merge( sorting_analyzer, preset, resolve_graph, - step_params={"num_spikes" : {"min_spikes": min_spikes}, - "snr_kwargs" : {"min_snr": min_snr}, - "remove_contaminated_kwargs" : { - "contamination_thresh": contamination_thresh, - "refractory_period_ms": refractory_period_ms, - "censored_period_ms": censored_period_ms, - }, - "unit_locations" : {"max_distance_um": max_distance_um}, - "correlogram" : { - "corr_diff_thresh": corr_diff_thresh, - "censor_correlograms_ms": censor_correlograms_ms, - "sigma_smooth_ms": sigma_smooth_ms, - "adaptative_window_thresh": adaptative_window_thresh, - }, - "template_similarity": {"template_diff_thresh": template_diff_thresh}, - "presence_distance" : {"presence_distance_thresh": presence_distance_thresh, **presence_distance_kwargs}, - "knn" : {"k_nn": k_nn, **knn_kwargs}, - "cross_contamination" : { - "cc_thresh": cc_thresh, - "p_value": p_value, - "refractory_period_ms": refractory_period_ms, - "censored_period_ms": censored_period_ms, - }, - "quality_score" : { - "firing_contamination_balance": firing_contamination_balance, - "refractory_period_ms": refractory_period_ms, - "censored_period_ms": censored_period_ms, - }}, + step_params={ + "num_spikes": {"min_spikes": min_spikes}, + "snr_kwargs": {"min_snr": min_snr}, + "remove_contaminated_kwargs": { + "contamination_thresh": contamination_thresh, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + "unit_locations": {"max_distance_um": max_distance_um}, + "correlogram": { + "corr_diff_thresh": corr_diff_thresh, + "censor_correlograms_ms": censor_correlograms_ms, + "sigma_smooth_ms": sigma_smooth_ms, + "adaptative_window_thresh": adaptative_window_thresh, + }, + "template_similarity": {"template_diff_thresh": template_diff_thresh}, + "presence_distance": {"presence_distance_thresh": presence_distance_thresh, **presence_distance_kwargs}, + "knn": {"k_nn": k_nn, **knn_kwargs}, + "cross_contamination": { + "cc_thresh": cc_thresh, + "p_value": p_value, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + "quality_score": { + "firing_contamination_balance": firing_contamination_balance, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + }, compute_needed_extensions=False, extra_outputs=extra_outputs, steps=steps, From 49c7a92a57af5a65f7367b375567afaed6abda56 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 15 Oct 2024 13:46:36 +0200 Subject: [PATCH 10/61] Use existing sparsity for unit location + add location with max channel --- .../postprocessing/localization_tools.py | 67 +++++++++++++++++-- .../tests/test_unit_locations.py | 1 + 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index e6278fc59f..59ca8cf7db 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -76,8 +76,12 @@ def compute_monopolar_triangulation( assert feature in ["ptp", "energy", "peak_voltage"], f"{feature} is not a valid feature" contact_locations = sorting_analyzer_or_templates.get_channel_locations() + + if sorting_analyzer_or_templates.sparsity is None: + sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um) + else: + sparsity = sorting_analyzer_or_templates.sparsity - sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um) templates = get_dense_templates_array( sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates) ) @@ -157,9 +161,13 @@ def compute_center_of_mass( assert feature in ["ptp", "mean", "energy", "peak_voltage"], f"{feature} is not a valid feature" - sparsity = compute_sparsity( - sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um - ) + if sorting_analyzer_or_templates.sparsity is None: + sparsity = compute_sparsity( + sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um + ) + else: + sparsity = sorting_analyzer_or_templates.sparsity + templates = get_dense_templates_array( sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates) ) @@ -650,8 +658,59 @@ def get_convolution_weights( enforce_decrease_shells = numba.jit(enforce_decrease_shells_data, nopython=True) + +def compute_location_max_channel( + templates_or_sorting_analyzer: SortingAnalyzer | Templates, + unit_ids=None, + peak_sign: "neg" | "pos" | "both" = "neg", + mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", +) -> np.ndarray: + """ + Localize a unit using max channel. + + This use inetrnally get_template_extremum_channel() + + + Parameters + ---------- + templates_or_sorting_analyzer : SortingAnalyzer | Templates + A SortingAnalyzer or Templates object + unit_ids: str | int | None + A list of unit_id to restrict the computation + peak_sign : "neg" | "pos" | "both" + Sign of the template to find extremum channels + mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + Where the amplitude is computed + * "extremum" : take the peak value (max or min depending on `peak_sign`) + * "at_index" : take value at `nbefore` index + * "peak_to_peak" : take the peak-to-peak amplitude + + Returns + ------- + unit_location: np.ndarray + 2d + """ + extremum_channels_index = get_template_extremum_channel( + templates_or_sorting_analyzer, + peak_sign=peak_sign, + mode=mode, + outputs="index" + ) + contact_locations = templates_or_sorting_analyzer.get_channel_locations() + if unit_ids is None: + unit_ids = templates_or_sorting_analyzer.unit_ids + else: + unit_ids = np.asarray(unit_ids) + unit_location = np.zeros((unit_ids.size, 2), dtype="float32") + for i, unit_id in enumerate(unit_ids): + unit_location[i, :] = contact_locations[extremum_channels_index[unit_id]] + + return unit_location + + _unit_location_methods = { "center_of_mass": compute_center_of_mass, "grid_convolution": compute_grid_convolution, "monopolar_triangulation": compute_monopolar_triangulation, + "max_channel": compute_location_max_channel, } diff --git a/src/spikeinterface/postprocessing/tests/test_unit_locations.py b/src/spikeinterface/postprocessing/tests/test_unit_locations.py index c40a917a2b..545edb3497 100644 --- a/src/spikeinterface/postprocessing/tests/test_unit_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_unit_locations.py @@ -13,6 +13,7 @@ class TestUnitLocationsExtension(AnalyzerExtensionCommonTestSuite): dict(method="grid_convolution", radius_um=150, weight_method={"mode": "gaussian_2d"}), dict(method="monopolar_triangulation", radius_um=150), dict(method="monopolar_triangulation", radius_um=150, optimizer="minimize_with_log_penality"), + dict(method="max_channel"), ], ) def test_extension(self, params): From 9cf9377a30b1733223037c58bb05709f0e76d5c8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Oct 2024 11:50:21 +0000 Subject: [PATCH 11/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../postprocessing/localization_tools.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 59ca8cf7db..4bf39e00e8 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -76,7 +76,7 @@ def compute_monopolar_triangulation( assert feature in ["ptp", "energy", "peak_voltage"], f"{feature} is not a valid feature" contact_locations = sorting_analyzer_or_templates.get_channel_locations() - + if sorting_analyzer_or_templates.sparsity is None: sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um) else: @@ -167,7 +167,7 @@ def compute_center_of_mass( ) else: sparsity = sorting_analyzer_or_templates.sparsity - + templates = get_dense_templates_array( sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates) ) @@ -658,7 +658,6 @@ def get_convolution_weights( enforce_decrease_shells = numba.jit(enforce_decrease_shells_data, nopython=True) - def compute_location_max_channel( templates_or_sorting_analyzer: SortingAnalyzer | Templates, unit_ids=None, @@ -691,10 +690,7 @@ def compute_location_max_channel( 2d """ extremum_channels_index = get_template_extremum_channel( - templates_or_sorting_analyzer, - peak_sign=peak_sign, - mode=mode, - outputs="index" + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, outputs="index" ) contact_locations = templates_or_sorting_analyzer.get_channel_locations() if unit_ids is None: From 3bf9b4884de04c89ffd0e89c647a9c151c27ed96 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 15 Oct 2024 14:54:30 +0200 Subject: [PATCH 12/61] Fixing tests --- src/spikeinterface/curation/auto_merge.py | 46 ++++++++++++----------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index db3300f0d2..7a8404d076 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -38,7 +38,9 @@ def auto_merges( steps_params: dict = { "num_spikes": {"min_spikes": 100}, "snr": {"min_snr": 2}, - "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + "remove_contaminated": {"contamination_thresh": 0.2, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3}, "unit_locations": {"max_distance_um": 50}, "correlogram": { "corr_diff_thresh": 0.16, @@ -55,7 +57,9 @@ def auto_merges( "refractory_period_ms": 1.0, "censored_period_ms": 0.3, }, - "quality_score": {"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + "quality_score": {"firing_contamination_balance": 2.5, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3}, }, compute_needed_extensions: bool = True, extra_outputs: bool = False, @@ -203,21 +207,6 @@ def auto_merges( # To avoid erasing the extensions of the user sorting_analyzer = sorting_analyzer.copy() - for step in steps: - if step in _required_extensions: - for ext in _required_extensions[step]: - if compute_needed_extensions: - if step in _templates_needed: - template_ext = sorting_analyzer.get_extension("templates") - if template_ext is None: - sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - params = eval(f"{step}_kwargs") - params = params.get(ext, dict()) - sorting_analyzer.compute(ext, **params, **job_kwargs) - else: - if not sorting_analyzer.has_extension(ext): - raise ValueError(f"{step} requires {ext} extension") - n = unit_ids.size pair_mask = np.triu(np.arange(n)) > 0 outs = dict() @@ -225,7 +214,20 @@ def auto_merges( for step in steps: assert step in all_steps, f"{step} is not a valid step" - params = steps_params.get(step, {}) + + if step in _required_extensions: + for ext in _required_extensions[step]: + if compute_needed_extensions and step in _templates_needed: + template_ext = sorting_analyzer.get_extension("templates") + if template_ext is None: + sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) + print(f"Extension {ext} is computed with default params") + sorting_analyzer.compute(ext, **job_kwargs) + elif not compute_needed_extensions and not sorting_analyzer.has_extension(ext): + raise ValueError(f"{step} requires {ext} extension") + + + params = steps_params.get(step, dict()) # STEP : remove units with too few spikes if step == "num_spikes": @@ -240,7 +242,7 @@ def auto_merges( elif step == "snr": qm_ext = sorting_analyzer.get_extension("quality_metrics") if qm_ext is None: - sorting_analyzer.compute(["noise_levels"], **job_kwargs) + sorting_analyzer.compute("noise_levels", **job_kwargs) sorting_analyzer.compute("quality_metrics", metric_names=["snr"], **job_kwargs) qm_ext = sorting_analyzer.get_extension("quality_metrics") @@ -505,10 +507,10 @@ def get_potential_auto_merge( sorting_analyzer, preset, resolve_graph, - step_params={ + steps_params={ "num_spikes": {"min_spikes": min_spikes}, - "snr_kwargs": {"min_snr": min_snr}, - "remove_contaminated_kwargs": { + "snr": {"min_snr": min_snr}, + "remove_contaminated": { "contamination_thresh": contamination_thresh, "refractory_period_ms": refractory_period_ms, "censored_period_ms": censored_period_ms, From 3df19c2e11117e9b69be4416bdb1123637ce63e8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Oct 2024 12:55:03 +0000 Subject: [PATCH 13/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/auto_merge.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 7a8404d076..d38b717bc8 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -38,9 +38,7 @@ def auto_merges( steps_params: dict = { "num_spikes": {"min_spikes": 100}, "snr": {"min_snr": 2}, - "remove_contaminated": {"contamination_thresh": 0.2, - "refractory_period_ms": 1.0, - "censored_period_ms": 0.3}, + "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, "unit_locations": {"max_distance_um": 50}, "correlogram": { "corr_diff_thresh": 0.16, @@ -57,9 +55,7 @@ def auto_merges( "refractory_period_ms": 1.0, "censored_period_ms": 0.3, }, - "quality_score": {"firing_contamination_balance": 2.5, - "refractory_period_ms": 1.0, - "censored_period_ms": 0.3}, + "quality_score": {"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, }, compute_needed_extensions: bool = True, extra_outputs: bool = False, @@ -226,7 +222,6 @@ def auto_merges( elif not compute_needed_extensions and not sorting_analyzer.has_extension(ext): raise ValueError(f"{step} requires {ext} extension") - params = steps_params.get(step, dict()) # STEP : remove units with too few spikes From 51edfece2f8ef041774bd2b27582021431e0f93d Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 15 Oct 2024 15:00:18 +0200 Subject: [PATCH 14/61] Fixing tests --- src/spikeinterface/curation/auto_merge.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 7a8404d076..4966db4247 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -23,12 +23,12 @@ _required_extensions = { "unit_locations": ["unit_locations"], "correlogram": ["correlograms"], - "min_snr": ["noise_levels", "templates"], + "snr": ["noise_levels", "templates"], "template_similarity": ["template_similarity"], "knn": ["spike_locations", "spike_amplitudes"], } -_templates_needed = ["unit_locations", "min_snr", "template_similarity", "spike_locations", "spike_amplitudes"] +_templates_needed = ["unit_locations", "snr", "template_similarity", "knn", "spike_amplitudes"] def auto_merges( @@ -242,7 +242,6 @@ def auto_merges( elif step == "snr": qm_ext = sorting_analyzer.get_extension("quality_metrics") if qm_ext is None: - sorting_analyzer.compute("noise_levels", **job_kwargs) sorting_analyzer.compute("quality_metrics", metric_names=["snr"], **job_kwargs) qm_ext = sorting_analyzer.get_extension("quality_metrics") @@ -537,7 +536,7 @@ def get_potential_auto_merge( "censored_period_ms": censored_period_ms, }, }, - compute_needed_extensions=False, + compute_needed_extensions=True, extra_outputs=extra_outputs, steps=steps, ) From c26b7199e086c9a3e48c99aa0495f540206e44a4 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 16 Oct 2024 14:27:14 +0200 Subject: [PATCH 15/61] Default params --- src/spikeinterface/curation/auto_merge.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 03c8c131a9..39a155ec09 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -31,11 +31,7 @@ _templates_needed = ["unit_locations", "snr", "template_similarity", "knn", "spike_amplitudes"] -def auto_merges( - sorting_analyzer: SortingAnalyzer, - preset: str | None = "similarity_correlograms", - resolve_graph: bool = False, - steps_params: dict = { +_default_step_params = { "num_spikes": {"min_spikes": 100}, "snr": {"min_snr": 2}, "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, @@ -56,7 +52,14 @@ def auto_merges( "censored_period_ms": 0.3, }, "quality_score": {"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, - }, + } + + +def auto_merges( + sorting_analyzer: SortingAnalyzer, + preset: str | None = "similarity_correlograms", + resolve_graph: bool = False, + steps_params: dict = None, compute_needed_extensions: bool = True, extra_outputs: bool = False, steps: list[str] | None = None, @@ -222,7 +225,9 @@ def auto_merges( elif not compute_needed_extensions and not sorting_analyzer.has_extension(ext): raise ValueError(f"{step} requires {ext} extension") - params = steps_params.get(step, dict()) + params = _default_step_params.get(step).copy() + if step in steps_params: + params.update(steps_params[step]) # STEP : remove units with too few spikes if step == "num_spikes": From 9692fb0fbaf294c323edc4bbaeb66d3347e2145c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 16 Oct 2024 12:31:36 +0000 Subject: [PATCH 16/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/auto_merge.py | 42 +++++++++++------------ 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 39a155ec09..e337b3d99d 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -32,27 +32,27 @@ _default_step_params = { - "num_spikes": {"min_spikes": 100}, - "snr": {"min_snr": 2}, - "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, - "unit_locations": {"max_distance_um": 50}, - "correlogram": { - "corr_diff_thresh": 0.16, - "censor_correlograms_ms": 0.3, - "sigma_smooth_ms": 0.6, - "adaptative_window_thresh": 0.5, - }, - "template_similarity": {"template_diff_thresh": 0.25}, - "presence_distance": {"presence_distance_thresh": 100}, - "knn": {"k_nn": 10}, - "cross_contamination": { - "cc_thresh": 0.1, - "p_value": 0.2, - "refractory_period_ms": 1.0, - "censored_period_ms": 0.3, - }, - "quality_score": {"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, - } + "num_spikes": {"min_spikes": 100}, + "snr": {"min_snr": 2}, + "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + "unit_locations": {"max_distance_um": 50}, + "correlogram": { + "corr_diff_thresh": 0.16, + "censor_correlograms_ms": 0.3, + "sigma_smooth_ms": 0.6, + "adaptative_window_thresh": 0.5, + }, + "template_similarity": {"template_diff_thresh": 0.25}, + "presence_distance": {"presence_distance_thresh": 100}, + "knn": {"k_nn": 10}, + "cross_contamination": { + "cc_thresh": 0.1, + "p_value": 0.2, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3, + }, + "quality_score": {"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, +} def auto_merges( From 3c277b3445dc05760d617c249fbc58a43b2d7ace Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 16 Oct 2024 14:35:24 +0200 Subject: [PATCH 17/61] Precomputing extensions --- src/spikeinterface/curation/auto_merge.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 39a155ec09..fcf5fd8fd9 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -216,12 +216,15 @@ def auto_merges( if step in _required_extensions: for ext in _required_extensions[step]: - if compute_needed_extensions and step in _templates_needed: - template_ext = sorting_analyzer.get_extension("templates") - if template_ext is None: - sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - print(f"Extension {ext} is computed with default params") - sorting_analyzer.compute(ext, **job_kwargs) + if compute_needed_extensions: + if step in _templates_needed: + template_ext = sorting_analyzer.get_extension("templates") + if template_ext is None: + sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) + res_ext = sorting_analyzer.get_extension(step) + if res_ext is None: + print(f"Extension {ext} is computed with default params. Precompute it with custom params if needed") + sorting_analyzer.compute(ext, **job_kwargs) elif not compute_needed_extensions and not sorting_analyzer.has_extension(ext): raise ValueError(f"{step} requires {ext} extension") From a3d1c2c4f49025e01bf17b13b76b931cd071e938 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 16 Oct 2024 12:35:58 +0000 Subject: [PATCH 18/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/auto_merge.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index ffc4fea78b..86f47af0eb 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -223,7 +223,9 @@ def auto_merges( sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) res_ext = sorting_analyzer.get_extension(step) if res_ext is None: - print(f"Extension {ext} is computed with default params. Precompute it with custom params if needed") + print( + f"Extension {ext} is computed with default params. Precompute it with custom params if needed" + ) sorting_analyzer.compute(ext, **job_kwargs) elif not compute_needed_extensions and not sorting_analyzer.has_extension(ext): raise ValueError(f"{step} requires {ext} extension") From d8ee9da3dbc5599dc2876b47933aab8b352fab71 Mon Sep 17 00:00:00 2001 From: rainsong <57996958+522848942@users.noreply.github.com> Date: Mon, 21 Oct 2024 00:06:19 +0800 Subject: [PATCH 19/61] Update core.rst doc error --- doc/modules/core.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 8aa1815a55..5df9a7e6b1 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -385,7 +385,7 @@ and merging unit groups. sorting_analyzer_select = sorting_analyzer.select_units(unit_ids=[0, 1, 2, 3]) sorting_analyzer_remove = sorting_analyzer.remove_units(remove_unit_ids=[0]) - sorting_analyzer_merge = sorting_analyzer.merge_units([0, 1], [2, 3]) + sorting_analyzer_merge = sorting_analyzer.merge_units([[0, 1], [2, 3]]) All computed extensions will be automatically propagated or merged when curating. Please refer to the :ref:`modules/curation:Curation module` documentation for more information. From 0e5f50fdfb6b6ae35a9b06d814e811ac9bf833ee Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 21 Oct 2024 13:54:42 +0200 Subject: [PATCH 20/61] merci zach Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/postprocessing/localization_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 4bf39e00e8..a17abea1eb 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -667,14 +667,14 @@ def compute_location_max_channel( """ Localize a unit using max channel. - This use inetrnally get_template_extremum_channel() + This uses interrnally `get_template_extremum_channel()` Parameters ---------- templates_or_sorting_analyzer : SortingAnalyzer | Templates A SortingAnalyzer or Templates object - unit_ids: str | int | None + unit_ids: list[str] | list[int] | None A list of unit_id to restrict the computation peak_sign : "neg" | "pos" | "both" Sign of the template to find extremum channels From 1411c6fdf9d89c68a205c1512b1dce9ce2a29c62 Mon Sep 17 00:00:00 2001 From: OlivierPeron <79974181+OlivierPeron@users.noreply.github.com> Date: Mon, 21 Oct 2024 15:59:20 +0200 Subject: [PATCH 21/61] Loading templates Loading templates whatever the operator --- src/spikeinterface/core/template_tools.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 934b18ed49..769610ad2b 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -31,7 +31,8 @@ def get_dense_templates_array(one_object: Templates | SortingAnalyzer, return_sc ) ext = one_object.get_extension("templates") if ext is not None: - templates_array = ext.data["average"] + templates_array = ext.data.get("average") or ext.data.get("median") + assert templates_array is not None, "Average or median templates have not been computed." else: raise ValueError("SortingAnalyzer need extension 'templates' to be computed to retrieve templates") else: From 72357a68a25acc58c6634300d574e056a1e46857 Mon Sep 17 00:00:00 2001 From: OlivierPeron <79974181+OlivierPeron@users.noreply.github.com> Date: Mon, 21 Oct 2024 16:22:52 +0200 Subject: [PATCH 22/61] Update src/spikeinterface/core/template_tools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/core/template_tools.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 769610ad2b..3c8663df70 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -31,8 +31,12 @@ def get_dense_templates_array(one_object: Templates | SortingAnalyzer, return_sc ) ext = one_object.get_extension("templates") if ext is not None: - templates_array = ext.data.get("average") or ext.data.get("median") - assert templates_array is not None, "Average or median templates have not been computed." + if "average" in ext.data: + templates_array = ext.data.get("average") + elif "median" in ext.data: + templates_array = ext.data.get("median") + else: + raise ValueError("Average or median templates have not been computed.") else: raise ValueError("SortingAnalyzer need extension 'templates' to be computed to retrieve templates") else: From f4dd922447eac6d422e43093a5b93a8877df34be Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Mon, 21 Oct 2024 12:44:26 -0400 Subject: [PATCH 23/61] better error message (#3479) --- src/spikeinterface/core/baserecordingsnippets.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 310533c96b..2ec3664a45 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -172,8 +172,10 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) if number_of_device_channel_indices >= self.get_num_channels(): error_msg = ( - f"The given Probe have 'device_channel_indices' that do not match channel count \n" - f"{number_of_device_channel_indices} vs {self.get_num_channels()} \n" + f"The given Probe either has 'device_channel_indices' that does not match channel count \n" + f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" + f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" + f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n" f"device_channel_indices are the following: {device_channel_indices} \n" f"recording channels are the following: {self.get_channel_ids()} \n" ) From 0be00cf32fad46b1a55d7018ea051014644568ab Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 21 Oct 2024 18:47:14 +0200 Subject: [PATCH 24/61] Update src/spikeinterface/postprocessing/localization_tools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/postprocessing/localization_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index a17abea1eb..3372a34c98 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -667,7 +667,7 @@ def compute_location_max_channel( """ Localize a unit using max channel. - This uses interrnally `get_template_extremum_channel()` + This uses internally `get_template_extremum_channel()` Parameters From 0002edcbe99764fcf65edae405a2be59647691f1 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 21 Oct 2024 18:47:23 +0200 Subject: [PATCH 25/61] Update src/spikeinterface/postprocessing/localization_tools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/postprocessing/localization_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 3372a34c98..67d469f85c 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -686,7 +686,7 @@ def compute_location_max_channel( Returns ------- - unit_location: np.ndarray + unit_locations: np.ndarray 2d """ extremum_channels_index = get_template_extremum_channel( From b4e681d8524d119a86ba21fd9a19c5e6716c1ca0 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 21 Oct 2024 18:47:33 +0200 Subject: [PATCH 26/61] Update src/spikeinterface/postprocessing/localization_tools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/postprocessing/localization_tools.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 67d469f85c..a073b6c518 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -697,11 +697,11 @@ def compute_location_max_channel( unit_ids = templates_or_sorting_analyzer.unit_ids else: unit_ids = np.asarray(unit_ids) - unit_location = np.zeros((unit_ids.size, 2), dtype="float32") + unit_locations = np.zeros((unit_ids.size, 2), dtype="float32") for i, unit_id in enumerate(unit_ids): - unit_location[i, :] = contact_locations[extremum_channels_index[unit_id]] + unit_locations[i, :] = contact_locations[extremum_channels_index[unit_id]] - return unit_location + return unit_locations _unit_location_methods = { From 55b50abdebf5f1e0aa0c00307d706588b1d9038d Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 21 Oct 2024 18:47:54 +0200 Subject: [PATCH 27/61] Update src/spikeinterface/postprocessing/localization_tools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/postprocessing/localization_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index a073b6c518..837b983059 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -684,7 +684,7 @@ def compute_location_max_channel( * "at_index" : take value at `nbefore` index * "peak_to_peak" : take the peak-to-peak amplitude - Returns + Returns ------- unit_locations: np.ndarray 2d From c2f980c6bb40df45d411904dd5f58288fa7b8ad3 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Mon, 21 Oct 2024 14:28:20 -0400 Subject: [PATCH 28/61] typos --- doc/modules/curation.rst | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index d115b33e4a..d24fc810b0 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -88,7 +88,7 @@ The ``censored_period_ms`` parameter is the time window in milliseconds to consi The :py:func:`~spikeinterface.curation.remove_redundand_units` function removes redundant units from the sorting output. Redundant units are units that share over a certain percentage of spikes, by default 80%. -The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object. +The function can act both on a ``BaseSorting`` or a ``SortingAnalyzer`` object. .. code-block:: python @@ -102,13 +102,16 @@ The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object. ) # remove redundant units from SortingAnalyzer object - clean_sorting_analyzer = remove_redundant_units( + # note this returns a cleaned sorting + clean_sorting = remove_redundant_units( sorting_analyzer, duplicate_threshold=0.9, remove_strategy="min_shift" ) + # in order to have a sorter with only the non-redundant units do: + clean_sorting_analyzer = sorting_analyzer.select_units(clean_sorting.unit_ids) -We recommend usinf the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps +We recommend using the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps the unit (among the redundant ones), with a better template alignment. From 61ce0007476b9a1e80d69bdbbc42e70d8bcce626 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Mon, 21 Oct 2024 14:32:37 -0400 Subject: [PATCH 29/61] better comment --- doc/modules/curation.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index d24fc810b0..37de992806 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -108,7 +108,9 @@ The function can act both on a ``BaseSorting`` or a ``SortingAnalyzer`` object. duplicate_threshold=0.9, remove_strategy="min_shift" ) - # in order to have a sorter with only the non-redundant units do: + # in order to have a SortingAnalyer with only the non-redundant units one must + # select the designed units remembering to give format and folder if one wants + # a persistent SortingAnalyzer. clean_sorting_analyzer = sorting_analyzer.select_units(clean_sorting.unit_ids) We recommend using the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps From 1ce2f8dc411acadf7bdf6bdd537f616d448a536c Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Wed, 23 Oct 2024 14:58:41 +0200 Subject: [PATCH 30/61] merci alessio Co-authored-by: Alessio Buccino --- src/spikeinterface/core/recording_tools.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 01f541ddac..790511ad88 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -744,11 +744,11 @@ def get_noise_levels( You can use standard deviation with `method="std"` Internally it samples some chunk across segment. - And then, it use MAD estimator (more robust than STD) ot the STD on each chunk. - Finally the average on all MAD is performed. + And then, it uses the MAD estimator (more robust than STD) or the STD on each chunk. + Finally the average of all MAD/STD values is performed. - The result is cached in a property of the recording. - Next call on the same recording will use the cache unless force_recompute=True. + The result is cached in a property of the recording, so that the next call on the same + recording will use the cached result unless `force_recompute=True`. Parameters ---------- From eb6219999ce1d8d1c898575ec9a71045c6a8bcb1 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 23 Oct 2024 15:50:51 +0200 Subject: [PATCH 31/61] wip --- src/spikeinterface/core/recording_tools.py | 27 +++++++------------ .../preprocessing/silence_periods.py | 4 ++- .../preprocessing/tests/test_silence.py | 5 +++- .../preprocessing/tests/test_whiten.py | 7 +++-- 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 790511ad88..a4feff4d14 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -601,11 +601,16 @@ def get_random_data_chunks( recording, return_scaled=False, concatenated=True, - random_slices_kwargs={}, - **kwargs, + **random_slices_kwargs ): """ - Extract random chunks across segments + Extract random chunks across segments. + + Internally, it uses `get_random_recording_slices()` and retrieves the traces chunk as a list + or a concatenated unique array. + + Please read `get_random_recording_slices()` for more details on parameters. + Parameters ---------- @@ -617,7 +622,7 @@ def get_random_data_chunks( Number of chunks per segment concatenated : bool, default: True If True chunk are concatenated along time axis - random_slices_kwargs : dict + **random_slices_kwargs : dict Options transmited to get_random_recording_slices(), please read documentation from this function for more details. @@ -626,18 +631,6 @@ def get_random_data_chunks( chunk_list : np.array Array of concatenate chunks per segment """ - if len(kwargs) > 0: - # This is to keep backward compatibility - # lets keep for a while and remove this maybe in 0.103.0 - msg = ( - "get_random_data_chunks(recording, num_chunks_per_segment=20) is deprecated\n" - "Now, you need to use get_random_data_chunks(recording, random_slices_kwargs=dict(num_chunks_per_segment=20))\n" - "Please read get_random_recording_slices() documentation for more options." - ) - assert len(random_slices_kwargs) ==0, msg - warnings.warn(msg) - random_slices_kwargs = kwargs - recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) chunk_list = [] @@ -797,7 +790,7 @@ def get_noise_levels( if "chunk_size" in job_kwargs: random_slices_kwargs["chunk_size"] = job_kwargs["chunk_size"] - recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) + recording_slices = get_random_recording_slices(recording, random_slices_kwargs) noise_levels_chunks = [] def append_noise_chunk(res): diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 8f38f01469..3129acd3f3 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -71,8 +71,10 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see if mode in ["noise"]: if noise_levels is None: + random_chunk_kwargs = random_chunk_kwargs.copy() + random_chunk_kwargs["seed"] = seed noise_levels = get_noise_levels( - recording, return_scaled=False, concatenated=True, seed=seed, **random_chunk_kwargs + recording, return_scaled=False, random_slices_kwargs=random_chunk_kwargs ) noise_generator = NoiseGeneratorRecording( num_channels=recording.get_num_channels(), diff --git a/src/spikeinterface/preprocessing/tests/test_silence.py b/src/spikeinterface/preprocessing/tests/test_silence.py index 6c2e8ec8b5..6405b6b0c4 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence.py +++ b/src/spikeinterface/preprocessing/tests/test_silence.py @@ -9,6 +9,8 @@ import numpy as np +from pathlib import Path + def test_silence(create_cache_folder): @@ -46,4 +48,5 @@ def test_silence(create_cache_folder): if __name__ == "__main__": - test_silence() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" + test_silence(cache_folder) diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index 04b731de4f..3444323488 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -5,13 +5,15 @@ from spikeinterface.preprocessing import whiten, scale, compute_whitening_matrix +from pathlib import Path + def test_whiten(create_cache_folder): cache_folder = create_cache_folder rec = generate_recording(num_channels=4, seed=2205) print(rec.get_channel_locations()) - random_chunk_kwargs = {} + random_chunk_kwargs = {"seed": 2205} W1, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None) # print(W) # print(M) @@ -47,4 +49,5 @@ def test_whiten(create_cache_folder): if __name__ == "__main__": - test_whiten() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" + test_whiten(cache_folder) From 68b4b200907be01e63149dd673b49f1f02f9b821 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 23 Oct 2024 17:14:06 +0200 Subject: [PATCH 32/61] small updates on auto merge + renaming --- src/spikeinterface/curation/__init__.py | 2 +- src/spikeinterface/curation/auto_merge.py | 197 ++++++++++-------- .../curation/tests/test_auto_merge.py | 49 +++-- 3 files changed, 137 insertions(+), 111 deletions(-) diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 657b936fb9..579e47a553 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -3,7 +3,7 @@ from .remove_redundant import remove_redundant_units, find_redundant_units from .remove_duplicated_spikes import remove_duplicated_spikes from .remove_excess_spikes import remove_excess_spikes -from .auto_merge import get_potential_auto_merge +from .auto_merge import compute_merge_unit_groups, auto_merge, get_potential_auto_merge # manual sorting, diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 86f47af0eb..16147a6225 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -1,5 +1,7 @@ from __future__ import annotations +import warnings + from typing import Tuple import numpy as np import math @@ -17,19 +19,50 @@ from .mergeunitssorting import MergeUnitsSorting from .curation_tools import resolve_merging_graph - -_possible_presets = ["similarity_correlograms", "x_contaminations", "temporal_splits", "feature_neighbors"] +_compute_merge_persets = { + "similarity_correlograms":[ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "correlogram", + "quality_score", + ], + "temporal_splits":[ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "presence_distance", + "quality_score", + ], + "x_contaminations":[ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "cross_contamination", + "quality_score", + ], + "feature_neighbors":[ + "num_spikes", + "snr", + "remove_contaminated", + "unit_locations", + "knn", + "quality_score", + ] +} _required_extensions = { - "unit_locations": ["unit_locations"], + "unit_locations": ["templates", "unit_locations"], "correlogram": ["correlograms"], - "snr": ["noise_levels", "templates"], - "template_similarity": ["template_similarity"], - "knn": ["spike_locations", "spike_amplitudes"], + "snr": ["templates","noise_levels", "templates"], + "template_similarity": ["templates", "template_similarity"], + "knn": ["templates", "spike_locations", "spike_amplitudes"], + "spike_amplitudes" : ["templates"], } -_templates_needed = ["unit_locations", "snr", "template_similarity", "knn", "spike_amplitudes"] - _default_step_params = { "num_spikes": {"min_spikes": 100}, @@ -55,17 +88,18 @@ } -def auto_merges( + +def compute_merge_unit_groups( sorting_analyzer: SortingAnalyzer, preset: str | None = "similarity_correlograms", - resolve_graph: bool = False, + resolve_graph: bool = True, steps_params: dict = None, compute_needed_extensions: bool = True, extra_outputs: bool = False, steps: list[str] | None = None, force_copy: bool = True, **job_kwargs, -) -> list[tuple[int | str, int | str]] | Tuple[tuple[int | str, int | str], dict]: +) -> list[tuple[int | str, int | str]] | Tuple[list[tuple[int | str, int | str]], dict]: """ Algorithm to find and check potential merges between units. @@ -110,7 +144,7 @@ def auto_merges( | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations", | "knn", "quality_score" If `preset` is None, you can specify the steps manually with the `steps` parameter. - resolve_graph : bool, default: False + resolve_graph : bool, default: True If True, the function resolves the potential unit pairs to be merged into multiple-unit merges. compute_needed_extensions : bool, default : True Should we force the computation of needed extensions? @@ -128,9 +162,10 @@ def auto_merges( Returns ------- - potential_merges: - A list of tuples of 2 elements (if `resolve_graph`if false) or 2+ elements (if `resolve_graph` is true). - List of pairs that could be merged. + merge_unit_groups: + List of groups that need to be merge. + When `resolve_graph` is true (default) a list of tuples of 2+ elements + If `resolve_graph` is false then a list of tuple of 2 elements is returned instead. outs: Returned only when extra_outputs=True A dictionary that contains data for debugging and plotting. @@ -146,62 +181,17 @@ def auto_merges( sorting = sorting_analyzer.sorting unit_ids = sorting.unit_ids - all_steps = [ - "num_spikes", - "snr", - "remove_contaminated", - "unit_locations", - "correlogram", - "template_similarity", - "presence_distance", - "knn", - "cross_contamination", - "quality_score", - ] - if preset is not None and preset not in _possible_presets: - raise ValueError(f"preset must be one of {_possible_presets}") - - if steps is None: - if preset is None: - if steps is None: - raise ValueError("You need to specify a preset or steps for the auto-merge function") - elif preset == "similarity_correlograms": - steps = [ - "num_spikes", - "remove_contaminated", - "unit_locations", - "template_similarity", - "correlogram", - "quality_score", - ] - elif preset == "temporal_splits": - steps = [ - "num_spikes", - "remove_contaminated", - "unit_locations", - "template_similarity", - "presence_distance", - "quality_score", - ] - elif preset == "x_contaminations": - steps = [ - "num_spikes", - "remove_contaminated", - "unit_locations", - "template_similarity", - "cross_contamination", - "quality_score", - ] - elif preset == "feature_neighbors": - steps = [ - "num_spikes", - "snr", - "remove_contaminated", - "unit_locations", - "knn", - "quality_score", - ] + if preset is None and steps is None: + raise ValueError("You need to specify a preset or steps for the auto-merge function") + elif steps is not None: + # steps has presendance on presets + pass + elif preset is not None: + if preset not in _compute_merge_persets: + raise ValueError(f"preset must be one of {list(_compute_merge_persets.keys())}") + steps = _compute_merge_persets[preset] + if force_copy and compute_needed_extensions: # To avoid erasing the extensions of the user sorting_analyzer = sorting_analyzer.copy() @@ -212,26 +202,23 @@ def auto_merges( for step in steps: - assert step in all_steps, f"{step} is not a valid step" + assert step in _default_step_params, f"{step} is not a valid step" if step in _required_extensions: for ext in _required_extensions[step]: - if compute_needed_extensions: - if step in _templates_needed: - template_ext = sorting_analyzer.get_extension("templates") - if template_ext is None: - sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - res_ext = sorting_analyzer.get_extension(step) - if res_ext is None: - print( - f"Extension {ext} is computed with default params. Precompute it with custom params if needed" - ) - sorting_analyzer.compute(ext, **job_kwargs) - elif not compute_needed_extensions and not sorting_analyzer.has_extension(ext): + if sorting_analyzer.has_extension(ext): + continue + if not compute_needed_extensions: raise ValueError(f"{step} requires {ext} extension") + + # special case for templates + if ext == "templates" and not sorting_analyzer.has_extension("random_spikes"): + sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) + else: + sorting_analyzer.compute(ext, **job_kwargs) params = _default_step_params.get(step).copy() - if step in steps_params: + if steps_params is not None and step in steps_params: params.update(steps_params[step]) # STEP : remove units with too few spikes @@ -360,15 +347,38 @@ def auto_merges( # FINAL STEP : create the final list from pair_mask boolean matrix ind1, ind2 = np.nonzero(pair_mask) - potential_merges = list(zip(unit_ids[ind1], unit_ids[ind2])) + merge_unit_groups = list(zip(unit_ids[ind1], unit_ids[ind2])) if resolve_graph: - potential_merges = resolve_merging_graph(sorting, potential_merges) + merge_unit_groups = resolve_merging_graph(sorting, merge_unit_groups) if extra_outputs: - return potential_merges, outs + return merge_unit_groups, outs else: - return potential_merges + return merge_unit_groups + +def auto_merge( + sorting_analyzer: SortingAnalyzer, + compute_merge_kwargs:dict = {}, + apply_merge_kwargs: dict = {}, + **job_kwargs + ) -> SortingAnalyzer: + """ + Compute merge unit groups and apply it on a SortingAnalyzer. + Internally uses `compute_merge_unit_groups()` + """ + merge_unit_groups = compute_merge_unit_groups( + sorting_analyzer, + extra_outputs=False, + **compute_merge_kwargs, + **job_kwargs + ) + + merged_analyzer = sorting_analyzer.merge_units( + merge_unit_groups, **apply_merge_kwargs, **job_kwargs + ) + return merged_analyzer + def get_potential_auto_merge( @@ -397,6 +407,9 @@ def get_potential_auto_merge( steps: list[str] | None = None, ) -> list[tuple[int | str, int | str]] | Tuple[tuple[int | str, int | str], dict]: """ + This function is deprecated. Use compute_merge_unit_groups() instead. + This will be removed in 0.103.0 + Algorithm to find and check potential merges between units. The merges are proposed based on a series of steps with different criteria: @@ -505,9 +518,15 @@ def get_potential_auto_merge( done by Aurelien Wyngaard and Victor Llobet. https://github.com/BarbourLab/lussac/blob/v1.0.0/postprocessing/merge_units.py """ + warnings.warn( + "get_potential_auto_merge() is deprecated. Use compute_merge_unit_groups() instead", + DeprecationWarning, + stacklevel=2, + ) + presence_distance_kwargs = presence_distance_kwargs or dict() knn_kwargs = knn_kwargs or dict() - return auto_merges( + return compute_merge_unit_groups( sorting_analyzer, preset, resolve_graph, diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index 33fd06d27a..ebd7bf1504 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -3,16 +3,16 @@ from spikeinterface.core import create_sorting_analyzer from spikeinterface.core.generate import inject_some_split_units -from spikeinterface.curation import get_potential_auto_merge +from spikeinterface.curation import compute_merge_unit_groups, auto_merge from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation @pytest.mark.parametrize( - "preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms"] + "preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms", None] ) -def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): +def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset): print(sorting_analyzer_for_curation) sorting = sorting_analyzer_for_curation.sorting @@ -47,32 +47,37 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): ) if preset is not None: - potential_merges, outs = get_potential_auto_merge( + # do not resolve graph for checking true pairs + merge_unit_groups, outs = compute_merge_unit_groups( sorting_analyzer, preset=preset, - min_spikes=1000, - max_distance_um=150.0, - contamination_thresh=0.2, - corr_diff_thresh=0.16, - template_diff_thresh=0.25, - censored_period_ms=0.0, - refractory_period_ms=4.0, - sigma_smooth_ms=0.6, - adaptative_window_thresh=0.5, - firing_contamination_balance=1.5, + resolve_graph=False, + # min_spikes=1000, + # max_distance_um=150.0, + # contamination_thresh=0.2, + # corr_diff_thresh=0.16, + # template_diff_thresh=0.25, + # censored_period_ms=0.0, + # refractory_period_ms=4.0, + # sigma_smooth_ms=0.6, + # adaptative_window_thresh=0.5, + # firing_contamination_balance=1.5, extra_outputs=True, + **job_kwargs ) if preset == "x_contaminations": - assert len(potential_merges) == num_unit_splited + assert len(merge_unit_groups) == num_unit_splited for true_pair in other_ids.values(): true_pair = tuple(true_pair) - assert true_pair in potential_merges + assert true_pair in merge_unit_groups else: # when preset is None you have to specify the steps with pytest.raises(ValueError): - potential_merges = get_potential_auto_merge(sorting_analyzer, preset=preset) - potential_merges = get_potential_auto_merge( - sorting_analyzer, preset=preset, steps=["min_spikes", "min_snr", "remove_contaminated", "unit_positions"] + merge_unit_groups = compute_merge_unit_groups(sorting_analyzer, preset=preset) + merge_unit_groups = compute_merge_unit_groups( + sorting_analyzer, preset=preset, + steps=["num_spikes", "snr", "remove_contaminated", "unit_locations"], + **job_kwargs ) # DEBUG @@ -93,7 +98,7 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): # m = correlograms.shape[2] // 2 - # for unit_id1, unit_id2 in potential_merges[:5]: + # for unit_id1, unit_id2 in merge_unit_groups[:5]: # unit_ind1 = sorting_with_split.id_to_index(unit_id1) # unit_ind2 = sorting_with_split.id_to_index(unit_id2) @@ -129,4 +134,6 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): if __name__ == "__main__": sorting_analyzer = make_sorting_analyzer(sparse=True) - test_get_auto_merge_list(sorting_analyzer) + # preset = "x_contaminations" + preset = None + test_compute_merge_unit_groups(sorting_analyzer, preset=preset) From 4476d4ccc6bde244561936b8ed22c9b7a0032113 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 23 Oct 2024 15:15:47 +0000 Subject: [PATCH 33/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/auto_merge.py | 44 +++++++------------ .../curation/tests/test_auto_merge.py | 7 +-- 2 files changed, 21 insertions(+), 30 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 16147a6225..ec5e8be20c 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -20,7 +20,7 @@ from .curation_tools import resolve_merging_graph _compute_merge_persets = { - "similarity_correlograms":[ + "similarity_correlograms": [ "num_spikes", "remove_contaminated", "unit_locations", @@ -28,7 +28,7 @@ "correlogram", "quality_score", ], - "temporal_splits":[ + "temporal_splits": [ "num_spikes", "remove_contaminated", "unit_locations", @@ -36,7 +36,7 @@ "presence_distance", "quality_score", ], - "x_contaminations":[ + "x_contaminations": [ "num_spikes", "remove_contaminated", "unit_locations", @@ -44,23 +44,23 @@ "cross_contamination", "quality_score", ], - "feature_neighbors":[ + "feature_neighbors": [ "num_spikes", "snr", "remove_contaminated", "unit_locations", "knn", "quality_score", - ] + ], } _required_extensions = { "unit_locations": ["templates", "unit_locations"], "correlogram": ["correlograms"], - "snr": ["templates","noise_levels", "templates"], + "snr": ["templates", "noise_levels", "templates"], "template_similarity": ["templates", "template_similarity"], "knn": ["templates", "spike_locations", "spike_amplitudes"], - "spike_amplitudes" : ["templates"], + "spike_amplitudes": ["templates"], } @@ -88,7 +88,6 @@ } - def compute_merge_unit_groups( sorting_analyzer: SortingAnalyzer, preset: str | None = "similarity_correlograms", @@ -181,7 +180,6 @@ def compute_merge_unit_groups( sorting = sorting_analyzer.sorting unit_ids = sorting.unit_ids - if preset is None and steps is None: raise ValueError("You need to specify a preset or steps for the auto-merge function") elif steps is not None: @@ -210,7 +208,7 @@ def compute_merge_unit_groups( continue if not compute_needed_extensions: raise ValueError(f"{step} requires {ext} extension") - + # special case for templates if ext == "templates" and not sorting_analyzer.has_extension("random_spikes"): sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) @@ -357,30 +355,22 @@ def compute_merge_unit_groups( else: return merge_unit_groups + def auto_merge( - sorting_analyzer: SortingAnalyzer, - compute_merge_kwargs:dict = {}, - apply_merge_kwargs: dict = {}, - **job_kwargs - ) -> SortingAnalyzer: + sorting_analyzer: SortingAnalyzer, compute_merge_kwargs: dict = {}, apply_merge_kwargs: dict = {}, **job_kwargs +) -> SortingAnalyzer: """ Compute merge unit groups and apply it on a SortingAnalyzer. Internally uses `compute_merge_unit_groups()` """ merge_unit_groups = compute_merge_unit_groups( - sorting_analyzer, - extra_outputs=False, - **compute_merge_kwargs, - **job_kwargs + sorting_analyzer, extra_outputs=False, **compute_merge_kwargs, **job_kwargs ) - merged_analyzer = sorting_analyzer.merge_units( - merge_unit_groups, **apply_merge_kwargs, **job_kwargs - ) + merged_analyzer = sorting_analyzer.merge_units(merge_unit_groups, **apply_merge_kwargs, **job_kwargs) return merged_analyzer - def get_potential_auto_merge( sorting_analyzer: SortingAnalyzer, preset: str | None = "similarity_correlograms", @@ -519,10 +509,10 @@ def get_potential_auto_merge( https://github.com/BarbourLab/lussac/blob/v1.0.0/postprocessing/merge_units.py """ warnings.warn( - "get_potential_auto_merge() is deprecated. Use compute_merge_unit_groups() instead", - DeprecationWarning, - stacklevel=2, - ) + "get_potential_auto_merge() is deprecated. Use compute_merge_unit_groups() instead", + DeprecationWarning, + stacklevel=2, + ) presence_distance_kwargs = presence_distance_kwargs or dict() knn_kwargs = knn_kwargs or dict() diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index ebd7bf1504..4c05f41a4c 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -63,7 +63,7 @@ def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset): # adaptative_window_thresh=0.5, # firing_contamination_balance=1.5, extra_outputs=True, - **job_kwargs + **job_kwargs, ) if preset == "x_contaminations": assert len(merge_unit_groups) == num_unit_splited @@ -75,9 +75,10 @@ def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset): with pytest.raises(ValueError): merge_unit_groups = compute_merge_unit_groups(sorting_analyzer, preset=preset) merge_unit_groups = compute_merge_unit_groups( - sorting_analyzer, preset=preset, + sorting_analyzer, + preset=preset, steps=["num_spikes", "snr", "remove_contaminated", "unit_locations"], - **job_kwargs + **job_kwargs, ) # DEBUG From 6ee7299c7f7a0b14fbc49f12abf46907b418a072 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 23 Oct 2024 17:40:57 +0200 Subject: [PATCH 34/61] more updates --- src/spikeinterface/core/recording_tools.py | 10 ++++++---- src/spikeinterface/core/tests/test_recording_tools.py | 2 +- src/spikeinterface/preprocessing/silence_periods.py | 6 +++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index a4feff4d14..1f46de7d29 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -628,7 +628,7 @@ def get_random_data_chunks( Returns ------- - chunk_list : np.array + chunk_list : np.array | list of np.array Array of concatenate chunks per segment """ recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) @@ -757,8 +757,9 @@ def get_noise_levels( random_slices_kwargs : dict Options transmited to get_random_recording_slices(), please read documentation from this function for more details. - **job_kwargs: - Job kwargs for parallel computing. + + {} + Returns ------- noise_levels : array @@ -790,7 +791,7 @@ def get_noise_levels( if "chunk_size" in job_kwargs: random_slices_kwargs["chunk_size"] = job_kwargs["chunk_size"] - recording_slices = get_random_recording_slices(recording, random_slices_kwargs) + recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) noise_levels_chunks = [] def append_noise_chunk(res): @@ -812,6 +813,7 @@ def append_noise_chunk(res): return noise_levels +get_noise_levels.__doc__ = get_noise_levels.__doc__.format(_shared_job_kwargs_doc) def get_chunk_with_margin( diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 07515ef3f0..1fa9ffe124 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -182,7 +182,7 @@ def test_get_random_recording_slices(): def test_get_random_data_chunks(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) - chunks = get_random_data_chunks(rec, random_slices_kwargs=dict(num_chunks_per_segment=50, chunk_size=500, seed=0)) + chunks = get_random_data_chunks(rec, num_chunks_per_segment=50, chunk_size=500, seed=0) assert chunks.shape == (50000, 1) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 3129acd3f3..85169011d8 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -71,10 +71,10 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see if mode in ["noise"]: if noise_levels is None: - random_chunk_kwargs = random_chunk_kwargs.copy() - random_chunk_kwargs["seed"] = seed + random_slices_kwargs = random_chunk_kwargs.copy() + random_slices_kwargs["seed"] = seed noise_levels = get_noise_levels( - recording, return_scaled=False, random_slices_kwargs=random_chunk_kwargs + recording, return_scaled=False, random_slices_kwargs=random_slices_kwargs ) noise_generator = NoiseGeneratorRecording( num_channels=recording.get_num_channels(), From 7517d06cb99d398eaecd48e3c50e063a8796bf7f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 24 Oct 2024 10:09:47 +0200 Subject: [PATCH 35/61] fix scaling seed --- src/spikeinterface/preprocessing/tests/test_scaling.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_scaling.py b/src/spikeinterface/preprocessing/tests/test_scaling.py index 321d7c9df2..e32d96901e 100644 --- a/src/spikeinterface/preprocessing/tests/test_scaling.py +++ b/src/spikeinterface/preprocessing/tests/test_scaling.py @@ -55,11 +55,11 @@ def test_scaling_in_preprocessing_chain(): recording.set_channel_gains(gains) recording.set_channel_offsets(offsets) - centered_recording = CenterRecording(scale_to_uV(recording=recording)) + centered_recording = CenterRecording(scale_to_uV(recording=recording), seed=2205) traces_scaled_with_argument = centered_recording.get_traces(return_scaled=True) # Chain preprocessors - centered_recording_scaled = CenterRecording(scale_to_uV(recording=recording)) + centered_recording_scaled = CenterRecording(scale_to_uV(recording=recording), seed=2205) traces_scaled_with_preprocessor = centered_recording_scaled.get_traces() np.testing.assert_allclose(traces_scaled_with_argument, traces_scaled_with_preprocessor) @@ -68,3 +68,8 @@ def test_scaling_in_preprocessing_chain(): traces_scaled_with_preprocessor_and_argument = centered_recording_scaled.get_traces(return_scaled=True) np.testing.assert_allclose(traces_scaled_with_preprocessor, traces_scaled_with_preprocessor_and_argument) + + +if __name__ == "__main__": + test_scale_to_uV() + test_scaling_in_preprocessing_chain() From bac57fe00450a6fe758a67092c0ff0d8f17ebfb5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 24 Oct 2024 08:39:08 +0000 Subject: [PATCH 36/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../core/analyzer_extension_core.py | 1 - src/spikeinterface/core/job_tools.py | 3 +- src/spikeinterface/core/recording_tools.py | 51 ++++++++++--------- .../tests/test_analyzer_extension_core.py | 2 +- .../core/tests/test_recording_tools.py | 35 ++++++++----- .../preprocessing/tests/test_silence.py | 2 +- .../preprocessing/tests/test_whiten.py | 2 +- 7 files changed, 54 insertions(+), 42 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 38d7ab247c..1d3501c4d0 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -693,7 +693,6 @@ class ComputeNoiseLevels(AnalyzerExtension): need_job_kwargs = False need_backward_compatibility_on_load = True - def __init__(self, sorting_analyzer): AnalyzerExtension.__init__(self, sorting_analyzer) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 8f5df37695..27f05bb36b 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -184,6 +184,7 @@ def ensure_n_jobs(recording, n_jobs=1): return n_jobs + def chunk_duration_to_chunk_size(chunk_duration, recording): if isinstance(chunk_duration, float): chunk_size = int(chunk_duration * recording.get_sampling_frequency()) @@ -196,7 +197,7 @@ def chunk_duration_to_chunk_size(chunk_duration, recording): raise ValueError("chunk_duration must ends with s or ms") chunk_size = int(chunk_duration * recording.get_sampling_frequency()) else: - raise ValueError("chunk_duration must be str or float") + raise ValueError("chunk_duration must be str or float") return chunk_size diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 1f46de7d29..2ab74ce51e 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -514,15 +514,15 @@ def determine_cast_unsigned(recording, dtype): return cast_unsigned - - -def get_random_recording_slices(recording, - method="full_random", - num_chunks_per_segment=20, - chunk_duration="500ms", - chunk_size=None, - margin_frames=0, - seed=None): +def get_random_recording_slices( + recording, + method="full_random", + num_chunks_per_segment=20, + chunk_duration="500ms", + chunk_size=None, + margin_frames=0, + seed=None, +): """ Get random slice of a recording across segments. @@ -593,19 +593,14 @@ def get_random_recording_slices(recording, ] else: raise ValueError(f"get_random_recording_slices : wrong method {method}") - + return recording_slices -def get_random_data_chunks( - recording, - return_scaled=False, - concatenated=True, - **random_slices_kwargs -): +def get_random_data_chunks(recording, return_scaled=False, concatenated=True, **random_slices_kwargs): """ Extract random chunks across segments. - + Internally, it uses `get_random_recording_slices()` and retrieves the traces chunk as a list or a concatenated unique array. @@ -698,7 +693,7 @@ def get_closest_channels(recording, channel_ids=None, num_channels=None): def _noise_level_chunk(segment_index, start_frame, end_frame, worker_ctx): recording = worker_ctx["recording"] - + one_chunk = recording.get_traces( start_frame=start_frame, end_frame=end_frame, @@ -706,7 +701,6 @@ def _noise_level_chunk(segment_index, start_frame, end_frame, worker_ctx): return_scaled=worker_ctx["return_scaled"], ) - if worker_ctx["method"] == "mad": med = np.median(one_chunk, axis=0, keepdims=True) # hard-coded so that core doesn't depend on scipy @@ -724,12 +718,13 @@ def _noise_level_chunk_init(recording, return_scaled, method): worker_ctx["method"] = method return worker_ctx + def get_noise_levels( recording: "BaseRecording", return_scaled: bool = True, method: Literal["mad", "std"] = "mad", force_recompute: bool = False, - random_slices_kwargs : dict = {}, + random_slices_kwargs: dict = {}, **kwargs, ) -> np.ndarray: """ @@ -759,7 +754,7 @@ def get_noise_levels( function for more details. {} - + Returns ------- noise_levels : array @@ -774,7 +769,7 @@ def get_noise_levels( if key in recording.get_property_keys() and not force_recompute: noise_levels = recording.get_property(key=key) else: - # This is to keep backward compatibility + # This is to keep backward compatibility # lets keep for a while and remove this maybe in 0.103.0 # chunk_size used to be in the signature and now is ambiguous random_slices_kwargs_, job_kwargs = split_job_kwargs(kwargs) @@ -794,6 +789,7 @@ def get_noise_levels( recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) noise_levels_chunks = [] + def append_noise_chunk(res): noise_levels_chunks.append(res) @@ -801,8 +797,14 @@ def append_noise_chunk(res): init_func = _noise_level_chunk_init init_args = (recording, return_scaled, method) executor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name="noise_level", verbose=False, - gather_func=append_noise_chunk, **job_kwargs + recording, + func, + init_func, + init_args, + job_name="noise_level", + verbose=False, + gather_func=append_noise_chunk, + **job_kwargs, ) executor.run(all_chunks=recording_slices) noise_levels_chunks = np.stack(noise_levels_chunks) @@ -813,6 +815,7 @@ def append_noise_chunk(res): return noise_levels + get_noise_levels.__doc__ = get_noise_levels.__doc__.format(_shared_job_kwargs_doc) diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index b04155261b..6f5bef3c6c 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -259,7 +259,7 @@ def test_compute_several(create_cache_folder): # test_ComputeWaveforms(format="binary_folder", sparse=False, create_cache_folder=cache_folder) # test_ComputeWaveforms(format="zarr", sparse=True, create_cache_folder=cache_folder) # test_ComputeWaveforms(format="zarr", sparse=False, create_cache_folder=cache_folder) - #test_ComputeRandomSpikes(format="memory", sparse=True, create_cache_folder=cache_folder) + # test_ComputeRandomSpikes(format="memory", sparse=True, create_cache_folder=cache_folder) test_ComputeRandomSpikes(format="binary_folder", sparse=False, create_cache_folder=cache_folder) test_ComputeTemplates(format="memory", sparse=True, create_cache_folder=cache_folder) test_ComputeNoiseLevels(format="memory", sparse=False, create_cache_folder=cache_folder) diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 1fa9ffe124..dad5273f12 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -167,19 +167,18 @@ def test_write_memory_recording(): for shm in shms: shm.unlink() + def test_get_random_recording_slices(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) - rec_slices = get_random_recording_slices(rec, - method="full_random", - num_chunks_per_segment=20, - chunk_duration="500ms", - margin_frames=0, - seed=0) + rec_slices = get_random_recording_slices( + rec, method="full_random", num_chunks_per_segment=20, chunk_duration="500ms", margin_frames=0, seed=0 + ) assert len(rec_slices) == 40 for seg_ind, start, stop in rec_slices: assert stop - start == 500 assert seg_ind in (0, 1) + def test_get_random_data_chunks(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) chunks = get_random_data_chunks(rec, num_chunks_per_segment=50, chunk_size=500, seed=0) @@ -216,7 +215,9 @@ def test_get_noise_levels(): assert np.all(noise_levels_1 == noise_levels_2) assert np.allclose(get_noise_levels(recording, return_scaled=False, **job_kwargs), [std, std], rtol=1e-2, atol=1e-3) - assert np.allclose(get_noise_levels(recording, method="std", return_scaled=False, **job_kwargs), [std, std], rtol=1e-2, atol=1e-3) + assert np.allclose( + get_noise_levels(recording, method="std", return_scaled=False, **job_kwargs), [std, std], rtol=1e-2, atol=1e-3 + ) def test_get_noise_levels_output(): @@ -230,13 +231,21 @@ def test_get_noise_levels_output(): traces = rng.normal(loc=10.0, scale=std, size=(num_samples, num_channels)) recording = NumpyRecording(traces_list=traces, sampling_frequency=sampling_frequency) - std_estimated_with_mad = get_noise_levels(recording, method="mad", return_scaled=False, - random_slices_kwargs=dict(num_chunks_per_segment=40, chunk_size=1_000, seed=seed)) + std_estimated_with_mad = get_noise_levels( + recording, + method="mad", + return_scaled=False, + random_slices_kwargs=dict(num_chunks_per_segment=40, chunk_size=1_000, seed=seed), + ) print(std_estimated_with_mad) assert np.allclose(std_estimated_with_mad, [std, std], rtol=1e-2, atol=1e-3) - std_estimated_with_std = get_noise_levels(recording, method="std", return_scaled=False, - random_slices_kwargs=dict(num_chunks_per_segment=40, chunk_size=1_000, seed=seed)) + std_estimated_with_std = get_noise_levels( + recording, + method="std", + return_scaled=False, + random_slices_kwargs=dict(num_chunks_per_segment=40, chunk_size=1_000, seed=seed), + ) assert np.allclose(std_estimated_with_std, [std, std], rtol=1e-2, atol=1e-3) @@ -358,8 +367,8 @@ def test_do_recording_attributes_match(): # test_write_memory_recording() test_get_random_recording_slices() - # test_get_random_data_chunks() + # test_get_random_data_chunks() # test_get_closest_channels() # test_get_noise_levels() - # test_get_noise_levels_output() + # test_get_noise_levels_output() # test_order_channels_by_depth() diff --git a/src/spikeinterface/preprocessing/tests/test_silence.py b/src/spikeinterface/preprocessing/tests/test_silence.py index 6405b6b0c4..20d4f6dfc7 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence.py +++ b/src/spikeinterface/preprocessing/tests/test_silence.py @@ -48,5 +48,5 @@ def test_silence(create_cache_folder): if __name__ == "__main__": - cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" test_silence(cache_folder) diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index 3444323488..b40627d836 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -49,5 +49,5 @@ def test_whiten(create_cache_folder): if __name__ == "__main__": - cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" test_whiten(cache_folder) From f66ae7fc9c6cf66c5ca35d11dddedfbb2180080d Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Fri, 25 Oct 2024 08:28:41 +0100 Subject: [PATCH 37/61] Compute covariance matrix in float64. --- src/spikeinterface/preprocessing/whiten.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 195969ff79..91c74c423f 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -124,7 +124,7 @@ def __init__(self, parent_recording_segment, W, M, dtype, int_scale): def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) traces_dtype = traces.dtype - # if uint --> force int + # if uint --> force float if traces_dtype.kind == "u": traces = traces.astype("float32") @@ -185,6 +185,7 @@ def compute_whitening_matrix( """ random_data = get_random_data_chunks(recording, concatenated=True, return_scaled=False, **random_chunk_kwargs) + random_data = random_data.astype(np.float64) regularize_kwargs = regularize_kwargs if regularize_kwargs is not None else {"method": "GraphicalLassoCV"} From b5c260aed812d6fb6202ffdf13d35e28d79ff4e9 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Fri, 25 Oct 2024 09:16:26 +0100 Subject: [PATCH 38/61] Update docstring. --- src/spikeinterface/preprocessing/whiten.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 91c74c423f..1c81f2ae42 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -19,6 +19,8 @@ class WhitenRecording(BasePreprocessor): recording : RecordingExtractor The recording extractor to be whitened. dtype : None or dtype, default: None + Datatype of the output recording (covariance matrix estimation + and whitening are performed in float64. If None the the parent dtype is kept. For integer dtype a int_scale must be also given. mode : "global" | "local", default: "global" @@ -74,7 +76,8 @@ def __init__( dtype_ = fix_dtype(recording, dtype) if dtype_.kind == "i": - assert int_scale is not None, "For recording with dtype=int you must set dtype=float32 OR set a int_scale" + assert int_scale is not None, ("For recording with dtype=int you must set the output dtype to float " + " OR set a int_scale") if W is not None: W = np.asarray(W) From 18cfb2b385d9cf5e18d622097a41631d94a0e9a8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 25 Oct 2024 08:18:56 +0000 Subject: [PATCH 39/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/whiten.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 1c81f2ae42..4e3135c3e9 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -76,8 +76,9 @@ def __init__( dtype_ = fix_dtype(recording, dtype) if dtype_.kind == "i": - assert int_scale is not None, ("For recording with dtype=int you must set the output dtype to float " - " OR set a int_scale") + assert int_scale is not None, ( + "For recording with dtype=int you must set the output dtype to float " " OR set a int_scale" + ) if W is not None: W = np.asarray(W) From 98e5db95aa36a415d520cfe758113fc7c5db9bac Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 25 Oct 2024 10:42:58 +0200 Subject: [PATCH 40/61] recording_slices in run_node_pipeline() --- src/spikeinterface/core/job_tools.py | 22 +++++++++--------- src/spikeinterface/core/node_pipeline.py | 7 +++++- src/spikeinterface/core/recording_tools.py | 2 +- .../core/tests/test_node_pipeline.py | 23 +++++++++++++++---- .../sortingcomponents/peak_detection.py | 6 +++++ 5 files changed, 42 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 27f05bb36b..7a6172369b 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -149,12 +149,12 @@ def divide_segment_into_chunks(num_frames, chunk_size): def divide_recording_into_chunks(recording, chunk_size): - all_chunks = [] + recording_slices = [] for segment_index in range(recording.get_num_segments()): num_frames = recording.get_num_samples(segment_index) chunks = divide_segment_into_chunks(num_frames, chunk_size) - all_chunks.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) - return all_chunks + recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) + return recording_slices def ensure_n_jobs(recording, n_jobs=1): @@ -387,13 +387,13 @@ def __init__( f"chunk_duration={chunk_duration_str}", ) - def run(self, all_chunks=None): + def run(self, recording_slices=None): """ Runs the defined jobs. """ - if all_chunks is None: - all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size) + if recording_slices is None: + recording_slices = divide_recording_into_chunks(self.recording, self.chunk_size) if self.handle_returns: returns = [] @@ -402,17 +402,17 @@ def run(self, all_chunks=None): if self.n_jobs == 1: if self.progress_bar: - all_chunks = tqdm(all_chunks, ascii=True, desc=self.job_name) + recording_slices = tqdm(recording_slices, ascii=True, desc=self.job_name) worker_ctx = self.init_func(*self.init_args) - for segment_index, frame_start, frame_stop in all_chunks: + for segment_index, frame_start, frame_stop in recording_slices: res = self.func(segment_index, frame_start, frame_stop, worker_ctx) if self.handle_returns: returns.append(res) if self.gather_func is not None: self.gather_func(res) else: - n_jobs = min(self.n_jobs, len(all_chunks)) + n_jobs = min(self.n_jobs, len(recording_slices)) # parallel with ProcessPoolExecutor( @@ -421,10 +421,10 @@ def run(self, all_chunks=None): mp_context=mp.get_context(self.mp_context), initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process), ) as executor: - results = executor.map(function_wrapper, all_chunks) + results = executor.map(function_wrapper, recording_slices) if self.progress_bar: - results = tqdm(results, desc=self.job_name, total=len(all_chunks)) + results = tqdm(results, desc=self.job_name, total=len(recording_slices)) for res in results: if self.handle_returns: diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index d90a20902d..8ca4ba7f3a 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -489,6 +489,7 @@ def run_node_pipeline( names=None, verbose=False, skip_after_n_peaks=None, + recording_slices=None, ): """ Machinery to compute in parallel operations on peaks and traces. @@ -540,6 +541,10 @@ def run_node_pipeline( skip_after_n_peaks : None | int Skip the computation after n_peaks. This is not an exact because internally this skip is done per worker in average. + recording_slices : None | list[tuple] + Optionaly give a list of slices to run the pipeline only on some chunks of the recording. + It must be a list of (segment_index, frame_start, frame_stop). + If None (default), the entire recording is computed. Returns ------- @@ -578,7 +583,7 @@ def run_node_pipeline( **job_kwargs, ) - processor.run() + processor.run(recording_slices=recording_slices) outs = gather_func.finalize_buffers(squeeze_output=squeeze_output) return outs diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 2ab74ce51e..4aabbfd587 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -806,7 +806,7 @@ def append_noise_chunk(res): gather_func=append_noise_chunk, **job_kwargs, ) - executor.run(all_chunks=recording_slices) + executor.run(recording_slices=recording_slices) noise_levels_chunks = np.stack(noise_levels_chunks) noise_levels = np.mean(noise_levels_chunks, axis=0) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index deef2291c6..400a71c424 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -4,7 +4,7 @@ import shutil from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording - +from spikeinterface.core.job_tools import divide_recording_into_chunks # from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.node_pipeline import ( @@ -191,8 +191,8 @@ def test_run_node_pipeline(cache_folder_creation): unpickled_node = pickle.loads(pickled_node) -def test_skip_after_n_peaks(): - recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0]) +def test_skip_after_n_peaks_and_recording_slices(): + recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0], seed=2205) # job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False) @@ -211,18 +211,31 @@ def test_skip_after_n_peaks(): node1 = AmplitudeExtractionNode(recording, parents=[node0], param0=6.6, return_output=True) nodes = [node0, node1] + # skip skip_after_n_peaks = 30 some_amplitudes = run_node_pipeline( recording, nodes, job_kwargs, gather_mode="memory", skip_after_n_peaks=skip_after_n_peaks ) - assert some_amplitudes.size >= skip_after_n_peaks assert some_amplitudes.size < spikes.size + # slices : 1 every 4 + recording_slices = divide_recording_into_chunks(recording, 10_000) + recording_slices = recording_slices[::4] + some_amplitudes = run_node_pipeline( + recording, nodes, job_kwargs, gather_mode="memory", recording_slices=recording_slices + ) + tolerance = 1.2 + assert some_amplitudes.size < (spikes.size // 4) * tolerance + + + + + # the following is for testing locally with python or ipython. It is not used in ci or with pytest. if __name__ == "__main__": # folder = Path("./cache_folder/core") # test_run_node_pipeline(folder) - test_skip_after_n_peaks() + test_skip_after_n_peaks_and_recording_slices() diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 5b1d33b334..233b16dcf7 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -57,6 +57,7 @@ def detect_peaks( folder=None, names=None, skip_after_n_peaks=None, + recording_slices=None, **kwargs, ): """Peak detection based on threshold crossing in term of k x MAD. @@ -83,6 +84,10 @@ def detect_peaks( skip_after_n_peaks : None | int Skip the computation after n_peaks. This is not an exact because internally this skip is done per worker in average. + recording_slices : None | list[tuple] + Optionaly give a list of slices to run the pipeline only on some chunks of the recording. + It must be a list of (segment_index, frame_start, frame_stop). + If None (default), the entire recording is computed. {method_doc} {job_doc} @@ -135,6 +140,7 @@ def detect_peaks( folder=folder, names=names, skip_after_n_peaks=skip_after_n_peaks, + recording_slices=recording_slices, ) return outs From aaa689fa9174e8576550528224431b9ea3e32759 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 25 Oct 2024 08:47:02 +0000 Subject: [PATCH 41/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/tests/test_node_pipeline.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 400a71c424..028eaecf12 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -229,10 +229,6 @@ def test_skip_after_n_peaks_and_recording_slices(): assert some_amplitudes.size < (spikes.size // 4) * tolerance - - - - # the following is for testing locally with python or ipython. It is not used in ci or with pytest. if __name__ == "__main__": # folder = Path("./cache_folder/core") From f0f7f6c7165b76f07706254597c5e0730691789a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 28 Oct 2024 10:16:27 +0100 Subject: [PATCH 42/61] Update src/spikeinterface/curation/auto_merge.py Co-authored-by: Alessio Buccino --- src/spikeinterface/curation/auto_merge.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index ec5e8be20c..73b69426f1 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -154,7 +154,8 @@ def compute_merge_unit_groups( Pontential steps : "num_spikes", "snr", "remove_contaminated", "unit_locations", "correlogram", "template_similarity", "presence_distance", "cross_contamination", "knn", "quality_score" Please check steps explanations above! - steps_params : A dictionary whose keys are the steps, and keys are steps parameters. + steps_params : dict + A dictionary whose keys are the steps, and keys are steps parameters. force_copy : boolean, default: True When new extensions are computed, the default is to make a copy of the analyzer, to avoid overwriting already computed extensions. False if you want to overwrite From 0dd48c424e437e9729af16f44101e881ba1d968e Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 28 Oct 2024 10:27:17 +0100 Subject: [PATCH 43/61] Typos and signatures --- src/spikeinterface/curation/auto_merge.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 73b69426f1..6680a70af4 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -19,7 +19,7 @@ from .mergeunitssorting import MergeUnitsSorting from .curation_tools import resolve_merging_graph -_compute_merge_persets = { +_compute_merge_presets = { "similarity_correlograms": [ "num_spikes", "remove_contaminated", @@ -146,7 +146,7 @@ def compute_merge_unit_groups( resolve_graph : bool, default: True If True, the function resolves the potential unit pairs to be merged into multiple-unit merges. compute_needed_extensions : bool, default : True - Should we force the computation of needed extensions? + Should we force the computation of needed extensions, if not already computed? extra_outputs : bool, default: False If True, an additional dictionary (`outs`) with processed data is returned. steps : None or list of str, default: None @@ -172,9 +172,11 @@ def compute_merge_unit_groups( References ---------- - This function is inspired and built upon similar functions from Lussac [Llobet]_, + This function used to be inspired and built upon similar functions from Lussac [Llobet]_, done by Aurelien Wyngaard and Victor Llobet. https://github.com/BarbourLab/lussac/blob/v1.0.0/postprocessing/merge_units.py + + However, it has been greatly consolidated and refined depending on the presets. """ import scipy @@ -187,11 +189,11 @@ def compute_merge_unit_groups( # steps has presendance on presets pass elif preset is not None: - if preset not in _compute_merge_persets: - raise ValueError(f"preset must be one of {list(_compute_merge_persets.keys())}") - steps = _compute_merge_persets[preset] + if preset not in _compute_merge_presets: + raise ValueError(f"preset must be one of {list(_compute_merge_presets.keys())}") + steps = _compute_merge_presets[preset] - if force_copy and compute_needed_extensions: + if force_copy: # To avoid erasing the extensions of the user sorting_analyzer = sorting_analyzer.copy() @@ -357,7 +359,7 @@ def compute_merge_unit_groups( return merge_unit_groups -def auto_merge( +def auto_merge_units( sorting_analyzer: SortingAnalyzer, compute_merge_kwargs: dict = {}, apply_merge_kwargs: dict = {}, **job_kwargs ) -> SortingAnalyzer: """ From a0587f6e04a210fe6bbde62e8b759176c69a47c3 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 28 Oct 2024 10:30:35 +0100 Subject: [PATCH 44/61] Cleaning requiered extensions --- src/spikeinterface/curation/auto_merge.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 6680a70af4..52dffc0378 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -57,10 +57,9 @@ _required_extensions = { "unit_locations": ["templates", "unit_locations"], "correlogram": ["correlograms"], - "snr": ["templates", "noise_levels", "templates"], + "snr": ["templates", "noise_levels"], "template_similarity": ["templates", "template_similarity"], - "knn": ["templates", "spike_locations", "spike_amplitudes"], - "spike_amplitudes": ["templates"], + "knn": ["templates", "spike_locations", "spike_amplitudes"] } From f22a4cc95a690ee5c0d89608a79c93d9207ca2be Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Oct 2024 09:31:21 +0000 Subject: [PATCH 45/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/auto_merge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 52dffc0378..dfcd7bbb17 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -59,7 +59,7 @@ "correlogram": ["correlograms"], "snr": ["templates", "noise_levels"], "template_similarity": ["templates", "template_similarity"], - "knn": ["templates", "spike_locations", "spike_amplitudes"] + "knn": ["templates", "spike_locations", "spike_amplitudes"], } From 516acc9dda2c55bd5014f3ac4cee4350d3940607 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 28 Oct 2024 12:14:46 +0100 Subject: [PATCH 46/61] Names --- src/spikeinterface/curation/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 579e47a553..0302ffe5b7 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -3,7 +3,7 @@ from .remove_redundant import remove_redundant_units, find_redundant_units from .remove_duplicated_spikes import remove_duplicated_spikes from .remove_excess_spikes import remove_excess_spikes -from .auto_merge import compute_merge_unit_groups, auto_merge, get_potential_auto_merge +from .auto_merge import compute_merge_unit_groups, auto_merge_units, get_potential_auto_merge # manual sorting, From d17181f3bd68f602780ad99e1b618aa3f793b8ad Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Oct 2024 13:08:56 +0100 Subject: [PATCH 47/61] Update src/spikeinterface/preprocessing/whiten.py --- src/spikeinterface/preprocessing/whiten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 4e3135c3e9..505e8a330a 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -20,7 +20,7 @@ class WhitenRecording(BasePreprocessor): The recording extractor to be whitened. dtype : None or dtype, default: None Datatype of the output recording (covariance matrix estimation - and whitening are performed in float64. + and whitening are performed in float64). If None the the parent dtype is kept. For integer dtype a int_scale must be also given. mode : "global" | "local", default: "global" From 78738ef679ebf8de5c4a16769aa879e51f68cf29 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 29 Oct 2024 14:41:51 +0100 Subject: [PATCH 48/61] WIP --- src/spikeinterface/curation/auto_merge.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index dfcd7bbb17..f7110f131d 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -258,7 +258,7 @@ def compute_merge_unit_groups( outs["remove_contaminated"] = to_remove # STEP : unit positions are estimated roughly with channel - elif step == "unit_locations" in steps: + elif step == "unit_locations": location_ext = sorting_analyzer.get_extension("unit_locations") unit_locations = location_ext.get_data()[:, :2] @@ -267,7 +267,7 @@ def compute_merge_unit_groups( outs["unit_distances"] = unit_distances # STEP : potential auto merge by correlogram - elif step == "correlogram" in steps: + elif step == "correlogram": correlograms_ext = sorting_analyzer.get_extension("correlograms") correlograms, bins = correlograms_ext.get_data() censor_ms = params["censor_correlograms_ms"] @@ -297,7 +297,7 @@ def compute_merge_unit_groups( outs["win_sizes"] = win_sizes # STEP : check if potential merge with CC also have template similarity - elif step == "template_similarity" in steps: + elif step == "template_similarity": template_similarity_ext = sorting_analyzer.get_extension("template_similarity") templates_similarity = template_similarity_ext.get_data() templates_diff = 1 - templates_similarity @@ -305,11 +305,11 @@ def compute_merge_unit_groups( outs["templates_diff"] = templates_diff # STEP : check the vicinity of the spikes - elif step == "knn" in steps: + elif step == "knn": pair_mask = get_pairs_via_nntree(sorting_analyzer, **params, pair_mask=pair_mask) # STEP : check how the rates overlap in times - elif step == "presence_distance" in steps: + elif step == "presence_distance": presence_distance_kwargs = params.copy() presence_distance_thresh = presence_distance_kwargs.pop("presence_distance_thresh") num_samples = [ @@ -322,7 +322,7 @@ def compute_merge_unit_groups( outs["presence_distances"] = presence_distances # STEP : check if the cross contamination is significant - elif step == "cross_contamination" in steps: + elif step == "cross_contamination": refractory = ( params["censored_period_ms"], params["refractory_period_ms"], @@ -334,7 +334,7 @@ def compute_merge_unit_groups( outs["cross_contaminations"] = CC, p_values # STEP : validate the potential merges with CC increase the contamination quality metrics - elif step == "quality_score" in steps: + elif step == "quality_score": pair_mask, pairs_decreased_score = check_improve_contaminations_score( sorting_analyzer, pair_mask, From 71e38e023ab660b28957c44d518477bfabf1782b Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 29 Oct 2024 15:17:13 +0100 Subject: [PATCH 49/61] Mix up with default params. Bringing back order --- src/spikeinterface/curation/auto_merge.py | 24 +++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index f7110f131d..12f7f9eac3 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -66,11 +66,13 @@ _default_step_params = { "num_spikes": {"min_spikes": 100}, "snr": {"min_snr": 2}, - "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, - "unit_locations": {"max_distance_um": 50}, + "remove_contaminated": {"contamination_thresh": 0.2, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3}, + "unit_locations": {"max_distance_um": 150}, "correlogram": { "corr_diff_thresh": 0.16, - "censor_correlograms_ms": 0.3, + "censor_correlograms_ms": 0.15, "sigma_smooth_ms": 0.6, "adaptative_window_thresh": 0.5, }, @@ -83,7 +85,9 @@ "refractory_period_ms": 1.0, "censored_period_ms": 0.3, }, - "quality_score": {"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + "quality_score": {"firing_contamination_balance": 1.5, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3}, } @@ -391,7 +395,7 @@ def get_potential_auto_merge( sigma_smooth_ms: float = 0.6, adaptative_window_thresh: float = 0.5, censor_correlograms_ms: float = 0.15, - firing_contamination_balance: float = 2.5, + firing_contamination_balance: float = 1.5, k_nn: int = 10, knn_kwargs: dict | None = None, presence_distance_kwargs: dict | None = None, @@ -479,7 +483,7 @@ def get_potential_auto_merge( Parameter to detect the window size in correlogram estimation. censor_correlograms_ms : float, default: 0.15 The period to censor on the auto and cross-correlograms. - firing_contamination_balance : float, default: 2.5 + firing_contamination_balance : float, default: 1.5 Parameter to control the balance between firing rate and contamination in computing unit "quality score". k_nn : int, default 5 The number of neighbors to consider for every spike in the recording. @@ -843,10 +847,10 @@ def check_improve_contaminations_score( f_new = compute_firing_rates(sorting_analyzer_new)[unit_id1] # old and new scores - k = firing_contamination_balance - score_1 = f_1 * (1 - (k + 1) * c_1) - score_2 = f_2 * (1 - (k + 1) * c_2) - score_new = f_new * (1 - (k + 1) * c_new) + k = 1 + firing_contamination_balance + score_1 = f_1 * (1 - k * c_1) + score_2 = f_2 * (1 - k * c_2) + score_new = f_new * (1 - k * c_new) if score_new < score_1 or score_new < score_2: # the score is not improved From 10d455cdf6db3038b59f484d7bc12d107cf8c578 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 29 Oct 2024 14:18:35 +0000 Subject: [PATCH 50/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/auto_merge.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 12f7f9eac3..085467fe9f 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -66,9 +66,7 @@ _default_step_params = { "num_spikes": {"min_spikes": 100}, "snr": {"min_snr": 2}, - "remove_contaminated": {"contamination_thresh": 0.2, - "refractory_period_ms": 1.0, - "censored_period_ms": 0.3}, + "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, "unit_locations": {"max_distance_um": 150}, "correlogram": { "corr_diff_thresh": 0.16, @@ -85,9 +83,7 @@ "refractory_period_ms": 1.0, "censored_period_ms": 0.3, }, - "quality_score": {"firing_contamination_balance": 1.5, - "refractory_period_ms": 1.0, - "censored_period_ms": 0.3}, + "quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, } From 988723df9212bda349adc40aaa631ddb68f44123 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 29 Oct 2024 15:50:24 +0100 Subject: [PATCH 51/61] Triangular sup excluding self pairs --- src/spikeinterface/curation/auto_merge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 085467fe9f..994cc25d26 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -197,7 +197,7 @@ def compute_merge_unit_groups( sorting_analyzer = sorting_analyzer.copy() n = unit_ids.size - pair_mask = np.triu(np.arange(n)) > 0 + pair_mask = np.triu(np.arange(n), 1) > 0 outs = dict() for step in steps: From 95120e1391a041924879ab4236b1e431f892c020 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 29 Oct 2024 16:17:21 +0100 Subject: [PATCH 52/61] Update src/spikeinterface/curation/auto_merge.py Co-authored-by: Alessio Buccino --- src/spikeinterface/curation/auto_merge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 994cc25d26..8ac1ef0f95 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -185,7 +185,7 @@ def compute_merge_unit_groups( if preset is None and steps is None: raise ValueError("You need to specify a preset or steps for the auto-merge function") elif steps is not None: - # steps has presendance on presets + # steps has precedence on presets pass elif preset is not None: if preset not in _compute_merge_presets: From 21408543fc5589a06977997fb93567287b8cbbda Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 30 Oct 2024 10:52:50 +0100 Subject: [PATCH 53/61] Docs --- src/spikeinterface/curation/auto_merge.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 8ac1ef0f95..af4407b10e 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -121,6 +121,9 @@ def compute_merge_unit_groups( Q = f(1 - (k + 1)C) + IMPORTANT: internally, all computations are relying on extensions of the analyzer, that are computed + with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to + have a finer control on these values, please precompute the extensions before applying the auto_merge Parameters ---------- @@ -424,6 +427,9 @@ def get_potential_auto_merge( Q = f(1 - (k + 1)C) + IMPORTANT: internally, all computations are relying on extensions of the analyzer, that are computed + with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to + have a finer control on these values, please precompute the extensions before applying the auto_merge Parameters ---------- From df3d2dffda836b90a75ac8a68deb859a3b824b24 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 30 Oct 2024 09:54:41 +0000 Subject: [PATCH 54/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/auto_merge.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index af4407b10e..eeeb5b2098 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -122,8 +122,8 @@ def compute_merge_unit_groups( Q = f(1 - (k + 1)C) IMPORTANT: internally, all computations are relying on extensions of the analyzer, that are computed - with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to - have a finer control on these values, please precompute the extensions before applying the auto_merge + with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to + have a finer control on these values, please precompute the extensions before applying the auto_merge Parameters ---------- @@ -428,8 +428,8 @@ def get_potential_auto_merge( Q = f(1 - (k + 1)C) IMPORTANT: internally, all computations are relying on extensions of the analyzer, that are computed - with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to - have a finer control on these values, please precompute the extensions before applying the auto_merge + with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to + have a finer control on these values, please precompute the extensions before applying the auto_merge Parameters ---------- From 22b90945c82e55c15e06c9c92ebd6b752889906a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 30 Oct 2024 16:56:22 +0100 Subject: [PATCH 55/61] avoid copy when not necessary --- src/spikeinterface/curation/auto_merge.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index eeeb5b2098..4f4cff144e 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -195,7 +195,19 @@ def compute_merge_unit_groups( raise ValueError(f"preset must be one of {list(_compute_merge_presets.keys())}") steps = _compute_merge_presets[preset] - if force_copy: + # check at least one extension is needed + at_least_one_extension_to_compute = False + for step in steps: + assert step in _default_step_params, f"{step} is not a valid step" + if step in _required_extensions: + for ext in _required_extensions[step]: + if sorting_analyzer.has_extension(ext): + continue + if not compute_needed_extensions: + raise ValueError(f"{step} requires {ext} extension") + at_least_one_extension_to_compute = True + + if force_copy and at_least_one_extension_to_compute: # To avoid erasing the extensions of the user sorting_analyzer = sorting_analyzer.copy() @@ -205,14 +217,10 @@ def compute_merge_unit_groups( for step in steps: - assert step in _default_step_params, f"{step} is not a valid step" - if step in _required_extensions: for ext in _required_extensions[step]: if sorting_analyzer.has_extension(ext): continue - if not compute_needed_extensions: - raise ValueError(f"{step} requires {ext} extension") # special case for templates if ext == "templates" and not sorting_analyzer.has_extension("random_spikes"): From 12538cc646f47b162b73190326d5b541121b2c1a Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 31 Oct 2024 17:49:29 +0000 Subject: [PATCH 56/61] Revert to float32. --- src/spikeinterface/preprocessing/whiten.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 505e8a330a..b9c106a5a2 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -20,7 +20,7 @@ class WhitenRecording(BasePreprocessor): The recording extractor to be whitened. dtype : None or dtype, default: None Datatype of the output recording (covariance matrix estimation - and whitening are performed in float64). + and whitening are performed in float32). If None the the parent dtype is kept. For integer dtype a int_scale must be also given. mode : "global" | "local", default: "global" @@ -189,7 +189,7 @@ def compute_whitening_matrix( """ random_data = get_random_data_chunks(recording, concatenated=True, return_scaled=False, **random_chunk_kwargs) - random_data = random_data.astype(np.float64) + random_data = random_data.astype(np.float32) regularize_kwargs = regularize_kwargs if regularize_kwargs is not None else {"method": "GraphicalLassoCV"} From 035d61c38dbdb453a6461de424028d1466367bda Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Thu, 31 Oct 2024 18:02:03 +0000 Subject: [PATCH 57/61] Fix string format error. --- src/spikeinterface/preprocessing/whiten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index b9c106a5a2..fa33975a68 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -77,7 +77,7 @@ def __init__( if dtype_.kind == "i": assert int_scale is not None, ( - "For recording with dtype=int you must set the output dtype to float " " OR set a int_scale" + "For recording with dtype=int you must set the output dtype to float OR set a int_scale" ) if W is not None: From 7b8d0a2c1c3e006d8a9a46257e0f06e034aa0a76 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 31 Oct 2024 18:02:31 +0000 Subject: [PATCH 58/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/whiten.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index fa33975a68..57400c1199 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -76,9 +76,9 @@ def __init__( dtype_ = fix_dtype(recording, dtype) if dtype_.kind == "i": - assert int_scale is not None, ( - "For recording with dtype=int you must set the output dtype to float OR set a int_scale" - ) + assert ( + int_scale is not None + ), "For recording with dtype=int you must set the output dtype to float OR set a int_scale" if W is not None: W = np.asarray(W) From 74ef4eba21ec8bb7d413f5221d899d3f35c8287f Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 1 Nov 2024 17:51:15 -0400 Subject: [PATCH 59/61] Remove remaining array.ptp()s --- .../postprocessing/localization_tools.py | 10 ++++++---- .../sortingcomponents/motion/dredge.py | 18 ++++++------------ 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 837b983059..685dcad1f0 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -3,6 +3,10 @@ import warnings import numpy as np +from spikeinterface.core import SortingAnalyzer, Templates, compute_sparsity +from spikeinterface.core.template_tools import (_get_nbefore, + get_dense_templates_array, + get_template_extremum_channel) try: import numba @@ -12,8 +16,6 @@ HAVE_NUMBA = False -from spikeinterface.core import compute_sparsity, SortingAnalyzer, Templates -from spikeinterface.core.template_tools import get_template_extremum_channel, _get_nbefore, get_dense_templates_array def compute_monopolar_triangulation( @@ -110,7 +112,7 @@ def compute_monopolar_triangulation( # wf is (nsample, nchan) - chann is only nieghboor wf = templates[i, :, :][:, chan_inds] if feature == "ptp": - wf_data = wf.ptp(axis=0) + wf_data = np.ptp(wf, axis=0) elif feature == "energy": wf_data = np.linalg.norm(wf, axis=0) elif feature == "peak_voltage": @@ -188,7 +190,7 @@ def compute_center_of_mass( wf = templates[i, :, :] if feature == "ptp": - wf_data = (wf[:, chan_inds]).ptp(axis=0) + wf_data = np.ptp(wf[:, chan_inds], axis=0) elif feature == "mean": wf_data = (wf[:, chan_inds]).mean(axis=0) elif feature == "energy": diff --git a/src/spikeinterface/sortingcomponents/motion/dredge.py b/src/spikeinterface/sortingcomponents/motion/dredge.py index e2b6b1a2bc..4db6bb1cb2 100644 --- a/src/spikeinterface/sortingcomponents/motion/dredge.py +++ b/src/spikeinterface/sortingcomponents/motion/dredge.py @@ -22,21 +22,15 @@ """ +import gc import warnings -from tqdm.auto import trange import numpy as np +from tqdm.auto import trange -import gc - -from .motion_utils import ( - Motion, - get_spatial_windows, - get_window_domains, - scipy_conv1d, - make_2d_motion_histogram, - get_spatial_bin_edges, -) +from .motion_utils import (Motion, get_spatial_bin_edges, get_spatial_windows, + get_window_domains, make_2d_motion_histogram, + scipy_conv1d) # simple class wrapper to be compliant with estimate_motion @@ -979,7 +973,7 @@ def xcorr_windows( if max_disp_um is None: if rigid: - max_disp_um = int(spatial_bin_edges_um.ptp() // 4) + max_disp_um = int(np.ptp(spatial_bin_edges_um) // 4) else: max_disp_um = int(win_scale_um // 4) From 43b085fefe1574de8bc65764ecd0b246408ebed0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Nov 2024 21:53:15 +0000 Subject: [PATCH 60/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../postprocessing/localization_tools.py | 6 +----- src/spikeinterface/sortingcomponents/motion/dredge.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 685dcad1f0..59c12a9923 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -4,9 +4,7 @@ import numpy as np from spikeinterface.core import SortingAnalyzer, Templates, compute_sparsity -from spikeinterface.core.template_tools import (_get_nbefore, - get_dense_templates_array, - get_template_extremum_channel) +from spikeinterface.core.template_tools import _get_nbefore, get_dense_templates_array, get_template_extremum_channel try: import numba @@ -16,8 +14,6 @@ HAVE_NUMBA = False - - def compute_monopolar_triangulation( sorting_analyzer_or_templates: SortingAnalyzer | Templates, unit_ids=None, diff --git a/src/spikeinterface/sortingcomponents/motion/dredge.py b/src/spikeinterface/sortingcomponents/motion/dredge.py index 4db6bb1cb2..bfedd4e1ee 100644 --- a/src/spikeinterface/sortingcomponents/motion/dredge.py +++ b/src/spikeinterface/sortingcomponents/motion/dredge.py @@ -28,9 +28,14 @@ import numpy as np from tqdm.auto import trange -from .motion_utils import (Motion, get_spatial_bin_edges, get_spatial_windows, - get_window_domains, make_2d_motion_histogram, - scipy_conv1d) +from .motion_utils import ( + Motion, + get_spatial_bin_edges, + get_spatial_windows, + get_window_domains, + make_2d_motion_histogram, + scipy_conv1d, +) # simple class wrapper to be compliant with estimate_motion From 0e44185c2918a2d7b53cfe55879fde134c478b57 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 4 Nov 2024 11:05:26 +0100 Subject: [PATCH 61/61] Apply suggestions from code review --- src/spikeinterface/core/node_pipeline.py | 2 +- src/spikeinterface/sortingcomponents/peak_detection.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 8ca4ba7f3a..53c2445c77 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -544,7 +544,7 @@ def run_node_pipeline( recording_slices : None | list[tuple] Optionaly give a list of slices to run the pipeline only on some chunks of the recording. It must be a list of (segment_index, frame_start, frame_stop). - If None (default), the entire recording is computed. + If None (default), the function iterates over the entire duration of the recording. Returns ------- diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 233b16dcf7..d03744f8f9 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -87,7 +87,7 @@ def detect_peaks( recording_slices : None | list[tuple] Optionaly give a list of slices to run the pipeline only on some chunks of the recording. It must be a list of (segment_index, frame_start, frame_stop). - If None (default), the entire recording is computed. + If None (default), the function iterates over the entire duration of the recording. {method_doc} {job_doc}