From 18b7cfebfcd52c08513a521f96da6975bb0d89aa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Jul 2024 08:53:59 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- doc/how_to/drift_with_lfp.rst | 14 +- doc/how_to/index.rst | 2 +- examples/how_to/drift_with_lfp.py | 3 - .../plot_1_estimate_motion.py | 18 +- src/spikeinterface/preprocessing/motion.py | 14 +- .../sortingcomponents/motion/__init__.py | 8 +- .../sortingcomponents/motion/decentralized.py | 14 +- .../sortingcomponents/motion/dredge.py | 222 +++++++----------- .../motion/iterative_template.py | 3 - .../motion/motion_cleaner.py | 3 +- .../motion/motion_estimation.py | 12 +- .../sortingcomponents/motion/motion_utils.py | 28 +-- .../motion/tests/test_drege.py | 2 - .../motion/tests/test_motion_estimation.py | 3 +- 14 files changed, 134 insertions(+), 212 deletions(-) diff --git a/doc/how_to/drift_with_lfp.rst b/doc/how_to/drift_with_lfp.rst index e8d48301a0..a215f0920f 100644 --- a/doc/how_to/drift_with_lfp.rst +++ b/doc/how_to/drift_with_lfp.rst @@ -36,7 +36,7 @@ For each patient, the dataset contains two recording : a high pass (AP - from pathlib import Path import matplotlib.pyplot as plt - + import spikeinterface.full as si from spikeinterface.sortingcomponents.motion import estimate_motion @@ -57,7 +57,7 @@ read the spikeglx file .. parsed-literal:: - SpikeGLXRecordingExtractor: 384 channels - 2.5kHz - 1 segments - 2,183,292 samples + SpikeGLXRecordingExtractor: 384 channels - 2.5kHz - 1 segments - 2,183,292 samples 873.32s (14.56 minutes) - int16 dtype - 1.56 GiB @@ -87,7 +87,7 @@ eyes ont the traces plotted with the map mode. raw_rec, freq_min=0.5, freq_max=250, - + margin_ms=1500., filter_order=3, dtype="float32", @@ -95,16 +95,16 @@ eyes ont the traces plotted with the map mode. ) lfprec = si.phase_shift(lfprec) lfprec = si.resample(lfprec, resample_rate=250, margin_ms=1000) - + lfprec = si.directional_derivative(lfprec, order=2, edge_order=1) lfprec = si.average_across_direction(lfprec) - + print(lfprec) .. parsed-literal:: - AverageAcrossDirectionRecording: 192 channels - 0.2kHz - 1 segments - 218,329 samples + AverageAcrossDirectionRecording: 192 channels - 0.2kHz - 1 segments - 218,329 samples 873.32s (14.56 minutes) - float32 dtype - 159.91 MiB @@ -185,5 +185,3 @@ This motion match the LFP signal above. .. image:: drift_with_lfp_files/drift_with_lfp_12_1.png - - diff --git a/doc/how_to/index.rst b/doc/how_to/index.rst index cf9cadcfc3..5d7eae9003 100644 --- a/doc/how_to/index.rst +++ b/doc/how_to/index.rst @@ -14,4 +14,4 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to. process_by_channel_group load_your_data_into_sorting benchmark_with_hybrid_recordings - drift_with_lfp \ No newline at end of file + drift_with_lfp diff --git a/examples/how_to/drift_with_lfp.py b/examples/how_to/drift_with_lfp.py index fe84b2ab48..df656bc4ae 100644 --- a/examples/how_to/drift_with_lfp.py +++ b/examples/how_to/drift_with_lfp.py @@ -108,6 +108,3 @@ si.plot_motion(motion, mode='line', ax=ax) ax.set_xlim(400, 420) ax.set_ylim(800, 1300) - - - diff --git a/examples/tutorials/sortingcomponents/plot_1_estimate_motion.py b/examples/tutorials/sortingcomponents/plot_1_estimate_motion.py index 0ce60ac7d3..87eaa4c51a 100644 --- a/examples/tutorials/sortingcomponents/plot_1_estimate_motion.py +++ b/examples/tutorials/sortingcomponents/plot_1_estimate_motion.py @@ -19,15 +19,15 @@ from spikeinterface.widgets import plot_motion, plot_motion_info, plot_probe_map # %% -# First, let's simulate a drifting recording using the +# First, let's simulate a drifting recording using the # :code:`spikeinterface.generation module`. -# +# # Here the simulated recording has a small zigzag motion along the 'y' axis of the probe. static_recording, drifting_recording, sorting = generate_drifting_recording( num_units=200, duration=300., - probe_name='Neuropixel-128', + probe_name='Neuropixel-128', generate_displacement_vector_kwargs=dict( displacement_sampling_frequency=5.0, drift_start_um=[0, 20], @@ -50,12 +50,12 @@ # %% # Here we will use the high level function :code:`correct_motion()` -# +# # Internally, this function is doing all steps of the motion detection: # 1. **activity profile** : detect peaks and localize them along time and depth # 2. **motion inference**: estimate the drift motion # 3. **motion interpolation**: interpolate traces using the estimated motion -# +# # All steps have an use several methods with many parameters. This is why we can use # 'preset' which combine methods and related parameters. # @@ -70,7 +70,7 @@ # %% # The function return a recording 'corrected' -# +# # A new recording is return, this recording will interpolate motion corrected traces # when calling get_traces() @@ -78,14 +78,14 @@ # %% # Optionally the function also return the `Motion` object itself -# +# print(motion) # %% # This motion can be plotted, in our case the motion has been estimated as non-rigid # so we can use the use the `mode='map'` to check the motion across depth. -# +# plot_motion(motion, mode='line') plot_motion(motion, mode='map') @@ -93,7 +93,7 @@ # %% # The dict `motion_info` can be used for more plotting. -# Here we can appreciate of the two top axes the raster of peaks depth vs times before and +# Here we can appreciate of the two top axes the raster of peaks depth vs times before and # after correction. fig = plt.figure() diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 59b5a590c2..0568650316 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -20,7 +20,7 @@ peak_sign="neg", detect_threshold=8.0, exclude_sweep_ms=0.8, - radius_um=80., + radius_um=80.0, ), "select_kwargs": dict(), "localize_peaks_kwargs": dict( @@ -76,7 +76,7 @@ peak_sign="neg", detect_threshold=8.0, exclude_sweep_ms=0.8, - radius_um=80., + radius_um=80.0, ), "select_kwargs": dict(), "localize_peaks_kwargs": dict( @@ -196,7 +196,6 @@ } - def correct_motion( recording, preset="nonrigid_accurate", @@ -398,16 +397,15 @@ def correct_motion( if not output_motion and not output_motion_info: return recording_corrected - - out = (recording_corrected, ) + + out = (recording_corrected,) if output_motion: - out += (motion, ) + out += (motion,) if output_motion_info: - out += (motion_info, ) + out += (motion_info,) return out - _doc_presets = "\n" for k, v in motion_options_preset.items(): if k == "": diff --git a/src/spikeinterface/sortingcomponents/motion/__init__.py b/src/spikeinterface/sortingcomponents/motion/__init__.py index 15233efc32..d2e6a8a3d9 100644 --- a/src/spikeinterface/sortingcomponents/motion/__init__.py +++ b/src/spikeinterface/sortingcomponents/motion/__init__.py @@ -1,5 +1,9 @@ from .motion_utils import Motion from .motion_estimation import estimate_motion -from .motion_interpolation import (correct_motion_on_peaks, interpolate_motion_on_traces, - InterpolateMotionRecording, interpolate_motion) +from .motion_interpolation import ( + correct_motion_on_peaks, + interpolate_motion_on_traces, + InterpolateMotionRecording, + interpolate_motion, +) from .motion_cleaner import clean_motion_vector diff --git a/src/spikeinterface/sortingcomponents/motion/decentralized.py b/src/spikeinterface/sortingcomponents/motion/decentralized.py index 396a18bba2..d054995839 100644 --- a/src/spikeinterface/sortingcomponents/motion/decentralized.py +++ b/src/spikeinterface/sortingcomponents/motion/decentralized.py @@ -111,8 +111,8 @@ def run( bin_um=1.0, hist_margin_um=20.0, bin_s=1.0, - histogram_depth_smooth_um=1., - histogram_time_smooth_s=1., + histogram_depth_smooth_um=1.0, + histogram_time_smooth_s=1.0, pairwise_displacement_method="conv", max_displacement_um=100.0, weight_scale="linear", @@ -135,7 +135,6 @@ def run( lsqr_robust_n_iter=20, weight_with_amplitude=False, ): - dim = ["x", "y", "z"].index(direction) contact_depth = recording.get_channel_locations()[:, dim] @@ -153,7 +152,7 @@ def run( win_step_um=win_step_um, win_scale_um=win_scale_um, win_margin_um=win_margin_um, - zero_threshold=None + zero_threshold=None, ) # make 2D histogram raster @@ -322,6 +321,7 @@ def compute_pairwise_displacement( try: import torch import torch.nn.functional as F + conv_engine = "torch" except ImportError: conv_engine = "numpy" @@ -430,7 +430,6 @@ def compute_pairwise_displacement( return pairwise_displacement, pairwise_displacement_weight - _possible_convergence_method = ("lsmr", "gradient_descent", "lsqr_robust") @@ -687,9 +686,6 @@ def jac(p): return np.squeeze(displacement) - - - # normxcorr1d is now implemented in dredge # we keep the old version here but this will be removed soon @@ -809,4 +805,4 @@ def jac(p): # # get rid of NaNs in zero-variance areas # corr[~npx.isfinite(corr)] = 0 -# return corr \ No newline at end of file +# return corr diff --git a/src/spikeinterface/sortingcomponents/motion/dredge.py b/src/spikeinterface/sortingcomponents/motion/dredge.py index 9aa15852cd..3eb83dfbaa 100644 --- a/src/spikeinterface/sortingcomponents/motion/dredge.py +++ b/src/spikeinterface/sortingcomponents/motion/dredge.py @@ -2,7 +2,7 @@ Copy-paste and then refactoring of DREDge https://github.com/evarol/dredge -For historical reason, some function from the DREDge package where implemeneted +For historical reason, some function from the DREDge package where implemeneted in spikeinterface in the motion_estimation.py before the DREDge package itself! Here a copy/paste (and small rewriting) of some functions from DREDge. @@ -21,6 +21,7 @@ but the original function dredge_ap() and dredge_online_lfp() can be used directly. """ + import warnings from tqdm.auto import trange @@ -28,10 +29,14 @@ import gc -from .motion_utils import Motion, get_spatial_windows, get_window_domains, scipy_conv1d, make_2d_motion_histogram, get_spatial_bin_edges - - - +from .motion_utils import ( + Motion, + get_spatial_windows, + get_window_domains, + scipy_conv1d, + make_2d_motion_histogram, + get_spatial_bin_edges, +) # simple class wrapper to be compliant with estimate_motion @@ -84,6 +89,7 @@ class DredgeApRegistration: device : str or torch.device What torch device to run on? E.g., "cpu" or "cuda" or "cuda:1". """ + @classmethod def run( cls, @@ -102,7 +108,6 @@ def run( **method_kwargs, ): - outs = dredge_ap( recording, peaks, @@ -141,7 +146,7 @@ def dredge_ap( bin_um=1.0, bin_s=1.0, max_disp_um=None, - time_horizon_s=1000., + time_horizon_s=1000.0, mincorr=0.1, # weights arguments do_window_weights=True, @@ -217,14 +222,12 @@ def dredge_ap( depths_um = peak_depths = peak_locations[direction] times_s = peak_times = recording.sample_index_to_time(peaks["sample_index"]) - - thomas_kw = thomas_kw if thomas_kw is not None else {} xcorr_kw = xcorr_kw if xcorr_kw is not None else {} if time_horizon_s: xcorr_kw["max_dt_bins"] = np.ceil(time_horizon_s / bin_s) - - #TODO @charlie I think this is a bad to have the dict which is transported to every function + + # TODO @charlie I think this is a bad to have the dict which is transported to every function # this should be used only in histogram function but not in weight_correlation_matrix() # only important kwargs should be explicitly reported # raster_kw = dict( @@ -251,8 +254,6 @@ def dredge_ap( # this will store return values other than the MotionEstimate extra = {} - - # TODO charlie I switch this to make_2d_motion_histogram # but we need to add all options from the original spike_raster() # but I think this is OK @@ -266,7 +267,7 @@ def dredge_ap( # raster, spatial_bin_edges_um, time_bin_edges_s, counts = raster_res # else: # raster, spatial_bin_edges_um, time_bin_edges_s = raster_res - + motion_histogram, time_bin_edges_s, spatial_bin_edges_um = make_2d_motion_histogram( recording, peaks, @@ -276,7 +277,7 @@ def dredge_ap( direction=direction, bin_s=bin_s, bin_um=bin_um, - hist_margin_um=0., # @charlie maybe we should expose this and set +20. for instance + hist_margin_um=0.0, # @charlie maybe we should expose this and set +20. for instance spatial_bin_edges=None, depth_smooth_um=histogram_depth_smooth_um, time_smooth_s=histogram_time_smooth_s, @@ -285,7 +286,6 @@ def dredge_ap( # TODO charlie : put the log for hitstogram - # TODO @charlie you should check that we are doing the same thing # windows, window_centers = get_spatial_windows( # np.c_[np.zeros_like(spatial_bin_edges_um), spatial_bin_edges_um], @@ -301,7 +301,7 @@ def dredge_ap( dim = ["x", "y", "z"].index(direction) contact_depth = recording.get_channel_locations()[:, dim] spatial_bin_centers = 0.5 * (spatial_bin_edges_um[1:] + spatial_bin_edges_um[:-1]) - + windows, window_centers = get_spatial_windows( contact_depth, spatial_bin_centers, @@ -310,16 +310,13 @@ def dredge_ap( win_step_um=win_step_um, win_scale_um=win_scale_um, win_margin_um=win_margin_um, - zero_threshold=1e-5 - ) - - + zero_threshold=1e-5, + ) # TODO charlie : the count has disapeared # if extra_outputs and count_masked_correlation: # extra["counts"] = counts - # cross-correlate to get D and C if precomputed_D_C_maxdisp is None: Ds, Cs, max_disp_um = xcorr_windows( @@ -348,7 +345,7 @@ def dredge_ap( spatial_bin_edges_um, time_bin_edges_s, # raster_kw, #@charlie this is removed - post_transform=post_transform, # @charlie this isnew + post_transform=post_transform, # @charlie this isnew lambda_t=thomas_kw.get("lambda_t", DEFAULT_LAMBDA_T), eps=thomas_kw.get("eps", DEFAULT_EPS), progress_bar=progress_bar, @@ -403,6 +400,7 @@ class DredgeLfpRegistration: The reference is here https://www.biorxiv.org/content/10.1101/2023.10.24.563768v1 """ + name = "dredge_lfp" need_peak_location = False params_doc = """ @@ -447,6 +445,7 @@ class DredgeLfpRegistration: device : string or torch.device Controls torch device """ + @classmethod def run( cls, @@ -462,7 +461,6 @@ def run( verbose, progress_bar, extra, - **method_kwargs, ): # Note peaks and peak_locations are not used and can be None @@ -488,24 +486,17 @@ def run( return motion - - - - def dredge_online_lfp( lfp_recording, - direction='y', + direction="y", # nonrigid window construction arguments rigid=True, win_shape="gaussian", win_step_um=800, win_scale_um=850, win_margin_um=None, - chunk_len_s=10.0, max_disp_um=500, - - time_horizon_s=None, # weighting arguments mincorr=0.8, @@ -537,7 +528,6 @@ def dredge_online_lfp( # contact pos is the only on the direction contact_depth = lfp_recording.get_channel_locations()[:, dim] - fs = lfp_recording.get_sampling_frequency() T_total = lfp_recording.get_num_samples() T_chunk = min(int(np.floor(fs * chunk_len_s)), T_total) @@ -563,7 +553,6 @@ def dredge_online_lfp( bin_s=1 / fs, # only relevant for time_horizon_s ) - # here we check that contact positons are unique on the direction if contact_depth.size != np.unique(contact_depth).size: raise ValueError( @@ -599,9 +588,7 @@ def dredge_online_lfp( # below, t0 is start of prev chunk, t1 start of cur chunk, t2 end of cur t0, t1 = 0, T_chunk traces0 = lfp_recording.get_traces(start_frame=t0, end_frame=t1) - Ds0, Cs0, max_disp_um = xcorr_windows( - traces0.T, windows, contact_depth, win_scale_um, **full_xcorr_kw - ) + Ds0, Cs0, max_disp_um = xcorr_windows(traces0.T, windows, contact_depth, win_scale_um, **full_xcorr_kw) full_xcorr_kw["max_disp_um"] = max_disp_um Ss0, mincorr0 = threshold_correlation_matrix( Cs0, @@ -646,19 +633,14 @@ def dredge_online_lfp( ) # cross-correlation in current chunk - Ds1, Cs1, _ = xcorr_windows( - traces1.T, windows, contact_depth, win_scale_um, **full_xcorr_kw - ) + Ds1, Cs1, _ = xcorr_windows(traces1.T, windows, contact_depth, win_scale_um, **full_xcorr_kw) Ss1, mincorr1 = threshold_correlation_matrix( Cs1, mincorr_percentile=mincorr_percentile, mincorr=mincorr, **threshold_kw, ) - Ss10, _ = threshold_correlation_matrix( - Cs10, mincorr=mincorr1, t_offset_bins=T_chunk, **threshold_kw - ) - + Ss10, _ = threshold_correlation_matrix(Cs10, mincorr=mincorr1, t_offset_bins=T_chunk, **threshold_kw) if extra_outputs: extra["mincorrs"].append(mincorr1) @@ -692,7 +674,8 @@ def dredge_online_lfp( else: return motion -dredge_online_lfp.__doc__ = dredge_online_lfp.__doc__.format(DredgeLfpRegistration.params_doc) + +dredge_online_lfp.__doc__ = dredge_online_lfp.__doc__.format(DredgeLfpRegistration.params_doc) # -- functions from dredgelib (zone forbiden for sam) @@ -721,7 +704,6 @@ def laplacian(n, wink=True, eps=DEFAULT_EPS, lambd=1.0, ridge_mask=None): return lap - def neg_hessian_likelihood_term(Ub, Ub_prevcur=None, Ub_curprev=None): """Newton step coefficients @@ -761,12 +743,7 @@ def newton_rhs( # online case align_term = (Ub_prevcur.T + Ub_curprev) @ Pb_prev - rhs = ( - align_term - + grad_at_0 - + (Ub_curprev * Db_curprev).sum(1) - - (Ub_prevcur * Db_prevcur).sum(0) - ) + rhs = align_term + grad_at_0 + (Ub_curprev * Db_curprev).sum(1) - (Ub_prevcur * Db_prevcur).sum(0) return rhs @@ -881,9 +858,7 @@ def thomas_solve( P = np.zeros((B, T)) extra["HU"] = np.zeros((B, T, T)) for b in range(B): - P[b], extra["HU"][b] = newton_solve_rigid( - Ds[b], Us[b], L_t[b], **online_kw_rhs(b) - ) + P[b], extra["HU"][b] = newton_solve_rigid(Ds[b], Us[b], L_t[b], **online_kw_rhs(b)) return P, extra # spatial prior is a sparse, block tridiagonal kronecker product @@ -893,31 +868,21 @@ def thomas_solve( Lambda_s_offdiag = laplacian(T, eps=0, lambd=-lambda_s / 2) # initialize block-LU stuff and forward variable - alpha_hat_b = ( - L_t[0] - + Lambda_s_diagb - + neg_hessian_likelihood_term(Us[0], **online_kw_hess(0)) - ) - targets = np.c_[ - Lambda_s_offdiag, newton_rhs(Us[0], Ds[0], **online_kw_rhs(0)) - ] + alpha_hat_b = L_t[0] + Lambda_s_diagb + neg_hessian_likelihood_term(Us[0], **online_kw_hess(0)) + targets = np.c_[Lambda_s_offdiag, newton_rhs(Us[0], Ds[0], **online_kw_rhs(0))] res = solve(alpha_hat_b, targets, assume_a="pos") assert res.shape == (T, T + 1) gamma_hats = [res[:, :T]] ys = [res[:, T]] # forward pass - for b in (trange(1, B, desc="Solve") if progress_bar else range(1, B)): + for b in trange(1, B, desc="Solve") if progress_bar else range(1, B): if b < B - 1: Lambda_s_diagb = laplacian(T, eps=eps, lambd=lambda_s, ridge_mask=had_weights[b]) else: Lambda_s_diagb = laplacian(T, eps=eps, lambd=lambda_s / 2, ridge_mask=had_weights[b]) - Ab = ( - L_t[b] - + Lambda_s_diagb - + neg_hessian_likelihood_term(Us[b], **online_kw_hess(b)) - ) + Ab = L_t[b] + Lambda_s_diagb + neg_hessian_likelihood_term(Us[b], **online_kw_hess(b)) alpha_hat_b = Ab - Lambda_s_offdiag @ gamma_hats[b - 1] targets[:, T] = newton_rhs(Us[b], Ds[b], **online_kw_rhs(b)) targets[:, T] -= Lambda_s_offdiag @ ys[b - 1] @@ -938,7 +903,6 @@ def thomas_solve( return P, extra - def threshold_correlation_matrix( Cs, mincorr=0.0, @@ -952,10 +916,7 @@ def threshold_correlation_matrix( soft=True, ): if mincorr_percentile is not None: - diags = [ - np.diagonal(Cs, offset=j, axis1=1, axis2=2).ravel() - for j in range(1, mincorr_percentile_nneighbs) - ] + diags = [np.diagonal(Cs, offset=j, axis1=1, axis2=2).ravel() for j in range(1, mincorr_percentile_nneighbs)] mincorr = np.percentile( np.concatenate(diags), mincorr_percentile, @@ -974,12 +935,7 @@ def threshold_correlation_matrix( Ss = np.square((Cs >= mincorr) * Cs) else: Ss = (Cs >= mincorr).astype(Cs.dtype) - if ( - time_horizon_s is not None - and time_horizon_s > 0 - and T is not None - and time_horizon_s < T - ): + if time_horizon_s is not None and time_horizon_s > 0 and T is not None and time_horizon_s < T: tt0 = bin_s * np.arange(T) tt1 = tt0 if t_offset_bins: @@ -1136,11 +1092,7 @@ def calc_corr_decent_pair( # pick torch device if unset if device is None: - device = ( - torch.device("cuda") - if torch.cuda.is_available() - else torch.device("cpu") - ) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # process rasters into the tensors we need for conv2ds below # convert to TxD device floats @@ -1152,9 +1104,7 @@ def calc_corr_decent_pair( C = np.zeros((Ta, Tb), dtype=np.float32) for i in range(0, Ta, batch_size): for j in range(0, Tb, batch_size): - dt_bins = min( - abs(i - j), abs(i + batch_size - j), abs(i - j - batch_size) - ) + dt_bins = min(abs(i - j), abs(i + batch_size - j), abs(i - j - batch_size)) if max_dt_bins and dt_bins > max_dt_bins: continue weights_ = weights @@ -1186,53 +1136,52 @@ def normxcorr1d( normalized=True, padding="same", conv_engine="torch", - ): """normxcorr1d: Normalized cross-correlation, optionally weighted - The API is like torch's F.conv1d, except I have accidentally - changed the position of input/weights -- template acts like weights, - and x acts like input. - - Returns the cross-correlation of `template` and `x` at spatial lags - determined by `mode`. Useful for estimating the location of `template` - within `x`. - - This might not be the most efficient implementation -- ideas welcome. - It uses a direct convolutional translation of the formula - corr = (E[XY] - EX EY) / sqrt(var X * var Y) - - This also supports weights! In that case, the usual adaptation of - the above formula is made to the weighted case -- and all of the - normalizations are done per block in the same way. - - Arguments - --------- - template : tensor, shape (num_templates, length) - The reference template signal - x : tensor, 1d shape (length,) or 2d shape (num_inputs, length) - The signal in which to find `template` - weights : tensor, shape (length,) - Will use weighted means, variances, covariances if supplied. - centered : bool - If true, means will be subtracted (per weighted patch). - normalized : bool - If true, normalize by the variance (per weighted patch). - padding : int, optional - How far to look? if unset, we'll use half the length - conv_engine : "torch" | "numpy" - What library to use for computing cross-correlations. - If numpy, falls back to the scipy correlate function. -conv_engine - Returns - ------- - corr : tensor + The API is like torch's F.conv1d, except I have accidentally + changed the position of input/weights -- template acts like weights, + and x acts like input. + + Returns the cross-correlation of `template` and `x` at spatial lags + determined by `mode`. Useful for estimating the location of `template` + within `x`. + + This might not be the most efficient implementation -- ideas welcome. + It uses a direct convolutional translation of the formula + corr = (E[XY] - EX EY) / sqrt(var X * var Y) + + This also supports weights! In that case, the usual adaptation of + the above formula is made to the weighted case -- and all of the + normalizations are done per block in the same way. + + Arguments + --------- + template : tensor, shape (num_templates, length) + The reference template signal + x : tensor, 1d shape (length,) or 2d shape (num_inputs, length) + The signal in which to find `template` + weights : tensor, shape (length,) + Will use weighted means, variances, covariances if supplied. + centered : bool + If true, means will be subtracted (per weighted patch). + normalized : bool + If true, normalize by the variance (per weighted patch). + padding : int, optional + How far to look? if unset, we'll use half the length + conv_engine : "torch" | "numpy" + What library to use for computing cross-correlations. + If numpy, falls back to the scipy correlate function. + conv_engine + Returns + ------- + corr : tensor """ - if conv_engine == "torch": import torch import torch.nn.functional as F + conv1d = F.conv1d npx = torch elif conv_engine == "numpy": @@ -1297,9 +1246,7 @@ def normxcorr1d( # compute variances for denominator, using var X = E[X^2] - (EX)^2 if normalized: - var_template = conv1d( - onesx, wt * template, padding=padding - ) + var_template = conv1d(onesx, wt * template, padding=padding) var_template /= Nx var_x = conv1d(wx * x, weights, padding=padding) var_x /= Nx @@ -1354,20 +1301,12 @@ def get_weights( if isinstance(weights_threshold_low, tuple): nspikes_threshold_low, amp_threshold_low = weights_threshold_low unif = np.full_like(windows[0], 1 / len(windows[0])) - weights_threshold_low = ( - scale_fn(amp_threshold_low) - * windows - @ (nspikes_threshold_low * unif) - ) + weights_threshold_low = scale_fn(amp_threshold_low) * windows @ (nspikes_threshold_low * unif) weights_threshold_low = weights_threshold_low[:, None] if isinstance(weights_threshold_high, tuple): nspikes_threshold_high, amp_threshold_high = weights_threshold_high unif = np.full_like(windows[0], 1 / len(windows[0])) - weights_threshold_high = ( - scale_fn(amp_threshold_high) - * windows - @ (nspikes_threshold_high * unif) - ) + weights_threshold_high = scale_fn(amp_threshold_high) * windows @ (nspikes_threshold_high * unif) weights_threshold_high = weights_threshold_high[:, None] weights_thresh = weights_orig.copy() weights_thresh[weights_orig < weights_threshold_low] = 0 @@ -1375,6 +1314,7 @@ def get_weights( return weights, weights_thresh, p_inds + def weight_correlation_matrix( Ds, Cs, @@ -1436,7 +1376,7 @@ def weight_correlation_matrix( raster, depth_bin_edges, time_bin_edges, - #raster_kw, + # raster_kw, post_transform=post_transform, weights_threshold_low=weights_threshold_low, weights_threshold_high=weights_threshold_high, diff --git a/src/spikeinterface/sortingcomponents/motion/iterative_template.py b/src/spikeinterface/sortingcomponents/motion/iterative_template.py index a49d5bd639..f5e2e30d4a 100644 --- a/src/spikeinterface/sortingcomponents/motion/iterative_template.py +++ b/src/spikeinterface/sortingcomponents/motion/iterative_template.py @@ -96,8 +96,6 @@ def run( zero_threshold=None, ) - - # make a 3D histogram motion_histograms, temporal_hist_bin_edges, spatial_hist_bin_edges = make_3d_motion_histograms( recording, @@ -143,7 +141,6 @@ def run( return motion - def iterative_template_registration( spikecounts_hist_images, non_rigid_windows=None, diff --git a/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py b/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py index de2c7df4cc..2ac20ad46d 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py @@ -2,6 +2,7 @@ # TODO this need a full rewrite with motion object + def clean_motion_vector(motion, temporal_bins, bin_duration_s, speed_threshold=30, sigma_smooth_s=None): """ Simple machinery to remove spurious fast bump in the motion vector. @@ -69,5 +70,3 @@ def clean_motion_vector(motion, temporal_bins, bin_duration_s, speed_threshold=3 motion_clean = scipy.signal.fftconvolve(motion_clean, smooth_kernel, mode="same", axes=0) return motion_clean - - diff --git a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py index f7f4f4ad66..2d8564fc54 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py @@ -4,7 +4,6 @@ import numpy as np - from spikeinterface.sortingcomponents.tools import make_multi_method_doc @@ -22,8 +21,8 @@ def estimate_motion( direction="y", rigid=False, win_shape="gaussian", - win_step_um=50.0, # @alessio charlie is proposing here instead 400 - win_scale_um=150.0, # @alessio charlie is proposing here instead 400 + win_step_um=50.0, # @alessio charlie is proposing here instead 400 + win_scale_um=150.0, # @alessio charlie is proposing here instead 400 win_margin_um=None, method="decentralized", extra_outputs=False, @@ -33,8 +32,8 @@ def estimate_motion( **method_kwargs, ): """ - - + + Estimate motion with several possible methods. Most of methods except dredge_lfp needs peaks and after their localization. @@ -98,7 +97,6 @@ def estimate_motion( if margin_um is not None: warnings.warn("estimate_motion() margin_um has been removed used hist_margin_um or win_margin_um") - # TODO handle multi segment one day : Charlie this is for you assert recording.get_num_segments() == 1, "At the moment estimate_motion handle only unique segment" @@ -119,13 +117,11 @@ def estimate_motion( peaks, peak_locations, direction, - rigid, win_shape, win_step_um, win_scale_um, win_margin_um, - verbose, progress_bar, extra, diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index 228110b7ec..a848ca1746 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -229,22 +229,21 @@ def copy(self): [d.copy() for d in self.displacement], [t.copy() for t in self.temporal_bins_s], self.spatial_bins_um.copy(), - direction=self.direction, + direction=self.direction, interpolation_method=self.interpolation_method, ) - def get_spatial_windows( - contact_depth, - spatial_bin_centers, - rigid=False, - win_shape="gaussian", - win_step_um=50.0, - win_scale_um=150.0, - win_margin_um=None, - zero_threshold=None - ): + contact_depth, + spatial_bin_centers, + rigid=False, + win_shape="gaussian", + win_step_um=50.0, + win_scale_um=150.0, + win_margin_um=None, + zero_threshold=None, +): """ Generate spatial windows (taper) for non-rigid motion. For rigid motion, this is equivalent to have one unique rectangular window that covers the entire probe. @@ -297,14 +296,14 @@ def get_spatial_windows( middle = (spatial_bin_centers[0] + spatial_bin_centers[-1]) / 2.0 window_centers = np.array([middle]) else: - if win_scale_um <= win_step_um/5.: + if win_scale_um <= win_step_um / 5.0: warnings.warn( f"get_spatial_windows(): spatial windows are probably not overlaping because {win_scale_um=} and {win_step_um=}" ) if win_margin_um is None: # this ensure that first/last windows do not overflow outside the probe - win_margin_um = -win_scale_um / 2. + win_margin_um = -win_scale_um / 2.0 min_ = np.min(contact_depth) - win_margin_um max_ = np.max(contact_depth) + win_margin_um @@ -388,7 +387,6 @@ def get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um): return spatial_bins - def make_2d_motion_histogram( recording, peaks, @@ -465,7 +463,7 @@ def make_2d_motion_histogram( weights = None motion_histogram, edges = np.histogramdd(arr, bins=(temporal_bin_edges, spatial_bin_edges), weights=weights) - + # average amplitude in each bin if weight_with_amplitude and avg_in_bin: bin_counts, _ = np.histogramdd(arr, bins=(temporal_bin_edges, spatial_bin_edges)) diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_drege.py b/src/spikeinterface/sortingcomponents/motion/tests/test_drege.py index 218d9036aa..8133c1fa6b 100644 --- a/src/spikeinterface/sortingcomponents/motion/tests/test_drege.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_drege.py @@ -1,8 +1,6 @@ import pytest - - def test_dredge_online_lfp(): pass diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py index 1168b65c79..3c83a56b9d 100644 --- a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py @@ -57,7 +57,7 @@ def dataset_fixture(create_cache_folder): def test_estimate_motion(dataset): # recording, sorting = make_dataset() recording, sorting, cache_folder = dataset - + peaks = np.load(cache_folder / "dataset_peaks.npy") peak_locations = np.load(cache_folder / "dataset_peak_locations.npy") @@ -222,6 +222,7 @@ def test_estimate_motion(dataset): if __name__ == "__main__": import tempfile + with tempfile.TemporaryDirectory() as tmpdirname: cache_folder = Path(tmpdirname) args = setup_dataset_and_peaks(cache_folder)