From 36474a8cea03c9905dd21931af2c9de418d37798 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 2 Sep 2024 17:36:23 +0200 Subject: [PATCH 01/21] Refactor the get_random_data_chunks with an internal function. to allow more methods --- src/spikeinterface/core/job_tools.py | 28 +-- src/spikeinterface/core/recording_tools.py | 163 ++++++++++++++---- .../core/tests/test_recording_tools.py | 16 +- 3 files changed, 154 insertions(+), 53 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index a5279247f5..1aa9ac9333 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -187,6 +187,21 @@ def ensure_n_jobs(recording, n_jobs=1): return n_jobs +def chunk_duration_to_chunk_size(chunk_duration, recording): + if isinstance(chunk_duration, float): + chunk_size = int(chunk_duration * recording.get_sampling_frequency()) + elif isinstance(chunk_duration, str): + if chunk_duration.endswith("ms"): + chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0 + elif chunk_duration.endswith("s"): + chunk_duration = float(chunk_duration.replace("s", "")) + else: + raise ValueError("chunk_duration must ends with s or ms") + chunk_size = int(chunk_duration * recording.get_sampling_frequency()) + else: + raise ValueError("chunk_duration must be str or float") + return chunk_size + def ensure_chunk_size( recording, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs @@ -234,18 +249,7 @@ def ensure_chunk_size( num_channels = recording.get_num_channels() chunk_size = int(total_memory / (num_channels * n_bytes * n_jobs)) elif chunk_duration is not None: - if isinstance(chunk_duration, float): - chunk_size = int(chunk_duration * recording.get_sampling_frequency()) - elif isinstance(chunk_duration, str): - if chunk_duration.endswith("ms"): - chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0 - elif chunk_duration.endswith("s"): - chunk_duration = float(chunk_duration.replace("s", "")) - else: - raise ValueError("chunk_duration must ends with s or ms") - chunk_size = int(chunk_duration * recording.get_sampling_frequency()) - else: - raise ValueError("chunk_duration must be str or float") + chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording) else: # Edge case to define single chunk per segment for n_jobs=1. # All chunking parameters equal None mean single chunk per segment diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 2c7e75668f..764b6f0c66 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -18,6 +18,7 @@ fix_job_kwargs, ChunkRecordingExecutor, _shared_job_kwargs_doc, + chunk_duration_to_chunk_size, ) @@ -509,6 +510,87 @@ def determine_cast_unsigned(recording, dtype): return cast_unsigned + + +def get_random_recording_slices(recording, + method="legacy", + num_chunks_per_segment=20, + chunk_duration="500ms", + chunk_size=None, + margin_frames=0, + seed=None): + """ + Get random slice of a recording across segments. + + This is used for instance in get_noise_levels() and get_random_data_chunks() to estimate noise on traces. + + Parameters + ---------- + recording : BaseRecording + The recording to get random chunks from + methid : "legacy" + The method used. + num_chunks_per_segment : int, default: 20 + Number of chunks per segment + chunk_duration : str | float | None, default "500ms" + The duration of each chunk in 's' or 'ms' + chunk_size : int | None + Size of a chunk in number of frames + + concatenated : bool, default: True + If True chunk are concatenated along time axis + seed : int, default: 0 + Random seed + margin_frames : int, default: 0 + Margin in number of frames to avoid edge effects + + Returns + ------- + chunk_list : np.array + Array of concatenate chunks per segment + + + """ + # TODO: if segment have differents length make another sampling that dependant on the length of the segment + # Should be done by changing kwargs with total_num_chunks=XXX and total_duration=YYYY + # And randomize the number of chunk per segment weighted by segment duration + + if method == "legacy": + if chunk_size is None: + if chunk_duration is not None: + chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording) + else: + raise ValueError("get_random_recording_slices need chunk_size or chunk_duration") + + # check chunk size + num_segments = recording.get_num_segments() + for segment_index in range(num_segments): + chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames + if chunk_size > chunk_size_limit: + chunk_size = chunk_size_limit - 1 + warnings.warn( + f"chunk_size is greater than the number " + f"of samples for segment index {segment_index}. " + f"Using {chunk_size}." + ) + rng = np.random.default_rng(seed) + recording_slices = [] + low = margin_frames + size = num_chunks_per_segment + for segment_index in range(num_segments): + num_frames = recording.get_num_frames(segment_index) + high = num_frames - chunk_size - margin_frames + random_starts = rng.integers(low=low, high=high, size=size) + random_starts = np.sort(random_starts) + recording_slices += [ + (segment_index, start_frame, (start_frame + chunk_size)) for start_frame in random_starts + ] + else: + raise ValueError(f"get_random_recording_slices : wrong method {method}") + + return recording_slices + + def get_random_data_chunks( recording, return_scaled=False, @@ -545,41 +627,56 @@ def get_random_data_chunks( chunk_list : np.array Array of concatenate chunks per segment """ - # TODO: if segment have differents length make another sampling that dependant on the length of the segment - # Should be done by changing kwargs with total_num_chunks=XXX and total_duration=YYYY - # And randomize the number of chunk per segment weighted by segment duration - - # check chunk size - num_segments = recording.get_num_segments() - for segment_index in range(num_segments): - chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames - if chunk_size > chunk_size_limit: - chunk_size = chunk_size_limit - 1 - warnings.warn( - f"chunk_size is greater than the number " - f"of samples for segment index {segment_index}. " - f"Using {chunk_size}." - ) + # # check chunk size + # num_segments = recording.get_num_segments() + # for segment_index in range(num_segments): + # chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames + # if chunk_size > chunk_size_limit: + # chunk_size = chunk_size_limit - 1 + # warnings.warn( + # f"chunk_size is greater than the number " + # f"of samples for segment index {segment_index}. " + # f"Using {chunk_size}." + # ) + + # rng = np.random.default_rng(seed) + # chunk_list = [] + # low = margin_frames + # size = num_chunks_per_segment + # for segment_index in range(num_segments): + # num_frames = recording.get_num_frames(segment_index) + # high = num_frames - chunk_size - margin_frames + # random_starts = rng.integers(low=low, high=high, size=size) + # segment_trace_chunk = [ + # recording.get_traces( + # start_frame=start_frame, + # end_frame=(start_frame + chunk_size), + # segment_index=segment_index, + # return_scaled=return_scaled, + # ) + # for start_frame in random_starts + # ] + + # chunk_list.extend(segment_trace_chunk) + + recording_slices = get_random_recording_slices(recording, + method="legacy", + num_chunks_per_segment=num_chunks_per_segment, + chunk_size=chunk_size, + # chunk_duration=chunk_duration, + margin_frames=margin_frames, + seed=seed) + print(recording_slices) - rng = np.random.default_rng(seed) chunk_list = [] - low = margin_frames - size = num_chunks_per_segment - for segment_index in range(num_segments): - num_frames = recording.get_num_frames(segment_index) - high = num_frames - chunk_size - margin_frames - random_starts = rng.integers(low=low, high=high, size=size) - segment_trace_chunk = [ - recording.get_traces( - start_frame=start_frame, - end_frame=(start_frame + chunk_size), - segment_index=segment_index, - return_scaled=return_scaled, - ) - for start_frame in random_starts - ] - - chunk_list.extend(segment_trace_chunk) + for segment_index, start_frame, stop_frame in recording_slices: + traces_chunk = recording.get_traces( + start_frame=start_frame, + end_frame=(start_frame + chunk_size), + segment_index=segment_index, + return_scaled=return_scaled, + ) + chunk_list.append(traces_chunk) if concatenated: return np.concatenate(chunk_list, axis=0) diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 23a1574f2a..e54981744d 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -333,14 +333,14 @@ def test_do_recording_attributes_match(): if __name__ == "__main__": # Create a temporary folder using the standard library - import tempfile + # import tempfile - with tempfile.TemporaryDirectory() as tmpdirname: - tmp_path = Path(tmpdirname) - test_write_binary_recording(tmp_path) - test_write_memory_recording() + # with tempfile.TemporaryDirectory() as tmpdirname: + # tmp_path = Path(tmpdirname) + # test_write_binary_recording(tmp_path) + # test_write_memory_recording() test_get_random_data_chunks() - test_get_closest_channels() - test_get_noise_levels() - test_order_channels_by_depth() + # test_get_closest_channels() + # test_get_noise_levels() + # test_order_channels_by_depth() From 63574ef6a45b948f228bae94dbf090a6486279e7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 3 Sep 2024 17:08:10 +0200 Subject: [PATCH 02/21] Noise level in parallel --- src/spikeinterface/core/job_tools.py | 6 +- src/spikeinterface/core/recording_tools.py | 71 ++++++++++++++++--- .../core/tests/test_recording_tools.py | 21 +++--- 3 files changed, 78 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 1aa9ac9333..45d04e83df 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -389,11 +389,13 @@ def __init__( f"chunk_duration={chunk_duration_str}", ) - def run(self): + def run(self, all_chunks=None): """ Runs the defined jobs. """ - all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size) + + if all_chunks is None: + all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size) if self.handle_returns: returns = [] diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 764b6f0c66..37fcd9714a 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -19,6 +19,7 @@ ChunkRecordingExecutor, _shared_job_kwargs_doc, chunk_duration_to_chunk_size, + split_job_kwargs, ) @@ -666,7 +667,6 @@ def get_random_data_chunks( # chunk_duration=chunk_duration, margin_frames=margin_frames, seed=seed) - print(recording_slices) chunk_list = [] for segment_index, start_frame, stop_frame in recording_slices: @@ -731,12 +731,42 @@ def get_closest_channels(recording, channel_ids=None, num_channels=None): return np.array(closest_channels_inds), np.array(dists) +def _noise_level_chunk(segment_index, start_frame, end_frame, worker_ctx): + recording = worker_ctx["recording"] + + one_chunk = recording.get_traces( + start_frame=start_frame, + end_frame=end_frame, + segment_index=segment_index, + return_scaled=worker_ctx["return_scaled"], + ) + + + if worker_ctx["method"] == "mad": + med = np.median(one_chunk, axis=0, keepdims=True) + # hard-coded so that core doesn't depend on scipy + noise_levels = np.median(np.abs(one_chunk - med), axis=0) / 0.6744897501960817 + elif worker_ctx["method"] == "std": + noise_levels = np.std(one_chunk, axis=0) + + return noise_levels + + +def _noise_level_chunk_init(recording, return_scaled, method): + worker_ctx = {} + worker_ctx["recording"] = recording + worker_ctx["return_scaled"] = return_scaled + worker_ctx["method"] = method + return worker_ctx + def get_noise_levels( recording: "BaseRecording", return_scaled: bool = True, method: Literal["mad", "std"] = "mad", force_recompute: bool = False, - **random_chunk_kwargs, + **kwargs, + # **random_chunk_kwargs, + # **job_kwargs ): """ Estimate noise for each channel using MAD methods. @@ -773,19 +803,40 @@ def get_noise_levels( if key in recording.get_property_keys() and not force_recompute: noise_levels = recording.get_property(key=key) else: - random_chunks = get_random_data_chunks(recording, return_scaled=return_scaled, **random_chunk_kwargs) - - if method == "mad": - med = np.median(random_chunks, axis=0, keepdims=True) - # hard-coded so that core doesn't depend on scipy - noise_levels = np.median(np.abs(random_chunks - med), axis=0) / 0.6744897501960817 - elif method == "std": - noise_levels = np.std(random_chunks, axis=0) + # random_chunks = get_random_data_chunks(recording, return_scaled=return_scaled, **random_chunk_kwargs) + + # if method == "mad": + # med = np.median(random_chunks, axis=0, keepdims=True) + # # hard-coded so that core doesn't depend on scipy + # noise_levels = np.median(np.abs(random_chunks - med), axis=0) / 0.6744897501960817 + # elif method == "std": + # noise_levels = np.std(random_chunks, axis=0) + + random_slices_kwargs, job_kwargs = split_job_kwargs(kwargs) + recording_slices = get_random_recording_slices(recording,**random_slices_kwargs) + + noise_levels_chunks = [] + def append_noise_chunk(res): + noise_levels_chunks.append(res) + + func = _noise_level_chunk + init_func = _noise_level_chunk_init + init_args = (recording, return_scaled, method) + executor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name="noise_level", verbose=False, + gather_func=append_noise_chunk, **job_kwargs + ) + executor.run(all_chunks=recording_slices) + noise_levels_chunks = np.stack(noise_levels_chunks) + noise_levels = np.mean(noise_levels_chunks, axis=0) + + # set property recording.set_property(key, noise_levels) return noise_levels + def get_chunk_with_margin( rec_segment, start_frame, diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index e54981744d..918e15803a 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -166,6 +166,9 @@ def test_write_memory_recording(): for shm in shms: shm.unlink() +def test_get_random_recording_slices(): + # TODO + pass def test_get_random_data_chunks(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) @@ -182,16 +185,17 @@ def test_get_closest_channels(): def test_get_noise_levels(): + job_kwargs = dict(n_jobs=1, progress_bar=True) rec = generate_recording(num_channels=2, sampling_frequency=1000.0, durations=[60.0]) - noise_levels_1 = get_noise_levels(rec, return_scaled=False) - noise_levels_2 = get_noise_levels(rec, return_scaled=False) + noise_levels_1 = get_noise_levels(rec, return_scaled=False, **job_kwargs) + noise_levels_2 = get_noise_levels(rec, return_scaled=False, **job_kwargs) rec.set_channel_gains(0.1) rec.set_channel_offsets(0) - noise_levels = get_noise_levels(rec, return_scaled=True, force_recompute=True) + noise_levels = get_noise_levels(rec, return_scaled=True, force_recompute=True, **job_kwargs) - noise_levels = get_noise_levels(rec, return_scaled=True, method="std") + noise_levels = get_noise_levels(rec, return_scaled=True, method="std", **job_kwargs) # Generate a recording following a gaussian distribution to check the result of get_noise. std = 6.0 @@ -201,8 +205,8 @@ def test_get_noise_levels(): recording = NumpyRecording(traces, 30000) assert np.all(noise_levels_1 == noise_levels_2) - assert np.allclose(get_noise_levels(recording, return_scaled=False), [std, std], rtol=1e-2, atol=1e-3) - assert np.allclose(get_noise_levels(recording, method="std", return_scaled=False), [std, std], rtol=1e-2, atol=1e-3) + assert np.allclose(get_noise_levels(recording, return_scaled=False, **job_kwargs), [std, std], rtol=1e-2, atol=1e-3) + assert np.allclose(get_noise_levels(recording, method="std", return_scaled=False, **job_kwargs), [std, std], rtol=1e-2, atol=1e-3) def test_get_noise_levels_output(): @@ -340,7 +344,8 @@ def test_do_recording_attributes_match(): # test_write_binary_recording(tmp_path) # test_write_memory_recording() - test_get_random_data_chunks() + # test_get_random_recording_slices() + # test_get_random_data_chunks() # test_get_closest_channels() - # test_get_noise_levels() + test_get_noise_levels() # test_order_channels_by_depth() From 4e38686836aa0a105cb01e8c9dcd25bf7f20a662 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 7 Oct 2024 20:51:52 +0200 Subject: [PATCH 03/21] feedback and clean --- src/spikeinterface/core/recording_tools.py | 75 ++++++---------------- 1 file changed, 20 insertions(+), 55 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 37fcd9714a..cde6f0ced5 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -529,18 +529,19 @@ def get_random_recording_slices(recording, ---------- recording : BaseRecording The recording to get random chunks from - methid : "legacy" - The method used. + method : "legacy" + The method used to get random slices. + * "legacy" : the one used until version 0.101.0, there is no constrain on slices + and they can overlap. num_chunks_per_segment : int, default: 20 Number of chunks per segment chunk_duration : str | float | None, default "500ms" The duration of each chunk in 's' or 'ms' chunk_size : int | None - Size of a chunk in number of frames - + Size of a chunk in number of frames. This is ued only if chunk_duration is None. concatenated : bool, default: True If True chunk are concatenated along time axis - seed : int, default: 0 + seed : int, default: None Random seed margin_frames : int, default: 0 Margin in number of frames to avoid edge effects @@ -596,7 +597,8 @@ def get_random_data_chunks( recording, return_scaled=False, num_chunks_per_segment=20, - chunk_size=10000, + chunk_duration="500ms", + chunk_size=None, concatenated=True, seed=0, margin_frames=0, @@ -604,8 +606,6 @@ def get_random_data_chunks( """ Extract random chunks across segments - This is used for instance in get_noise_levels() to estimate noise on traces. - Parameters ---------- recording : BaseRecording @@ -614,8 +614,10 @@ def get_random_data_chunks( If True, returned chunks are scaled to uV num_chunks_per_segment : int, default: 20 Number of chunks per segment - chunk_size : int, default: 10000 - Size of a chunk in number of frames + chunk_duration : str | float | None, default "500ms" + The duration of each chunk in 's' or 'ms' + chunk_size : int | None + Size of a chunk in number of frames. This is ued only if chunk_duration is None. concatenated : bool, default: True If True chunk are concatenated along time axis seed : int, default: 0 @@ -628,51 +630,19 @@ def get_random_data_chunks( chunk_list : np.array Array of concatenate chunks per segment """ - # # check chunk size - # num_segments = recording.get_num_segments() - # for segment_index in range(num_segments): - # chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames - # if chunk_size > chunk_size_limit: - # chunk_size = chunk_size_limit - 1 - # warnings.warn( - # f"chunk_size is greater than the number " - # f"of samples for segment index {segment_index}. " - # f"Using {chunk_size}." - # ) - - # rng = np.random.default_rng(seed) - # chunk_list = [] - # low = margin_frames - # size = num_chunks_per_segment - # for segment_index in range(num_segments): - # num_frames = recording.get_num_frames(segment_index) - # high = num_frames - chunk_size - margin_frames - # random_starts = rng.integers(low=low, high=high, size=size) - # segment_trace_chunk = [ - # recording.get_traces( - # start_frame=start_frame, - # end_frame=(start_frame + chunk_size), - # segment_index=segment_index, - # return_scaled=return_scaled, - # ) - # for start_frame in random_starts - # ] - - # chunk_list.extend(segment_trace_chunk) - recording_slices = get_random_recording_slices(recording, method="legacy", num_chunks_per_segment=num_chunks_per_segment, + chunk_duration=chunk_duration, chunk_size=chunk_size, - # chunk_duration=chunk_duration, margin_frames=margin_frames, seed=seed) chunk_list = [] - for segment_index, start_frame, stop_frame in recording_slices: + for segment_index, start_frame, end_frame in recording_slices: traces_chunk = recording.get_traces( start_frame=start_frame, - end_frame=(start_frame + chunk_size), + end_frame=end_frame, segment_index=segment_index, return_scaled=return_scaled, ) @@ -773,7 +743,11 @@ def get_noise_levels( You can use standard deviation with `method="std"` Internally it samples some chunk across segment. - And then, it use MAD estimator (more robust than STD) + And then, it use MAD estimator (more robust than STD) ot the STD on each chunk. + Finally the average on all MAD is performed. + + The result is cached in a property of the recording. + Next call on the same recording will use the cache unless force_recompute=True. Parameters ---------- @@ -803,15 +777,6 @@ def get_noise_levels( if key in recording.get_property_keys() and not force_recompute: noise_levels = recording.get_property(key=key) else: - # random_chunks = get_random_data_chunks(recording, return_scaled=return_scaled, **random_chunk_kwargs) - - # if method == "mad": - # med = np.median(random_chunks, axis=0, keepdims=True) - # # hard-coded so that core doesn't depend on scipy - # noise_levels = np.median(np.abs(random_chunks - med), axis=0) / 0.6744897501960817 - # elif method == "std": - # noise_levels = np.std(random_chunks, axis=0) - random_slices_kwargs, job_kwargs = split_job_kwargs(kwargs) recording_slices = get_random_recording_slices(recording,**random_slices_kwargs) From 49c7a92a57af5a65f7367b375567afaed6abda56 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 15 Oct 2024 13:46:36 +0200 Subject: [PATCH 04/21] Use existing sparsity for unit location + add location with max channel --- .../postprocessing/localization_tools.py | 67 +++++++++++++++++-- .../tests/test_unit_locations.py | 1 + 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index e6278fc59f..59ca8cf7db 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -76,8 +76,12 @@ def compute_monopolar_triangulation( assert feature in ["ptp", "energy", "peak_voltage"], f"{feature} is not a valid feature" contact_locations = sorting_analyzer_or_templates.get_channel_locations() + + if sorting_analyzer_or_templates.sparsity is None: + sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um) + else: + sparsity = sorting_analyzer_or_templates.sparsity - sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um) templates = get_dense_templates_array( sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates) ) @@ -157,9 +161,13 @@ def compute_center_of_mass( assert feature in ["ptp", "mean", "energy", "peak_voltage"], f"{feature} is not a valid feature" - sparsity = compute_sparsity( - sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um - ) + if sorting_analyzer_or_templates.sparsity is None: + sparsity = compute_sparsity( + sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um + ) + else: + sparsity = sorting_analyzer_or_templates.sparsity + templates = get_dense_templates_array( sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates) ) @@ -650,8 +658,59 @@ def get_convolution_weights( enforce_decrease_shells = numba.jit(enforce_decrease_shells_data, nopython=True) + +def compute_location_max_channel( + templates_or_sorting_analyzer: SortingAnalyzer | Templates, + unit_ids=None, + peak_sign: "neg" | "pos" | "both" = "neg", + mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", +) -> np.ndarray: + """ + Localize a unit using max channel. + + This use inetrnally get_template_extremum_channel() + + + Parameters + ---------- + templates_or_sorting_analyzer : SortingAnalyzer | Templates + A SortingAnalyzer or Templates object + unit_ids: str | int | None + A list of unit_id to restrict the computation + peak_sign : "neg" | "pos" | "both" + Sign of the template to find extremum channels + mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + Where the amplitude is computed + * "extremum" : take the peak value (max or min depending on `peak_sign`) + * "at_index" : take value at `nbefore` index + * "peak_to_peak" : take the peak-to-peak amplitude + + Returns + ------- + unit_location: np.ndarray + 2d + """ + extremum_channels_index = get_template_extremum_channel( + templates_or_sorting_analyzer, + peak_sign=peak_sign, + mode=mode, + outputs="index" + ) + contact_locations = templates_or_sorting_analyzer.get_channel_locations() + if unit_ids is None: + unit_ids = templates_or_sorting_analyzer.unit_ids + else: + unit_ids = np.asarray(unit_ids) + unit_location = np.zeros((unit_ids.size, 2), dtype="float32") + for i, unit_id in enumerate(unit_ids): + unit_location[i, :] = contact_locations[extremum_channels_index[unit_id]] + + return unit_location + + _unit_location_methods = { "center_of_mass": compute_center_of_mass, "grid_convolution": compute_grid_convolution, "monopolar_triangulation": compute_monopolar_triangulation, + "max_channel": compute_location_max_channel, } diff --git a/src/spikeinterface/postprocessing/tests/test_unit_locations.py b/src/spikeinterface/postprocessing/tests/test_unit_locations.py index c40a917a2b..545edb3497 100644 --- a/src/spikeinterface/postprocessing/tests/test_unit_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_unit_locations.py @@ -13,6 +13,7 @@ class TestUnitLocationsExtension(AnalyzerExtensionCommonTestSuite): dict(method="grid_convolution", radius_um=150, weight_method={"mode": "gaussian_2d"}), dict(method="monopolar_triangulation", radius_um=150), dict(method="monopolar_triangulation", radius_um=150, optimizer="minimize_with_log_penality"), + dict(method="max_channel"), ], ) def test_extension(self, params): From 9cf9377a30b1733223037c58bb05709f0e76d5c8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Oct 2024 11:50:21 +0000 Subject: [PATCH 05/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../postprocessing/localization_tools.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 59ca8cf7db..4bf39e00e8 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -76,7 +76,7 @@ def compute_monopolar_triangulation( assert feature in ["ptp", "energy", "peak_voltage"], f"{feature} is not a valid feature" contact_locations = sorting_analyzer_or_templates.get_channel_locations() - + if sorting_analyzer_or_templates.sparsity is None: sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um) else: @@ -167,7 +167,7 @@ def compute_center_of_mass( ) else: sparsity = sorting_analyzer_or_templates.sparsity - + templates = get_dense_templates_array( sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates) ) @@ -658,7 +658,6 @@ def get_convolution_weights( enforce_decrease_shells = numba.jit(enforce_decrease_shells_data, nopython=True) - def compute_location_max_channel( templates_or_sorting_analyzer: SortingAnalyzer | Templates, unit_ids=None, @@ -691,10 +690,7 @@ def compute_location_max_channel( 2d """ extremum_channels_index = get_template_extremum_channel( - templates_or_sorting_analyzer, - peak_sign=peak_sign, - mode=mode, - outputs="index" + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, outputs="index" ) contact_locations = templates_or_sorting_analyzer.get_channel_locations() if unit_ids is None: From d8ee9da3dbc5599dc2876b47933aab8b352fab71 Mon Sep 17 00:00:00 2001 From: rainsong <57996958+522848942@users.noreply.github.com> Date: Mon, 21 Oct 2024 00:06:19 +0800 Subject: [PATCH 06/21] Update core.rst doc error --- doc/modules/core.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 8aa1815a55..5df9a7e6b1 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -385,7 +385,7 @@ and merging unit groups. sorting_analyzer_select = sorting_analyzer.select_units(unit_ids=[0, 1, 2, 3]) sorting_analyzer_remove = sorting_analyzer.remove_units(remove_unit_ids=[0]) - sorting_analyzer_merge = sorting_analyzer.merge_units([0, 1], [2, 3]) + sorting_analyzer_merge = sorting_analyzer.merge_units([[0, 1], [2, 3]]) All computed extensions will be automatically propagated or merged when curating. Please refer to the :ref:`modules/curation:Curation module` documentation for more information. From 0e5f50fdfb6b6ae35a9b06d814e811ac9bf833ee Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 21 Oct 2024 13:54:42 +0200 Subject: [PATCH 07/21] merci zach Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/postprocessing/localization_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 4bf39e00e8..a17abea1eb 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -667,14 +667,14 @@ def compute_location_max_channel( """ Localize a unit using max channel. - This use inetrnally get_template_extremum_channel() + This uses interrnally `get_template_extremum_channel()` Parameters ---------- templates_or_sorting_analyzer : SortingAnalyzer | Templates A SortingAnalyzer or Templates object - unit_ids: str | int | None + unit_ids: list[str] | list[int] | None A list of unit_id to restrict the computation peak_sign : "neg" | "pos" | "both" Sign of the template to find extremum channels From 1411c6fdf9d89c68a205c1512b1dce9ce2a29c62 Mon Sep 17 00:00:00 2001 From: OlivierPeron <79974181+OlivierPeron@users.noreply.github.com> Date: Mon, 21 Oct 2024 15:59:20 +0200 Subject: [PATCH 08/21] Loading templates Loading templates whatever the operator --- src/spikeinterface/core/template_tools.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 934b18ed49..769610ad2b 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -31,7 +31,8 @@ def get_dense_templates_array(one_object: Templates | SortingAnalyzer, return_sc ) ext = one_object.get_extension("templates") if ext is not None: - templates_array = ext.data["average"] + templates_array = ext.data.get("average") or ext.data.get("median") + assert templates_array is not None, "Average or median templates have not been computed." else: raise ValueError("SortingAnalyzer need extension 'templates' to be computed to retrieve templates") else: From 72357a68a25acc58c6634300d574e056a1e46857 Mon Sep 17 00:00:00 2001 From: OlivierPeron <79974181+OlivierPeron@users.noreply.github.com> Date: Mon, 21 Oct 2024 16:22:52 +0200 Subject: [PATCH 09/21] Update src/spikeinterface/core/template_tools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/core/template_tools.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 769610ad2b..3c8663df70 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -31,8 +31,12 @@ def get_dense_templates_array(one_object: Templates | SortingAnalyzer, return_sc ) ext = one_object.get_extension("templates") if ext is not None: - templates_array = ext.data.get("average") or ext.data.get("median") - assert templates_array is not None, "Average or median templates have not been computed." + if "average" in ext.data: + templates_array = ext.data.get("average") + elif "median" in ext.data: + templates_array = ext.data.get("median") + else: + raise ValueError("Average or median templates have not been computed.") else: raise ValueError("SortingAnalyzer need extension 'templates' to be computed to retrieve templates") else: From f4dd922447eac6d422e43093a5b93a8877df34be Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Mon, 21 Oct 2024 12:44:26 -0400 Subject: [PATCH 10/21] better error message (#3479) --- src/spikeinterface/core/baserecordingsnippets.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 310533c96b..2ec3664a45 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -172,8 +172,10 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) if number_of_device_channel_indices >= self.get_num_channels(): error_msg = ( - f"The given Probe have 'device_channel_indices' that do not match channel count \n" - f"{number_of_device_channel_indices} vs {self.get_num_channels()} \n" + f"The given Probe either has 'device_channel_indices' that does not match channel count \n" + f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" + f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" + f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n" f"device_channel_indices are the following: {device_channel_indices} \n" f"recording channels are the following: {self.get_channel_ids()} \n" ) From 0be00cf32fad46b1a55d7018ea051014644568ab Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 21 Oct 2024 18:47:14 +0200 Subject: [PATCH 11/21] Update src/spikeinterface/postprocessing/localization_tools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/postprocessing/localization_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index a17abea1eb..3372a34c98 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -667,7 +667,7 @@ def compute_location_max_channel( """ Localize a unit using max channel. - This uses interrnally `get_template_extremum_channel()` + This uses internally `get_template_extremum_channel()` Parameters From 0002edcbe99764fcf65edae405a2be59647691f1 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 21 Oct 2024 18:47:23 +0200 Subject: [PATCH 12/21] Update src/spikeinterface/postprocessing/localization_tools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/postprocessing/localization_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 3372a34c98..67d469f85c 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -686,7 +686,7 @@ def compute_location_max_channel( Returns ------- - unit_location: np.ndarray + unit_locations: np.ndarray 2d """ extremum_channels_index = get_template_extremum_channel( From b4e681d8524d119a86ba21fd9a19c5e6716c1ca0 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 21 Oct 2024 18:47:33 +0200 Subject: [PATCH 13/21] Update src/spikeinterface/postprocessing/localization_tools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/postprocessing/localization_tools.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 67d469f85c..a073b6c518 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -697,11 +697,11 @@ def compute_location_max_channel( unit_ids = templates_or_sorting_analyzer.unit_ids else: unit_ids = np.asarray(unit_ids) - unit_location = np.zeros((unit_ids.size, 2), dtype="float32") + unit_locations = np.zeros((unit_ids.size, 2), dtype="float32") for i, unit_id in enumerate(unit_ids): - unit_location[i, :] = contact_locations[extremum_channels_index[unit_id]] + unit_locations[i, :] = contact_locations[extremum_channels_index[unit_id]] - return unit_location + return unit_locations _unit_location_methods = { From 55b50abdebf5f1e0aa0c00307d706588b1d9038d Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 21 Oct 2024 18:47:54 +0200 Subject: [PATCH 14/21] Update src/spikeinterface/postprocessing/localization_tools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/postprocessing/localization_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index a073b6c518..837b983059 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -684,7 +684,7 @@ def compute_location_max_channel( * "at_index" : take value at `nbefore` index * "peak_to_peak" : take the peak-to-peak amplitude - Returns + Returns ------- unit_locations: np.ndarray 2d From c2f980c6bb40df45d411904dd5f58288fa7b8ad3 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Mon, 21 Oct 2024 14:28:20 -0400 Subject: [PATCH 15/21] typos --- doc/modules/curation.rst | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index d115b33e4a..d24fc810b0 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -88,7 +88,7 @@ The ``censored_period_ms`` parameter is the time window in milliseconds to consi The :py:func:`~spikeinterface.curation.remove_redundand_units` function removes redundant units from the sorting output. Redundant units are units that share over a certain percentage of spikes, by default 80%. -The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object. +The function can act both on a ``BaseSorting`` or a ``SortingAnalyzer`` object. .. code-block:: python @@ -102,13 +102,16 @@ The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object. ) # remove redundant units from SortingAnalyzer object - clean_sorting_analyzer = remove_redundant_units( + # note this returns a cleaned sorting + clean_sorting = remove_redundant_units( sorting_analyzer, duplicate_threshold=0.9, remove_strategy="min_shift" ) + # in order to have a sorter with only the non-redundant units do: + clean_sorting_analyzer = sorting_analyzer.select_units(clean_sorting.unit_ids) -We recommend usinf the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps +We recommend using the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps the unit (among the redundant ones), with a better template alignment. From 61ce0007476b9a1e80d69bdbbc42e70d8bcce626 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Mon, 21 Oct 2024 14:32:37 -0400 Subject: [PATCH 16/21] better comment --- doc/modules/curation.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index d24fc810b0..37de992806 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -108,7 +108,9 @@ The function can act both on a ``BaseSorting`` or a ``SortingAnalyzer`` object. duplicate_threshold=0.9, remove_strategy="min_shift" ) - # in order to have a sorter with only the non-redundant units do: + # in order to have a SortingAnalyer with only the non-redundant units one must + # select the designed units remembering to give format and folder if one wants + # a persistent SortingAnalyzer. clean_sorting_analyzer = sorting_analyzer.select_units(clean_sorting.unit_ids) We recommend using the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps From 1ce2f8dc411acadf7bdf6bdd537f616d448a536c Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Wed, 23 Oct 2024 14:58:41 +0200 Subject: [PATCH 17/21] merci alessio Co-authored-by: Alessio Buccino --- src/spikeinterface/core/recording_tools.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 01f541ddac..790511ad88 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -744,11 +744,11 @@ def get_noise_levels( You can use standard deviation with `method="std"` Internally it samples some chunk across segment. - And then, it use MAD estimator (more robust than STD) ot the STD on each chunk. - Finally the average on all MAD is performed. + And then, it uses the MAD estimator (more robust than STD) or the STD on each chunk. + Finally the average of all MAD/STD values is performed. - The result is cached in a property of the recording. - Next call on the same recording will use the cache unless force_recompute=True. + The result is cached in a property of the recording, so that the next call on the same + recording will use the cached result unless `force_recompute=True`. Parameters ---------- From eb6219999ce1d8d1c898575ec9a71045c6a8bcb1 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 23 Oct 2024 15:50:51 +0200 Subject: [PATCH 18/21] wip --- src/spikeinterface/core/recording_tools.py | 27 +++++++------------ .../preprocessing/silence_periods.py | 4 ++- .../preprocessing/tests/test_silence.py | 5 +++- .../preprocessing/tests/test_whiten.py | 7 +++-- 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 790511ad88..a4feff4d14 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -601,11 +601,16 @@ def get_random_data_chunks( recording, return_scaled=False, concatenated=True, - random_slices_kwargs={}, - **kwargs, + **random_slices_kwargs ): """ - Extract random chunks across segments + Extract random chunks across segments. + + Internally, it uses `get_random_recording_slices()` and retrieves the traces chunk as a list + or a concatenated unique array. + + Please read `get_random_recording_slices()` for more details on parameters. + Parameters ---------- @@ -617,7 +622,7 @@ def get_random_data_chunks( Number of chunks per segment concatenated : bool, default: True If True chunk are concatenated along time axis - random_slices_kwargs : dict + **random_slices_kwargs : dict Options transmited to get_random_recording_slices(), please read documentation from this function for more details. @@ -626,18 +631,6 @@ def get_random_data_chunks( chunk_list : np.array Array of concatenate chunks per segment """ - if len(kwargs) > 0: - # This is to keep backward compatibility - # lets keep for a while and remove this maybe in 0.103.0 - msg = ( - "get_random_data_chunks(recording, num_chunks_per_segment=20) is deprecated\n" - "Now, you need to use get_random_data_chunks(recording, random_slices_kwargs=dict(num_chunks_per_segment=20))\n" - "Please read get_random_recording_slices() documentation for more options." - ) - assert len(random_slices_kwargs) ==0, msg - warnings.warn(msg) - random_slices_kwargs = kwargs - recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) chunk_list = [] @@ -797,7 +790,7 @@ def get_noise_levels( if "chunk_size" in job_kwargs: random_slices_kwargs["chunk_size"] = job_kwargs["chunk_size"] - recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) + recording_slices = get_random_recording_slices(recording, random_slices_kwargs) noise_levels_chunks = [] def append_noise_chunk(res): diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 8f38f01469..3129acd3f3 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -71,8 +71,10 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see if mode in ["noise"]: if noise_levels is None: + random_chunk_kwargs = random_chunk_kwargs.copy() + random_chunk_kwargs["seed"] = seed noise_levels = get_noise_levels( - recording, return_scaled=False, concatenated=True, seed=seed, **random_chunk_kwargs + recording, return_scaled=False, random_slices_kwargs=random_chunk_kwargs ) noise_generator = NoiseGeneratorRecording( num_channels=recording.get_num_channels(), diff --git a/src/spikeinterface/preprocessing/tests/test_silence.py b/src/spikeinterface/preprocessing/tests/test_silence.py index 6c2e8ec8b5..6405b6b0c4 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence.py +++ b/src/spikeinterface/preprocessing/tests/test_silence.py @@ -9,6 +9,8 @@ import numpy as np +from pathlib import Path + def test_silence(create_cache_folder): @@ -46,4 +48,5 @@ def test_silence(create_cache_folder): if __name__ == "__main__": - test_silence() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" + test_silence(cache_folder) diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index 04b731de4f..3444323488 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -5,13 +5,15 @@ from spikeinterface.preprocessing import whiten, scale, compute_whitening_matrix +from pathlib import Path + def test_whiten(create_cache_folder): cache_folder = create_cache_folder rec = generate_recording(num_channels=4, seed=2205) print(rec.get_channel_locations()) - random_chunk_kwargs = {} + random_chunk_kwargs = {"seed": 2205} W1, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None) # print(W) # print(M) @@ -47,4 +49,5 @@ def test_whiten(create_cache_folder): if __name__ == "__main__": - test_whiten() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" + test_whiten(cache_folder) From 6ee7299c7f7a0b14fbc49f12abf46907b418a072 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 23 Oct 2024 17:40:57 +0200 Subject: [PATCH 19/21] more updates --- src/spikeinterface/core/recording_tools.py | 10 ++++++---- src/spikeinterface/core/tests/test_recording_tools.py | 2 +- src/spikeinterface/preprocessing/silence_periods.py | 6 +++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index a4feff4d14..1f46de7d29 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -628,7 +628,7 @@ def get_random_data_chunks( Returns ------- - chunk_list : np.array + chunk_list : np.array | list of np.array Array of concatenate chunks per segment """ recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) @@ -757,8 +757,9 @@ def get_noise_levels( random_slices_kwargs : dict Options transmited to get_random_recording_slices(), please read documentation from this function for more details. - **job_kwargs: - Job kwargs for parallel computing. + + {} + Returns ------- noise_levels : array @@ -790,7 +791,7 @@ def get_noise_levels( if "chunk_size" in job_kwargs: random_slices_kwargs["chunk_size"] = job_kwargs["chunk_size"] - recording_slices = get_random_recording_slices(recording, random_slices_kwargs) + recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) noise_levels_chunks = [] def append_noise_chunk(res): @@ -812,6 +813,7 @@ def append_noise_chunk(res): return noise_levels +get_noise_levels.__doc__ = get_noise_levels.__doc__.format(_shared_job_kwargs_doc) def get_chunk_with_margin( diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 07515ef3f0..1fa9ffe124 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -182,7 +182,7 @@ def test_get_random_recording_slices(): def test_get_random_data_chunks(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) - chunks = get_random_data_chunks(rec, random_slices_kwargs=dict(num_chunks_per_segment=50, chunk_size=500, seed=0)) + chunks = get_random_data_chunks(rec, num_chunks_per_segment=50, chunk_size=500, seed=0) assert chunks.shape == (50000, 1) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 3129acd3f3..85169011d8 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -71,10 +71,10 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see if mode in ["noise"]: if noise_levels is None: - random_chunk_kwargs = random_chunk_kwargs.copy() - random_chunk_kwargs["seed"] = seed + random_slices_kwargs = random_chunk_kwargs.copy() + random_slices_kwargs["seed"] = seed noise_levels = get_noise_levels( - recording, return_scaled=False, random_slices_kwargs=random_chunk_kwargs + recording, return_scaled=False, random_slices_kwargs=random_slices_kwargs ) noise_generator = NoiseGeneratorRecording( num_channels=recording.get_num_channels(), From 7517d06cb99d398eaecd48e3c50e063a8796bf7f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 24 Oct 2024 10:09:47 +0200 Subject: [PATCH 20/21] fix scaling seed --- src/spikeinterface/preprocessing/tests/test_scaling.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_scaling.py b/src/spikeinterface/preprocessing/tests/test_scaling.py index 321d7c9df2..e32d96901e 100644 --- a/src/spikeinterface/preprocessing/tests/test_scaling.py +++ b/src/spikeinterface/preprocessing/tests/test_scaling.py @@ -55,11 +55,11 @@ def test_scaling_in_preprocessing_chain(): recording.set_channel_gains(gains) recording.set_channel_offsets(offsets) - centered_recording = CenterRecording(scale_to_uV(recording=recording)) + centered_recording = CenterRecording(scale_to_uV(recording=recording), seed=2205) traces_scaled_with_argument = centered_recording.get_traces(return_scaled=True) # Chain preprocessors - centered_recording_scaled = CenterRecording(scale_to_uV(recording=recording)) + centered_recording_scaled = CenterRecording(scale_to_uV(recording=recording), seed=2205) traces_scaled_with_preprocessor = centered_recording_scaled.get_traces() np.testing.assert_allclose(traces_scaled_with_argument, traces_scaled_with_preprocessor) @@ -68,3 +68,8 @@ def test_scaling_in_preprocessing_chain(): traces_scaled_with_preprocessor_and_argument = centered_recording_scaled.get_traces(return_scaled=True) np.testing.assert_allclose(traces_scaled_with_preprocessor, traces_scaled_with_preprocessor_and_argument) + + +if __name__ == "__main__": + test_scale_to_uV() + test_scaling_in_preprocessing_chain() From bac57fe00450a6fe758a67092c0ff0d8f17ebfb5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 24 Oct 2024 08:39:08 +0000 Subject: [PATCH 21/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../core/analyzer_extension_core.py | 1 - src/spikeinterface/core/job_tools.py | 3 +- src/spikeinterface/core/recording_tools.py | 51 ++++++++++--------- .../tests/test_analyzer_extension_core.py | 2 +- .../core/tests/test_recording_tools.py | 35 ++++++++----- .../preprocessing/tests/test_silence.py | 2 +- .../preprocessing/tests/test_whiten.py | 2 +- 7 files changed, 54 insertions(+), 42 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 38d7ab247c..1d3501c4d0 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -693,7 +693,6 @@ class ComputeNoiseLevels(AnalyzerExtension): need_job_kwargs = False need_backward_compatibility_on_load = True - def __init__(self, sorting_analyzer): AnalyzerExtension.__init__(self, sorting_analyzer) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 8f5df37695..27f05bb36b 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -184,6 +184,7 @@ def ensure_n_jobs(recording, n_jobs=1): return n_jobs + def chunk_duration_to_chunk_size(chunk_duration, recording): if isinstance(chunk_duration, float): chunk_size = int(chunk_duration * recording.get_sampling_frequency()) @@ -196,7 +197,7 @@ def chunk_duration_to_chunk_size(chunk_duration, recording): raise ValueError("chunk_duration must ends with s or ms") chunk_size = int(chunk_duration * recording.get_sampling_frequency()) else: - raise ValueError("chunk_duration must be str or float") + raise ValueError("chunk_duration must be str or float") return chunk_size diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 1f46de7d29..2ab74ce51e 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -514,15 +514,15 @@ def determine_cast_unsigned(recording, dtype): return cast_unsigned - - -def get_random_recording_slices(recording, - method="full_random", - num_chunks_per_segment=20, - chunk_duration="500ms", - chunk_size=None, - margin_frames=0, - seed=None): +def get_random_recording_slices( + recording, + method="full_random", + num_chunks_per_segment=20, + chunk_duration="500ms", + chunk_size=None, + margin_frames=0, + seed=None, +): """ Get random slice of a recording across segments. @@ -593,19 +593,14 @@ def get_random_recording_slices(recording, ] else: raise ValueError(f"get_random_recording_slices : wrong method {method}") - + return recording_slices -def get_random_data_chunks( - recording, - return_scaled=False, - concatenated=True, - **random_slices_kwargs -): +def get_random_data_chunks(recording, return_scaled=False, concatenated=True, **random_slices_kwargs): """ Extract random chunks across segments. - + Internally, it uses `get_random_recording_slices()` and retrieves the traces chunk as a list or a concatenated unique array. @@ -698,7 +693,7 @@ def get_closest_channels(recording, channel_ids=None, num_channels=None): def _noise_level_chunk(segment_index, start_frame, end_frame, worker_ctx): recording = worker_ctx["recording"] - + one_chunk = recording.get_traces( start_frame=start_frame, end_frame=end_frame, @@ -706,7 +701,6 @@ def _noise_level_chunk(segment_index, start_frame, end_frame, worker_ctx): return_scaled=worker_ctx["return_scaled"], ) - if worker_ctx["method"] == "mad": med = np.median(one_chunk, axis=0, keepdims=True) # hard-coded so that core doesn't depend on scipy @@ -724,12 +718,13 @@ def _noise_level_chunk_init(recording, return_scaled, method): worker_ctx["method"] = method return worker_ctx + def get_noise_levels( recording: "BaseRecording", return_scaled: bool = True, method: Literal["mad", "std"] = "mad", force_recompute: bool = False, - random_slices_kwargs : dict = {}, + random_slices_kwargs: dict = {}, **kwargs, ) -> np.ndarray: """ @@ -759,7 +754,7 @@ def get_noise_levels( function for more details. {} - + Returns ------- noise_levels : array @@ -774,7 +769,7 @@ def get_noise_levels( if key in recording.get_property_keys() and not force_recompute: noise_levels = recording.get_property(key=key) else: - # This is to keep backward compatibility + # This is to keep backward compatibility # lets keep for a while and remove this maybe in 0.103.0 # chunk_size used to be in the signature and now is ambiguous random_slices_kwargs_, job_kwargs = split_job_kwargs(kwargs) @@ -794,6 +789,7 @@ def get_noise_levels( recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) noise_levels_chunks = [] + def append_noise_chunk(res): noise_levels_chunks.append(res) @@ -801,8 +797,14 @@ def append_noise_chunk(res): init_func = _noise_level_chunk_init init_args = (recording, return_scaled, method) executor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name="noise_level", verbose=False, - gather_func=append_noise_chunk, **job_kwargs + recording, + func, + init_func, + init_args, + job_name="noise_level", + verbose=False, + gather_func=append_noise_chunk, + **job_kwargs, ) executor.run(all_chunks=recording_slices) noise_levels_chunks = np.stack(noise_levels_chunks) @@ -813,6 +815,7 @@ def append_noise_chunk(res): return noise_levels + get_noise_levels.__doc__ = get_noise_levels.__doc__.format(_shared_job_kwargs_doc) diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index b04155261b..6f5bef3c6c 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -259,7 +259,7 @@ def test_compute_several(create_cache_folder): # test_ComputeWaveforms(format="binary_folder", sparse=False, create_cache_folder=cache_folder) # test_ComputeWaveforms(format="zarr", sparse=True, create_cache_folder=cache_folder) # test_ComputeWaveforms(format="zarr", sparse=False, create_cache_folder=cache_folder) - #test_ComputeRandomSpikes(format="memory", sparse=True, create_cache_folder=cache_folder) + # test_ComputeRandomSpikes(format="memory", sparse=True, create_cache_folder=cache_folder) test_ComputeRandomSpikes(format="binary_folder", sparse=False, create_cache_folder=cache_folder) test_ComputeTemplates(format="memory", sparse=True, create_cache_folder=cache_folder) test_ComputeNoiseLevels(format="memory", sparse=False, create_cache_folder=cache_folder) diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 1fa9ffe124..dad5273f12 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -167,19 +167,18 @@ def test_write_memory_recording(): for shm in shms: shm.unlink() + def test_get_random_recording_slices(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) - rec_slices = get_random_recording_slices(rec, - method="full_random", - num_chunks_per_segment=20, - chunk_duration="500ms", - margin_frames=0, - seed=0) + rec_slices = get_random_recording_slices( + rec, method="full_random", num_chunks_per_segment=20, chunk_duration="500ms", margin_frames=0, seed=0 + ) assert len(rec_slices) == 40 for seg_ind, start, stop in rec_slices: assert stop - start == 500 assert seg_ind in (0, 1) + def test_get_random_data_chunks(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) chunks = get_random_data_chunks(rec, num_chunks_per_segment=50, chunk_size=500, seed=0) @@ -216,7 +215,9 @@ def test_get_noise_levels(): assert np.all(noise_levels_1 == noise_levels_2) assert np.allclose(get_noise_levels(recording, return_scaled=False, **job_kwargs), [std, std], rtol=1e-2, atol=1e-3) - assert np.allclose(get_noise_levels(recording, method="std", return_scaled=False, **job_kwargs), [std, std], rtol=1e-2, atol=1e-3) + assert np.allclose( + get_noise_levels(recording, method="std", return_scaled=False, **job_kwargs), [std, std], rtol=1e-2, atol=1e-3 + ) def test_get_noise_levels_output(): @@ -230,13 +231,21 @@ def test_get_noise_levels_output(): traces = rng.normal(loc=10.0, scale=std, size=(num_samples, num_channels)) recording = NumpyRecording(traces_list=traces, sampling_frequency=sampling_frequency) - std_estimated_with_mad = get_noise_levels(recording, method="mad", return_scaled=False, - random_slices_kwargs=dict(num_chunks_per_segment=40, chunk_size=1_000, seed=seed)) + std_estimated_with_mad = get_noise_levels( + recording, + method="mad", + return_scaled=False, + random_slices_kwargs=dict(num_chunks_per_segment=40, chunk_size=1_000, seed=seed), + ) print(std_estimated_with_mad) assert np.allclose(std_estimated_with_mad, [std, std], rtol=1e-2, atol=1e-3) - std_estimated_with_std = get_noise_levels(recording, method="std", return_scaled=False, - random_slices_kwargs=dict(num_chunks_per_segment=40, chunk_size=1_000, seed=seed)) + std_estimated_with_std = get_noise_levels( + recording, + method="std", + return_scaled=False, + random_slices_kwargs=dict(num_chunks_per_segment=40, chunk_size=1_000, seed=seed), + ) assert np.allclose(std_estimated_with_std, [std, std], rtol=1e-2, atol=1e-3) @@ -358,8 +367,8 @@ def test_do_recording_attributes_match(): # test_write_memory_recording() test_get_random_recording_slices() - # test_get_random_data_chunks() + # test_get_random_data_chunks() # test_get_closest_channels() # test_get_noise_levels() - # test_get_noise_levels_output() + # test_get_noise_levels_output() # test_order_channels_by_depth() diff --git a/src/spikeinterface/preprocessing/tests/test_silence.py b/src/spikeinterface/preprocessing/tests/test_silence.py index 6405b6b0c4..20d4f6dfc7 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence.py +++ b/src/spikeinterface/preprocessing/tests/test_silence.py @@ -48,5 +48,5 @@ def test_silence(create_cache_folder): if __name__ == "__main__": - cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" test_silence(cache_folder) diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index 3444323488..b40627d836 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -49,5 +49,5 @@ def test_whiten(create_cache_folder): if __name__ == "__main__": - cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" test_whiten(cache_folder)