From d0263c043039900167d4d19d0789b20adaeb52f0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 27 Nov 2023 10:36:47 +0100 Subject: [PATCH 1/5] Fix memory leak in lsmr solver and optimize correct_motion --- src/spikeinterface/preprocessing/motion.py | 66 ++++++++++--------- .../sortingcomponents/motion_estimation.py | 3 + 2 files changed, 37 insertions(+), 32 deletions(-) diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 3992a4c8c6..e26ae6dbc3 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -8,6 +8,7 @@ from spikeinterface.core import get_noise_levels, fix_job_kwargs from spikeinterface.core.job_tools import _shared_job_kwargs_doc from spikeinterface.core.core_tools import SIJsonEncoder +from torch import gather motion_options_preset = { # This preset should be the most acccurate @@ -20,7 +21,7 @@ exclude_sweep_ms=0.1, radius_um=50, ), - "select_kwargs": None, + "select_kwargs": dict(), "localize_peaks_kwargs": dict( method="monopolar_triangulation", radius_um=75.0, @@ -83,7 +84,7 @@ exclude_sweep_ms=0.1, radius_um=50, ), - "select_kwargs": None, + "select_kwargs": dict(), "localize_peaks_kwargs": dict( method="center_of_mass", radius_um=75.0, @@ -111,7 +112,7 @@ exclude_sweep_ms=0.1, radius_um=50, ), - "select_kwargs": None, + "select_kwargs": dict(), "localize_peaks_kwargs": dict( method="grid_convolution", radius_um=40.0, @@ -157,7 +158,7 @@ def correct_motion( folder=None, output_motion_info=False, detect_kwargs={}, - select_kwargs=None, + select_kwargs={}, localize_peaks_kwargs={}, estimate_motion_kwargs={}, interpolate_motion_kwargs={}, @@ -241,13 +242,22 @@ def correct_motion( # get preset params and update if necessary params = motion_options_preset[preset] detect_kwargs = dict(params["detect_kwargs"], **detect_kwargs) - if params["select_kwargs"] is None: - select_kwargs = None - else: - select_kwargs = dict(params["select_kwargs"], **select_kwargs) + select_kwargs = dict(params["select_kwargs"], **select_kwargs) localize_peaks_kwargs = dict(params["localize_peaks_kwargs"], **localize_peaks_kwargs) estimate_motion_kwargs = dict(params["estimate_motion_kwargs"], **estimate_motion_kwargs) interpolate_motion_kwargs = dict(params["interpolate_motion_kwargs"], **interpolate_motion_kwargs) + do_selection = len(select_kwargs) > 0 + + # params + parameters = dict( + detect_kwargs=detect_kwargs, + select_kwargs=select_kwargs, + localize_peaks_kwargs=localize_peaks_kwargs, + estimate_motion_kwargs=estimate_motion_kwargs, + interpolate_motion_kwargs=interpolate_motion_kwargs, + job_kwargs=job_kwargs, + sampling_frequency=recording.sampling_frequency, + ) if output_motion_info: motion_info = {} @@ -255,13 +265,20 @@ def correct_motion( motion_info = None job_kwargs = fix_job_kwargs(job_kwargs) - noise_levels = get_noise_levels(recording, return_scaled=False) - if select_kwargs is None: - # maybe do this directly in the folder when not None + if folder is not None: + folder = Path(folder) + folder.mkdir(exist_ok=True, parents=True) + + (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") + if recording.check_serializability("json"): + recording.dump_to_json(folder / "recording.json") + gather_mode = "npy" + else: gather_mode = "memory" + if not do_selection: # node detect method = detect_kwargs.pop("method", "locally_exclusive") method_class = detect_peak_methods[method] @@ -281,9 +298,10 @@ def correct_motion( job_kwargs, job_name="detect and localize", gather_mode=gather_mode, + gather_kwargs={"exist_ok": True}, squeeze_output=False, - folder=None, - names=None, + folder=folder, + names=["peaks", "peak_locations"], ) t1 = time.perf_counter() run_times = dict( @@ -307,6 +325,9 @@ def correct_motion( select_peaks=t2 - t1, localize_peaks=t3 - t2, ) + if folder is not None: + np.save(folder / "peaks.npy", peaks) + np.save(folder / "peak_locations.npy", peak_locations) t0 = time.perf_counter() motion, temporal_bins, spatial_bins = estimate_motion(recording, peaks, peak_locations, **estimate_motion_kwargs) @@ -318,29 +339,10 @@ def correct_motion( ) if folder is not None: - folder = Path(folder) - folder.mkdir(exist_ok=True, parents=True) - - # params and run times - parameters = dict( - detect_kwargs=detect_kwargs, - select_kwargs=select_kwargs, - localize_peaks_kwargs=localize_peaks_kwargs, - estimate_motion_kwargs=estimate_motion_kwargs, - interpolate_motion_kwargs=interpolate_motion_kwargs, - job_kwargs=job_kwargs, - sampling_frequency=recording.sampling_frequency, - ) - (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") - if recording.check_serializability("json"): - recording.dump_to_json(folder / "recording.json") - np.save(folder / "peaks.npy", peaks) - np.save(folder / "peak_locations.npy", peak_locations) np.save(folder / "temporal_bins.npy", temporal_bins) np.save(folder / "motion.npy", motion) - np.save(folder / "peak_locations.npy", peak_locations) if spatial_bins is not None: np.save(folder / "spatial_bins.npy", spatial_bins) diff --git a/src/spikeinterface/sortingcomponents/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion_estimation.py index df73575a01..141fc531f4 100644 --- a/src/spikeinterface/sortingcomponents/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion_estimation.py @@ -1039,6 +1039,7 @@ def jac(p): displacement = p elif convergence_method == "lsmr": + import gc from scipy import sparse from scipy.stats import zscore @@ -1170,6 +1171,8 @@ def jac(p): # warm start next iteration p0 = displacement + # Cleanup lsmr memory (see https://stackoverflow.com/questions/56147713/memory-leak-in-scipy) + gc.collect() displacement = displacement.reshape(B, T).T else: From 5f127c40bb6284c67179ae349294dae929ef574c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 27 Nov 2023 11:32:28 +0100 Subject: [PATCH 2/5] Remove unused import --- src/spikeinterface/preprocessing/motion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index e26ae6dbc3..f451ef8618 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -8,7 +8,6 @@ from spikeinterface.core import get_noise_levels, fix_job_kwargs from spikeinterface.core.job_tools import _shared_job_kwargs_doc from spikeinterface.core.core_tools import SIJsonEncoder -from torch import gather motion_options_preset = { # This preset should be the most acccurate From 54bb7fdcab895fddd25a5f59682e5974717f33e9 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 27 Nov 2023 13:21:07 +0100 Subject: [PATCH 3/5] Add TODO --- src/spikeinterface/sortingcomponents/motion_estimation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/sortingcomponents/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion_estimation.py index 141fc531f4..1345bd312c 100644 --- a/src/spikeinterface/sortingcomponents/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion_estimation.py @@ -1172,6 +1172,7 @@ def jac(p): # warm start next iteration p0 = displacement # Cleanup lsmr memory (see https://stackoverflow.com/questions/56147713/memory-leak-in-scipy) + # TODO: check if this gets fixed in scipy gc.collect() displacement = displacement.reshape(B, T).T From 0fd1e67f722918026e35f62bed292e3253114641 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 30 Nov 2023 16:27:30 +0100 Subject: [PATCH 4/5] Always use gather_mode='memory' and then save --- src/spikeinterface/preprocessing/motion.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index f451ef8618..f2c7983f2b 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -273,9 +273,6 @@ def correct_motion( (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") if recording.check_serializability("json"): recording.dump_to_json(folder / "recording.json") - gather_mode = "npy" - else: - gather_mode = "memory" if not do_selection: # node detect @@ -296,11 +293,11 @@ def correct_motion( pipeline_nodes, job_kwargs, job_name="detect and localize", - gather_mode=gather_mode, + gather_mode="memory", gather_kwargs={"exist_ok": True}, squeeze_output=False, - folder=folder, - names=["peaks", "peak_locations"], + folder=None, + names=None, ) t1 = time.perf_counter() run_times = dict( @@ -324,9 +321,9 @@ def correct_motion( select_peaks=t2 - t1, localize_peaks=t3 - t2, ) - if folder is not None: - np.save(folder / "peaks.npy", peaks) - np.save(folder / "peak_locations.npy", peak_locations) + if folder is not None: + np.save(folder / "peaks.npy", peaks) + np.save(folder / "peak_locations.npy", peak_locations) t0 = time.perf_counter() motion, temporal_bins, spatial_bins = estimate_motion(recording, peaks, peak_locations, **estimate_motion_kwargs) From cdc8d58492725e6dcd5c3d624ac5c59902e45b78 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 1 Dec 2023 16:43:21 +0100 Subject: [PATCH 5/5] Sam's suggestions --- src/spikeinterface/preprocessing/motion.py | 6 ++++-- src/spikeinterface/sortingcomponents/motion_estimation.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index f2c7983f2b..c81630fc1b 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -275,6 +275,8 @@ def correct_motion( recording.dump_to_json(folder / "recording.json") if not do_selection: + # maybe do this directly in the folder when not None, but might be slow on external storage + gather_mode = "memory" # node detect method = detect_kwargs.pop("method", "locally_exclusive") method_class = detect_peak_methods[method] @@ -293,8 +295,8 @@ def correct_motion( pipeline_nodes, job_kwargs, job_name="detect and localize", - gather_mode="memory", - gather_kwargs={"exist_ok": True}, + gather_mode=gather_mode, + gather_kwargs=None, squeeze_output=False, folder=None, names=None, diff --git a/src/spikeinterface/sortingcomponents/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion_estimation.py index 1345bd312c..8eb9cafe9d 100644 --- a/src/spikeinterface/sortingcomponents/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion_estimation.py @@ -220,7 +220,7 @@ class DecentralizedRegistration: pairwise_displacement_method: "conv" or "phase_cross_correlation" How to estimate the displacement in the pairwise matrix. max_displacement_um: float - Maximum possible discplacement in micrometers. + Maximum possible displacement in micrometers. weight_scale: "linear" or "exp" For parwaise displacement, how to to rescale the associated weight matrix. error_sigma: float, default: 0.2