diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index bc5de63d07..1d3501c4d0 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -691,12 +691,13 @@ class ComputeNoiseLevels(AnalyzerExtension): need_recording = True use_nodepipeline = False need_job_kwargs = False + need_backward_compatibility_on_load = True def __init__(self, sorting_analyzer): AnalyzerExtension.__init__(self, sorting_analyzer) - def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, seed=None): - params = dict(num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, seed=seed) + def _set_params(self, **noise_level_params): + params = noise_level_params.copy() return params def _select_extension_data(self, unit_ids): @@ -717,6 +718,15 @@ def _run(self, verbose=False): def _get_data(self): return self.data["noise_levels"] + def _handle_backward_compatibility_on_load(self): + # The old parameters used to be params=dict(num_chunks_per_segment=20, chunk_size=10000, seed=None) + # now it is handle more explicitly using random_slices_kwargs=dict() + for key in ("num_chunks_per_segment", "chunk_size", "seed"): + if key in self.params: + if "random_slices_kwargs" not in self.params: + self.params["random_slices_kwargs"] = dict() + self.params["random_slices_kwargs"][key] = self.params.pop(key) + register_result_extension(ComputeNoiseLevels) compute_noise_levels = ComputeNoiseLevels.function_factory() diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 5240edcee7..27f05bb36b 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -185,6 +185,22 @@ def ensure_n_jobs(recording, n_jobs=1): return n_jobs +def chunk_duration_to_chunk_size(chunk_duration, recording): + if isinstance(chunk_duration, float): + chunk_size = int(chunk_duration * recording.get_sampling_frequency()) + elif isinstance(chunk_duration, str): + if chunk_duration.endswith("ms"): + chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0 + elif chunk_duration.endswith("s"): + chunk_duration = float(chunk_duration.replace("s", "")) + else: + raise ValueError("chunk_duration must ends with s or ms") + chunk_size = int(chunk_duration * recording.get_sampling_frequency()) + else: + raise ValueError("chunk_duration must be str or float") + return chunk_size + + def ensure_chunk_size( recording, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs ): @@ -231,18 +247,7 @@ def ensure_chunk_size( num_channels = recording.get_num_channels() chunk_size = int(total_memory / (num_channels * n_bytes * n_jobs)) elif chunk_duration is not None: - if isinstance(chunk_duration, float): - chunk_size = int(chunk_duration * recording.get_sampling_frequency()) - elif isinstance(chunk_duration, str): - if chunk_duration.endswith("ms"): - chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0 - elif chunk_duration.endswith("s"): - chunk_duration = float(chunk_duration.replace("s", "")) - else: - raise ValueError("chunk_duration must ends with s or ms") - chunk_size = int(chunk_duration * recording.get_sampling_frequency()) - else: - raise ValueError("chunk_duration must be str or float") + chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording) else: # Edge case to define single chunk per segment for n_jobs=1. # All chunking parameters equal None mean single chunk per segment @@ -382,11 +387,13 @@ def __init__( f"chunk_duration={chunk_duration_str}", ) - def run(self): + def run(self, all_chunks=None): """ Runs the defined jobs. """ - all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size) + + if all_chunks is None: + all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size) if self.handle_returns: returns = [] diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 77d427bc88..2ab74ce51e 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -18,6 +18,8 @@ fix_job_kwargs, ChunkRecordingExecutor, _shared_job_kwargs_doc, + chunk_duration_to_chunk_size, + split_job_kwargs, ) @@ -512,33 +514,38 @@ def determine_cast_unsigned(recording, dtype): return cast_unsigned -def get_random_data_chunks( +def get_random_recording_slices( recording, - return_scaled=False, + method="full_random", num_chunks_per_segment=20, - chunk_size=10000, - concatenated=True, - seed=0, + chunk_duration="500ms", + chunk_size=None, margin_frames=0, + seed=None, ): """ - Extract random chunks across segments + Get random slice of a recording across segments. - This is used for instance in get_noise_levels() to estimate noise on traces. + This is used for instance in get_noise_levels() and get_random_data_chunks() to estimate noise on traces. Parameters ---------- recording : BaseRecording The recording to get random chunks from - return_scaled : bool, default: False - If True, returned chunks are scaled to uV + method : "full_random" + The method used to get random slices. + * "full_random" : legacy method, used until version 0.101.0, there is no constrain on slices + and they can overlap. num_chunks_per_segment : int, default: 20 Number of chunks per segment - chunk_size : int, default: 10000 - Size of a chunk in number of frames + chunk_duration : str | float | None, default "500ms" + The duration of each chunk in 's' or 'ms' + chunk_size : int | None + Size of a chunk in number of frames. This is ued only if chunk_duration is None. + This is kept for backward compatibility, you should prefer 'chunk_duration=500ms' instead. concatenated : bool, default: True If True chunk are concatenated along time axis - seed : int, default: 0 + seed : int, default: None Random seed margin_frames : int, default: 0 Margin in number of frames to avoid edge effects @@ -547,42 +554,89 @@ def get_random_data_chunks( ------- chunk_list : np.array Array of concatenate chunks per segment + + """ # TODO: if segment have differents length make another sampling that dependant on the length of the segment # Should be done by changing kwargs with total_num_chunks=XXX and total_duration=YYYY # And randomize the number of chunk per segment weighted by segment duration - # check chunk size - num_segments = recording.get_num_segments() - for segment_index in range(num_segments): - chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames - if chunk_size > chunk_size_limit: - chunk_size = chunk_size_limit - 1 - warnings.warn( - f"chunk_size is greater than the number " - f"of samples for segment index {segment_index}. " - f"Using {chunk_size}." - ) + if method == "full_random": + if chunk_size is None: + if chunk_duration is not None: + chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording) + else: + raise ValueError("get_random_recording_slices need chunk_size or chunk_duration") + + # check chunk size + num_segments = recording.get_num_segments() + for segment_index in range(num_segments): + chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames + if chunk_size > chunk_size_limit: + chunk_size = chunk_size_limit - 1 + warnings.warn( + f"chunk_size is greater than the number " + f"of samples for segment index {segment_index}. " + f"Using {chunk_size}." + ) + rng = np.random.default_rng(seed) + recording_slices = [] + low = margin_frames + size = num_chunks_per_segment + for segment_index in range(num_segments): + num_frames = recording.get_num_frames(segment_index) + high = num_frames - chunk_size - margin_frames + random_starts = rng.integers(low=low, high=high, size=size) + random_starts = np.sort(random_starts) + recording_slices += [ + (segment_index, start_frame, (start_frame + chunk_size)) for start_frame in random_starts + ] + else: + raise ValueError(f"get_random_recording_slices : wrong method {method}") - rng = np.random.default_rng(seed) - chunk_list = [] - low = margin_frames - size = num_chunks_per_segment - for segment_index in range(num_segments): - num_frames = recording.get_num_frames(segment_index) - high = num_frames - chunk_size - margin_frames - random_starts = rng.integers(low=low, high=high, size=size) - segment_trace_chunk = [ - recording.get_traces( - start_frame=start_frame, - end_frame=(start_frame + chunk_size), - segment_index=segment_index, - return_scaled=return_scaled, - ) - for start_frame in random_starts - ] + return recording_slices - chunk_list.extend(segment_trace_chunk) + +def get_random_data_chunks(recording, return_scaled=False, concatenated=True, **random_slices_kwargs): + """ + Extract random chunks across segments. + + Internally, it uses `get_random_recording_slices()` and retrieves the traces chunk as a list + or a concatenated unique array. + + Please read `get_random_recording_slices()` for more details on parameters. + + + Parameters + ---------- + recording : BaseRecording + The recording to get random chunks from + return_scaled : bool, default: False + If True, returned chunks are scaled to uV + num_chunks_per_segment : int, default: 20 + Number of chunks per segment + concatenated : bool, default: True + If True chunk are concatenated along time axis + **random_slices_kwargs : dict + Options transmited to get_random_recording_slices(), please read documentation from this + function for more details. + + Returns + ------- + chunk_list : np.array | list of np.array + Array of concatenate chunks per segment + """ + recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) + + chunk_list = [] + for segment_index, start_frame, end_frame in recording_slices: + traces_chunk = recording.get_traces( + start_frame=start_frame, + end_frame=end_frame, + segment_index=segment_index, + return_scaled=return_scaled, + ) + chunk_list.append(traces_chunk) if concatenated: return np.concatenate(chunk_list, axis=0) @@ -637,19 +691,52 @@ def get_closest_channels(recording, channel_ids=None, num_channels=None): return np.array(closest_channels_inds), np.array(dists) +def _noise_level_chunk(segment_index, start_frame, end_frame, worker_ctx): + recording = worker_ctx["recording"] + + one_chunk = recording.get_traces( + start_frame=start_frame, + end_frame=end_frame, + segment_index=segment_index, + return_scaled=worker_ctx["return_scaled"], + ) + + if worker_ctx["method"] == "mad": + med = np.median(one_chunk, axis=0, keepdims=True) + # hard-coded so that core doesn't depend on scipy + noise_levels = np.median(np.abs(one_chunk - med), axis=0) / 0.6744897501960817 + elif worker_ctx["method"] == "std": + noise_levels = np.std(one_chunk, axis=0) + + return noise_levels + + +def _noise_level_chunk_init(recording, return_scaled, method): + worker_ctx = {} + worker_ctx["recording"] = recording + worker_ctx["return_scaled"] = return_scaled + worker_ctx["method"] = method + return worker_ctx + + def get_noise_levels( recording: "BaseRecording", return_scaled: bool = True, method: Literal["mad", "std"] = "mad", force_recompute: bool = False, - **random_chunk_kwargs, + random_slices_kwargs: dict = {}, + **kwargs, ) -> np.ndarray: """ Estimate noise for each channel using MAD methods. You can use standard deviation with `method="std"` Internally it samples some chunk across segment. - And then, it use MAD estimator (more robust than STD) + And then, it uses the MAD estimator (more robust than STD) or the STD on each chunk. + Finally the average of all MAD/STD values is performed. + + The result is cached in a property of the recording, so that the next call on the same + recording will use the cached result unless `force_recompute=True`. Parameters ---------- @@ -662,8 +749,11 @@ def get_noise_levels( The method to use to estimate noise levels force_recompute : bool If True, noise levels are recomputed even if they are already stored in the recording extractor - random_chunk_kwargs : dict - Kwargs for get_random_data_chunks + random_slices_kwargs : dict + Options transmited to get_random_recording_slices(), please read documentation from this + function for more details. + + {} Returns ------- @@ -679,19 +769,56 @@ def get_noise_levels( if key in recording.get_property_keys() and not force_recompute: noise_levels = recording.get_property(key=key) else: - random_chunks = get_random_data_chunks(recording, return_scaled=return_scaled, **random_chunk_kwargs) - - if method == "mad": - med = np.median(random_chunks, axis=0, keepdims=True) - # hard-coded so that core doesn't depend on scipy - noise_levels = np.median(np.abs(random_chunks - med), axis=0) / 0.6744897501960817 - elif method == "std": - noise_levels = np.std(random_chunks, axis=0) + # This is to keep backward compatibility + # lets keep for a while and remove this maybe in 0.103.0 + # chunk_size used to be in the signature and now is ambiguous + random_slices_kwargs_, job_kwargs = split_job_kwargs(kwargs) + if len(random_slices_kwargs_) > 0 or "chunk_size" in job_kwargs: + msg = ( + "get_noise_levels(recording, num_chunks_per_segment=20) is deprecated\n" + "Now, you need to use get_noise_levels(recording, random_slices_kwargs=dict(num_chunks_per_segment=20, chunk_size=1000))\n" + "Please read get_random_recording_slices() documentation for more options." + ) + # if the user use both the old and the new behavior then an error is raised + assert len(random_slices_kwargs) == 0, msg + warnings.warn(msg) + random_slices_kwargs = random_slices_kwargs_ + if "chunk_size" in job_kwargs: + random_slices_kwargs["chunk_size"] = job_kwargs["chunk_size"] + + recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) + + noise_levels_chunks = [] + + def append_noise_chunk(res): + noise_levels_chunks.append(res) + + func = _noise_level_chunk + init_func = _noise_level_chunk_init + init_args = (recording, return_scaled, method) + executor = ChunkRecordingExecutor( + recording, + func, + init_func, + init_args, + job_name="noise_level", + verbose=False, + gather_func=append_noise_chunk, + **job_kwargs, + ) + executor.run(all_chunks=recording_slices) + noise_levels_chunks = np.stack(noise_levels_chunks) + noise_levels = np.mean(noise_levels_chunks, axis=0) + + # set property recording.set_property(key, noise_levels) return noise_levels +get_noise_levels.__doc__ = get_noise_levels.__doc__.format(_shared_job_kwargs_doc) + + def get_chunk_with_margin( rec_segment, start_frame, diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index 626899ab6e..6f5bef3c6c 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -2,6 +2,8 @@ import shutil +from pathlib import Path + from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import create_sorting_analyzer from spikeinterface.core import Templates @@ -250,16 +252,17 @@ def test_compute_several(create_cache_folder): if __name__ == "__main__": - - test_ComputeWaveforms(format="memory", sparse=True) - test_ComputeWaveforms(format="memory", sparse=False) - test_ComputeWaveforms(format="binary_folder", sparse=True) - test_ComputeWaveforms(format="binary_folder", sparse=False) - test_ComputeWaveforms(format="zarr", sparse=True) - test_ComputeWaveforms(format="zarr", sparse=False) - test_ComputeRandomSpikes(format="memory", sparse=True) - test_ComputeTemplates(format="memory", sparse=True) - test_ComputeNoiseLevels(format="memory", sparse=False) - - test_get_children_dependencies() - test_delete_on_recompute() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "core" + # test_ComputeWaveforms(format="memory", sparse=True, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="memory", sparse=False, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="binary_folder", sparse=True, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="binary_folder", sparse=False, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="zarr", sparse=True, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="zarr", sparse=False, create_cache_folder=cache_folder) + # test_ComputeRandomSpikes(format="memory", sparse=True, create_cache_folder=cache_folder) + test_ComputeRandomSpikes(format="binary_folder", sparse=False, create_cache_folder=cache_folder) + test_ComputeTemplates(format="memory", sparse=True, create_cache_folder=cache_folder) + test_ComputeNoiseLevels(format="memory", sparse=False, create_cache_folder=cache_folder) + + # test_get_children_dependencies() + # test_delete_on_recompute(cache_folder) diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 23a1574f2a..dad5273f12 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -11,6 +11,7 @@ from spikeinterface.core.recording_tools import ( write_binary_recording, write_memory_recording, + get_random_recording_slices, get_random_data_chunks, get_chunk_with_margin, get_closest_channels, @@ -167,6 +168,17 @@ def test_write_memory_recording(): shm.unlink() +def test_get_random_recording_slices(): + rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) + rec_slices = get_random_recording_slices( + rec, method="full_random", num_chunks_per_segment=20, chunk_duration="500ms", margin_frames=0, seed=0 + ) + assert len(rec_slices) == 40 + for seg_ind, start, stop in rec_slices: + assert stop - start == 500 + assert seg_ind in (0, 1) + + def test_get_random_data_chunks(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) chunks = get_random_data_chunks(rec, num_chunks_per_segment=50, chunk_size=500, seed=0) @@ -182,16 +194,17 @@ def test_get_closest_channels(): def test_get_noise_levels(): + job_kwargs = dict(n_jobs=1, progress_bar=True) rec = generate_recording(num_channels=2, sampling_frequency=1000.0, durations=[60.0]) - noise_levels_1 = get_noise_levels(rec, return_scaled=False) - noise_levels_2 = get_noise_levels(rec, return_scaled=False) + noise_levels_1 = get_noise_levels(rec, return_scaled=False, **job_kwargs) + noise_levels_2 = get_noise_levels(rec, return_scaled=False, **job_kwargs) rec.set_channel_gains(0.1) rec.set_channel_offsets(0) - noise_levels = get_noise_levels(rec, return_scaled=True, force_recompute=True) + noise_levels = get_noise_levels(rec, return_scaled=True, force_recompute=True, **job_kwargs) - noise_levels = get_noise_levels(rec, return_scaled=True, method="std") + noise_levels = get_noise_levels(rec, return_scaled=True, method="std", **job_kwargs) # Generate a recording following a gaussian distribution to check the result of get_noise. std = 6.0 @@ -201,8 +214,10 @@ def test_get_noise_levels(): recording = NumpyRecording(traces, 30000) assert np.all(noise_levels_1 == noise_levels_2) - assert np.allclose(get_noise_levels(recording, return_scaled=False), [std, std], rtol=1e-2, atol=1e-3) - assert np.allclose(get_noise_levels(recording, method="std", return_scaled=False), [std, std], rtol=1e-2, atol=1e-3) + assert np.allclose(get_noise_levels(recording, return_scaled=False, **job_kwargs), [std, std], rtol=1e-2, atol=1e-3) + assert np.allclose( + get_noise_levels(recording, method="std", return_scaled=False, **job_kwargs), [std, std], rtol=1e-2, atol=1e-3 + ) def test_get_noise_levels_output(): @@ -216,10 +231,21 @@ def test_get_noise_levels_output(): traces = rng.normal(loc=10.0, scale=std, size=(num_samples, num_channels)) recording = NumpyRecording(traces_list=traces, sampling_frequency=sampling_frequency) - std_estimated_with_mad = get_noise_levels(recording, method="mad", return_scaled=False, chunk_size=1_000) + std_estimated_with_mad = get_noise_levels( + recording, + method="mad", + return_scaled=False, + random_slices_kwargs=dict(num_chunks_per_segment=40, chunk_size=1_000, seed=seed), + ) + print(std_estimated_with_mad) assert np.allclose(std_estimated_with_mad, [std, std], rtol=1e-2, atol=1e-3) - std_estimated_with_std = get_noise_levels(recording, method="std", return_scaled=False, chunk_size=1_000) + std_estimated_with_std = get_noise_levels( + recording, + method="std", + return_scaled=False, + random_slices_kwargs=dict(num_chunks_per_segment=40, chunk_size=1_000, seed=seed), + ) assert np.allclose(std_estimated_with_std, [std, std], rtol=1e-2, atol=1e-3) @@ -333,14 +359,16 @@ def test_do_recording_attributes_match(): if __name__ == "__main__": # Create a temporary folder using the standard library - import tempfile - - with tempfile.TemporaryDirectory() as tmpdirname: - tmp_path = Path(tmpdirname) - test_write_binary_recording(tmp_path) - test_write_memory_recording() - - test_get_random_data_chunks() - test_get_closest_channels() - test_get_noise_levels() - test_order_channels_by_depth() + # import tempfile + + # with tempfile.TemporaryDirectory() as tmpdirname: + # tmp_path = Path(tmpdirname) + # test_write_binary_recording(tmp_path) + # test_write_memory_recording() + + test_get_random_recording_slices() + # test_get_random_data_chunks() + # test_get_closest_channels() + # test_get_noise_levels() + # test_get_noise_levels_output() + # test_order_channels_by_depth() diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 8f38f01469..85169011d8 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -71,8 +71,10 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see if mode in ["noise"]: if noise_levels is None: + random_slices_kwargs = random_chunk_kwargs.copy() + random_slices_kwargs["seed"] = seed noise_levels = get_noise_levels( - recording, return_scaled=False, concatenated=True, seed=seed, **random_chunk_kwargs + recording, return_scaled=False, random_slices_kwargs=random_slices_kwargs ) noise_generator = NoiseGeneratorRecording( num_channels=recording.get_num_channels(), diff --git a/src/spikeinterface/preprocessing/tests/test_scaling.py b/src/spikeinterface/preprocessing/tests/test_scaling.py index 321d7c9df2..e32d96901e 100644 --- a/src/spikeinterface/preprocessing/tests/test_scaling.py +++ b/src/spikeinterface/preprocessing/tests/test_scaling.py @@ -55,11 +55,11 @@ def test_scaling_in_preprocessing_chain(): recording.set_channel_gains(gains) recording.set_channel_offsets(offsets) - centered_recording = CenterRecording(scale_to_uV(recording=recording)) + centered_recording = CenterRecording(scale_to_uV(recording=recording), seed=2205) traces_scaled_with_argument = centered_recording.get_traces(return_scaled=True) # Chain preprocessors - centered_recording_scaled = CenterRecording(scale_to_uV(recording=recording)) + centered_recording_scaled = CenterRecording(scale_to_uV(recording=recording), seed=2205) traces_scaled_with_preprocessor = centered_recording_scaled.get_traces() np.testing.assert_allclose(traces_scaled_with_argument, traces_scaled_with_preprocessor) @@ -68,3 +68,8 @@ def test_scaling_in_preprocessing_chain(): traces_scaled_with_preprocessor_and_argument = centered_recording_scaled.get_traces(return_scaled=True) np.testing.assert_allclose(traces_scaled_with_preprocessor, traces_scaled_with_preprocessor_and_argument) + + +if __name__ == "__main__": + test_scale_to_uV() + test_scaling_in_preprocessing_chain() diff --git a/src/spikeinterface/preprocessing/tests/test_silence.py b/src/spikeinterface/preprocessing/tests/test_silence.py index 6c2e8ec8b5..20d4f6dfc7 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence.py +++ b/src/spikeinterface/preprocessing/tests/test_silence.py @@ -9,6 +9,8 @@ import numpy as np +from pathlib import Path + def test_silence(create_cache_folder): @@ -46,4 +48,5 @@ def test_silence(create_cache_folder): if __name__ == "__main__": - test_silence() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" + test_silence(cache_folder) diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index 04b731de4f..b40627d836 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -5,13 +5,15 @@ from spikeinterface.preprocessing import whiten, scale, compute_whitening_matrix +from pathlib import Path + def test_whiten(create_cache_folder): cache_folder = create_cache_folder rec = generate_recording(num_channels=4, seed=2205) print(rec.get_channel_locations()) - random_chunk_kwargs = {} + random_chunk_kwargs = {"seed": 2205} W1, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None) # print(W) # print(M) @@ -47,4 +49,5 @@ def test_whiten(create_cache_folder): if __name__ == "__main__": - test_whiten() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" + test_whiten(cache_folder)