diff --git a/doc/api.rst b/doc/api.rst index 3e825084e7..ac221ac602 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -408,12 +408,6 @@ Peak Detection .. autofunction:: detect_peaks -Motion Correction -~~~~~~~~~~~~~~~~~ -.. automodule:: spikeinterface.sortingcomponents.motion_interpolation - - .. autoclass:: InterpolateMotionRecording - Clustering ~~~~~~~~~~ .. automodule:: spikeinterface.sortingcomponents.clustering @@ -425,3 +419,15 @@ Template Matching .. automodule:: spikeinterface.sortingcomponents.matching .. autofunction:: find_spikes_from_templates + +Motion Correction +~~~~~~~~~~~~~~~~~ +.. automodule:: spikeinterface.sortingcomponents.motion + + .. autoclass:: Motion + .. autofunction:: estimate_motion + .. autofunction:: interpolate_motion + .. autofunction:: correct_motion_on_peaks + .. autofunction:: interpolate_motion_on_traces + .. autofunction:: clean_motion_vector + .. autoclass:: InterpolateMotionRecording diff --git a/doc/how_to/benchmark_with_hybrid_recordings.rst b/doc/how_to/benchmark_with_hybrid_recordings.rst index 5870d87955..9975bb1a4b 100644 --- a/doc/how_to/benchmark_with_hybrid_recordings.rst +++ b/doc/how_to/benchmark_with_hybrid_recordings.rst @@ -24,7 +24,7 @@ order to smoothly inject spikes into the recording. import spikeinterface.generation as sgen import spikeinterface.widgets as sw - from spikeinterface.sortingcomponents.motion_estimation import estimate_motion + from spikeinterface.sortingcomponents.motion import estimate_motion import numpy as np import matplotlib.pyplot as plt @@ -1202,63 +1202,63 @@ drifts when injecting hybrid spikes. 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 1. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 2. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 3. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 4. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 5. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 6. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 7. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 8. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 9. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 10. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 11. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 12. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 13. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 14. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 15. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385] diff --git a/doc/how_to/drift_with_lfp.rst b/doc/how_to/drift_with_lfp.rst new file mode 100644 index 0000000000..0decc1058a --- /dev/null +++ b/doc/how_to/drift_with_lfp.rst @@ -0,0 +1,163 @@ +Estimate drift using the LFP traces +=================================== + +Drift is a well known issue for long shank probes. Some datasets, especially from primates and humans, +can experience very fast motion due to breathing and heart beats. In these cases, the standard motion +estimation methods that use detected spikes as a basis for motion inference will fail, because there +are not enough spikes to "follow" such fast drifts. + +Charlie Windolf and colleagues from the Paninski Lab at Columbia have developed a method to estimate +the motion using the LFP signal: **DREDge**. (more details about the method in the paper +`DREDge: robust motion correction for high-density extracellular recordings across species `_). + +This method is particularly suited for the open dataset recorded at Massachusetts General Hospital by Angelique Paulk and colleagues in humans (more details in the [paper](https://doi.org/10.1038/s41593-021-00997-0)). The dataset can be dowloaed from [datadryad](https://datadryad.org/stash/dataset/doi:10.5061/dryad.d2547d840) and it contains recordings on human patients with a Neuropixels probe, some of which with very high and fast motion on the probe, which prevents accurate spike sorting without a proper and adequate motion correction + +The **DREDge** method has two options: **dredge_lfp** and **dredge_ap**, which have both been ported inside `SpikeInterface`. + +Here we will demonstrate the **dredge_lfp** method to estimate the fast and high drift on this recording. + +For each patient, the dataset contains two streams: + +* a highpass "action potential" (AP), sampled at 30kHz +* a lowpass "local field" (LF) sampled at 2.5kHz + +For this demonstration, we will use the LF stream. + +.. code:: ipython3 + + %matplotlib inline + %load_ext autoreload + %autoreload 2 + +.. code:: ipython3 + + from pathlib import Path + import matplotlib.pyplot as plt + + import spikeinterface.full as si + from spikeinterface.sortingcomponents.motion import estimate_motion + +.. code:: ipython3 + + # the dataset has been locally downloaded + base_folder = Path("/mnt/data/sam/DataSpikeSorting/") + np_data_drift = base_folder / 'human_neuropixel/Pt02/' + +Read the spikeglx file +~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: ipython3 + + raw_rec = si.read_spikeglx(np_data_drift) + print(raw_rec) + + +.. parsed-literal:: + + SpikeGLXRecordingExtractor: 384 channels - 2.5kHz - 1 segments - 2,183,292 samples + 873.32s (14.56 minutes) - int16 dtype - 1.56 GiB + + +Preprocessing +~~~~~~~~~~~~~ + +Contrary to the **dredge_ap** approach, which needs detected peaks and peak locations, the **dredge_lfp** +method is estimating the motion directly on traces. +Importantly, the method requires some additional pre-processing steps: + * ``bandpass_filter``: to "focus" the signal on a particular band + * ``phase_shift``: to compensate for the sampling misalignement + * ``resample``: to further reduce the sampling fequency of the signal and speed up the computation. The sampling frequency of the estimated motion will be the same as the resampling frequency. Here we choose 250Hz, which corresponds to a sampling interval of 4ms. + * ``directional_derivative``: this optional step applies a second order derivative in the spatial dimension to enhance edges on the traces. + This is not a general rules and need to be tested case by case. + * ``average_across_direction``: Neuropixels 1.0 probes have two contacts per depth. This steps averages them to obtain a unique virtual signal along the probe depth ("y" in ``spikeinterface``). + +After appying this preprocessing chain, the motion can be estimated almost by eyes ont the traces plotted with the map mode. + +.. code:: ipython3 + + lfprec = si.bandpass_filter( + raw_rec, + freq_min=0.5, + freq_max=250, + + margin_ms=1500., + filter_order=3, + dtype="float32", + add_reflect_padding=True, + ) + lfprec = si.phase_shift(lfprec) + lfprec = si.resample(lfprec, resample_rate=250, margin_ms=1000) + + lfprec = si.directional_derivative(lfprec, order=2, edge_order=1) + lfprec = si.average_across_direction(lfprec) + + print(lfprec) + + +.. parsed-literal:: + + AverageAcrossDirectionRecording: 192 channels - 0.2kHz - 1 segments - 218,329 samples + 873.32s (14.56 minutes) - float32 dtype - 159.91 MiB + + +.. code:: ipython3 + + %matplotlib inline + si.plot_traces(lfprec, backend="matplotlib", mode="map", clim=(-0.05, 0.05), time_range=(400, 420)) + + + +.. image:: drift_with_lfp_files/drift_with_lfp_8_1.png + + +Run the method +~~~~~~~~~~~~~~ + +``estimate_motion()`` is the generic function to estimate motion with multiple +methods in ``spikeinterface``. + +This function returns a ``Motion`` object and we can notice that the interval is exactly +the same as downsampled signal. + +Here we use ``rigid=True``, which means that we have one unqiue signal to +describe the motion across the entire probe depth. + +.. code:: ipython3 + + motion = estimate_motion(lfprec, method='dredge_lfp', rigid=True, progress_bar=True) + motion + + +.. parsed-literal:: + + Online chunks [10.0s each]: 0%| | 0/87 [00:00 1e-5) + window_slice = slice(window_slice[0], window_slice[-1]) + if verbose: + print(f"Computing pairwise displacement: {i + 1} / {len(non_rigid_windows)}") + + pairwise_displacement, pairwise_displacement_weight = compute_pairwise_displacement( + motion_histogram[:, window_slice], + bin_um, + window=win[window_slice], + method=pairwise_displacement_method, + weight_scale=weight_scale, + error_sigma=error_sigma, + conv_engine=conv_engine, + torch_device=torch_device, + batch_size=batch_size, + max_displacement_um=max_displacement_um, + normalized_xcorr=normalized_xcorr, + centered_xcorr=centered_xcorr, + corr_threshold=corr_threshold, + time_horizon_s=time_horizon_s, + bin_s=bin_s, + progress_bar=False, + ) + + if spatial_prior: + all_pairwise_displacements[i] = pairwise_displacement + all_pairwise_displacement_weights[i] = pairwise_displacement_weight + + if extra is not None: + extra["pairwise_displacement_list"].append(pairwise_displacement) + + if verbose: + print(f"Computing global displacement: {i + 1} / {len(non_rigid_windows)}") + + # TODO: if spatial_prior, do this after the loop + if not spatial_prior: + motion_array[:, i] = compute_global_displacement( + pairwise_displacement, + pairwise_displacement_weight=pairwise_displacement_weight, + convergence_method=convergence_method, + robust_regression_sigma=robust_regression_sigma, + lsqr_robust_n_iter=lsqr_robust_n_iter, + temporal_prior=temporal_prior, + spatial_prior=spatial_prior, + soft_weights=soft_weights, + progress_bar=False, + ) + + if spatial_prior: + motion_array = compute_global_displacement( + all_pairwise_displacements, + pairwise_displacement_weight=all_pairwise_displacement_weights, + convergence_method=convergence_method, + robust_regression_sigma=robust_regression_sigma, + lsqr_robust_n_iter=lsqr_robust_n_iter, + temporal_prior=temporal_prior, + spatial_prior=spatial_prior, + soft_weights=soft_weights, + progress_bar=False, + ) + elif len(non_rigid_windows) > 1: + # if spatial_prior is False, we still want keep the spatial bins + # correctly offset from each other + if force_spatial_median_continuity: + for i in range(len(non_rigid_windows) - 1): + motion_array[:, i + 1] -= np.median(motion_array[:, i + 1] - motion_array[:, i]) + + # try to avoid constant offset + # let the user choose how to do this. here are some ideas. + # (one can also -= their own number on the result of this function.) + if reference_displacement == "mean": + motion_array -= motion_array.mean() + elif reference_displacement == "median": + motion_array -= np.median(motion_array) + elif reference_displacement == "time": + # reference the motion to 0 at a specific time, independently in each window + reference_displacement_bin = np.digitize(reference_displacement_time_s, temporal_hist_bin_edges) - 1 + motion_array -= motion_array[reference_displacement_bin, :] + elif reference_displacement == "mode_search": + # just a sketch of an idea + # things might want to change, should have a configurable bin size, + # should use a call to histogram instead of the loop, ... + step_size = 0.1 + round_mode = np.round # floor? + best_ref = np.median(motion_array) + max_zeros = np.sum(round_mode(motion_array - best_ref) == 0) + for ref in np.arange(np.floor(motion_array.min()), np.ceil(motion_array.max()), step_size): + n_zeros = np.sum(round_mode(motion_array - ref) == 0) + if n_zeros > max_zeros: + max_zeros = n_zeros + best_ref = ref + motion_array -= best_ref + + # replace nan by zeros + np.nan_to_num(motion_array, copy=False) + + motion = Motion([motion_array], [temporal_bins], non_rigid_window_centers, direction=direction) + + return motion + + +def compute_pairwise_displacement( + motion_hist, + bin_um, + method="conv", + weight_scale="linear", + error_sigma=0.2, + conv_engine="numpy", + torch_device=None, + batch_size=1, + max_displacement_um=1500, + corr_threshold=0, + time_horizon_s=None, + normalized_xcorr=True, + centered_xcorr=True, + bin_s=None, + progress_bar=False, + window=None, +): + """ + Compute pairwise displacement + """ + from scipy import linalg + + if conv_engine is None: + # use torch if installed + try: + import torch + + conv_engine = "torch" + except ImportError: + conv_engine = "numpy" + + if conv_engine == "torch": + import torch + + assert conv_engine in ("torch", "numpy"), f"'conv_engine' must be 'torch' or 'numpy'" + size = motion_hist.shape[0] + pairwise_displacement = np.zeros((size, size), dtype="float32") + + if time_horizon_s is not None: + band_width = int(np.ceil(time_horizon_s / bin_s)) + if band_width >= size: + time_horizon_s = None + + if conv_engine == "torch": + if torch_device is None: + torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if method == "conv": + if max_displacement_um is None: + n = motion_hist.shape[1] // 2 + else: + n = min( + motion_hist.shape[1] // 2, + int(np.ceil(max_displacement_um // bin_um)), + ) + possible_displacement = np.arange(-n, n + 1) * bin_um + + xrange = trange if progress_bar else range + + motion_hist_engine = motion_hist + window_engine = window + if conv_engine == "torch": + motion_hist_engine = torch.as_tensor(motion_hist, dtype=torch.float32, device=torch_device) + window_engine = torch.as_tensor(window, dtype=torch.float32, device=torch_device) + + pairwise_displacement = np.empty((size, size), dtype=np.float32) + correlation = np.empty((size, size), dtype=motion_hist.dtype) + + for i in xrange(0, size, batch_size): + corr = normxcorr1d( + motion_hist_engine, + motion_hist_engine[i : i + batch_size], + weights=window_engine, + padding=possible_displacement.size // 2, + conv_engine=conv_engine, + normalized=normalized_xcorr, + centered=centered_xcorr, + ) + if conv_engine == "torch": + max_corr, best_disp_inds = torch.max(corr, dim=2) + best_disp = possible_displacement[best_disp_inds.cpu()] + pairwise_displacement[i : i + batch_size] = best_disp + correlation[i : i + batch_size] = max_corr.cpu() + elif conv_engine == "numpy": + best_disp_inds = np.argmax(corr, axis=2) + max_corr = np.take_along_axis(corr, best_disp_inds[..., None], 2).squeeze() + best_disp = possible_displacement[best_disp_inds] + pairwise_displacement[i : i + batch_size] = best_disp + correlation[i : i + batch_size] = max_corr + + if corr_threshold is not None and corr_threshold > 0: + which = correlation > corr_threshold + correlation *= which + + elif method == "phase_cross_correlation": + # this 'phase_cross_correlation' is an old idea from Julien/Charlie/Erden that is kept for testing + # but this is not very releveant + try: + import skimage.registration + except ImportError: + raise ImportError("To use the 'phase_cross_correlation' method install scikit-image") + + errors = np.zeros((size, size), dtype="float32") + loop = range(size) + if progress_bar: + loop = tqdm(loop) + for i in loop: + for j in range(size): + shift, error, diffphase = skimage.registration.phase_cross_correlation( + motion_hist[i, :], motion_hist[j, :] + ) + pairwise_displacement[i, j] = shift * bin_um + errors[i, j] = error + correlation = 1 - errors + + else: + raise ValueError( + f"method {method} does not exist for compute_pairwise_displacement. Current possible methods are" + f" 'conv' or 'phase_cross_correlation'" + ) + + if weight_scale == "linear": + # between 0 and 1 + pairwise_displacement_weight = correlation + elif weight_scale == "exp": + pairwise_displacement_weight = np.exp((correlation - 1) / error_sigma) + + # handle the time horizon by multiplying the weights by a + # matrix with the time horizon on its diagonal bands. + if method == "conv" and time_horizon_s is not None and time_horizon_s > 0: + horizon_matrix = linalg.toeplitz( + np.r_[np.ones(band_width, dtype=bool), np.zeros(size - band_width, dtype=bool)] + ) + pairwise_displacement_weight *= horizon_matrix + + return pairwise_displacement, pairwise_displacement_weight + + +_possible_convergence_method = ("lsmr", "gradient_descent", "lsqr_robust") + + +def compute_global_displacement( + pairwise_displacement, + pairwise_displacement_weight=None, + sparse_mask=None, + temporal_prior=True, + spatial_prior=True, + soft_weights=False, + convergence_method="lsmr", + robust_regression_sigma=2, + lsqr_robust_n_iter=20, + progress_bar=False, +): + """ + Compute global displacement + + Arguments + --------- + pairwise_displacement : time x time array + pairwise_displacement_weight : time x time array + sparse_mask : time x time array + convergence_method : str + One of "gradient" + + """ + import scipy + from scipy.optimize import minimize + from scipy.sparse import csr_matrix + from scipy.sparse.linalg import lsqr + from scipy.stats import zscore + + if convergence_method == "gradient_descent": + size = pairwise_displacement.shape[0] + + D = pairwise_displacement + if pairwise_displacement_weight is not None or sparse_mask is not None: + # weighted problem + if pairwise_displacement_weight is None: + pairwise_displacement_weight = np.ones_like(D) + if sparse_mask is None: + sparse_mask = np.ones_like(D) + W = pairwise_displacement_weight * sparse_mask + + I, J = np.nonzero(W > 0) + Wij = W[I, J] + Dij = D[I, J] + W = csr_matrix((Wij, (I, J)), shape=W.shape) + WD = csr_matrix((Wij * Dij, (I, J)), shape=W.shape) + fixed_terms = (W @ WD).diagonal() - (WD @ W).diagonal() + diag_WW = (W @ W).diagonal() + Wsq = W.power(2) + + def obj(p): + return 0.5 * np.square(Wij * (Dij - (p[I] - p[J]))).sum() + + def jac(p): + return fixed_terms - 2 * (Wsq @ p) + 2 * p * diag_WW + + else: + # unweighted problem, it's faster when we have no weights + fixed_terms = -D.sum(axis=1) + D.sum(axis=0) + + def obj(p): + v = np.square((D - (p[:, None] - p[None, :]))).sum() + return 0.5 * v + + def jac(p): + return fixed_terms + 2 * (size * p - p.sum()) + + res = minimize(fun=obj, jac=jac, x0=D.mean(axis=1), method="L-BFGS-B") + if not res.success: + print("Global displacement gradient descent had an error") + displacement = res.x + + elif convergence_method == "lsqr_robust": + + if sparse_mask is not None: + I, J = np.nonzero(sparse_mask > 0) + elif pairwise_displacement_weight is not None: + I, J = pairwise_displacement_weight.nonzero() + else: + I, J = np.nonzero(np.ones_like(pairwise_displacement, dtype=bool)) + + nnz_ones = np.ones(I.shape[0], dtype=pairwise_displacement.dtype) + + if pairwise_displacement_weight is not None: + if isinstance(pairwise_displacement_weight, scipy.sparse.csr_matrix): + W = np.array(pairwise_displacement_weight[I, J]).T + else: + W = pairwise_displacement_weight[I, J][:, None] + else: + W = nnz_ones[:, None] + if isinstance(pairwise_displacement, scipy.sparse.csr_matrix): + V = np.array(pairwise_displacement[I, J])[0] + else: + V = pairwise_displacement[I, J] + M = csr_matrix((nnz_ones, (range(I.shape[0]), I)), shape=(I.shape[0], pairwise_displacement.shape[0])) + N = csr_matrix((nnz_ones, (range(I.shape[0]), J)), shape=(I.shape[0], pairwise_displacement.shape[0])) + A = M - N + idx = np.ones(A.shape[0], dtype=bool) + + # TODO: this is already soft_weights + xrange = trange if progress_bar else range + for i in xrange(lsqr_robust_n_iter): + p = lsqr(A[idx].multiply(W[idx]), V[idx] * W[idx][:, 0])[0] + idx = np.nonzero(np.abs(zscore(A @ p - V)) <= robust_regression_sigma) + displacement = p + + elif convergence_method == "lsmr": + import gc + from scipy import sparse + + D = pairwise_displacement + + # weighted problem + if pairwise_displacement_weight is None: + pairwise_displacement_weight = np.ones_like(D) + if sparse_mask is None: + sparse_mask = np.ones_like(D) + W = pairwise_displacement_weight * sparse_mask + if isinstance(W, scipy.sparse.csr_matrix): + W = W.astype(np.float32).toarray() + D = D.astype(np.float32).toarray() + + assert D.shape == W.shape + + # first dimension is the windows dim, which could be empty in rigid case + # we expand dims so that below we can consider only the nonrigid case + if D.ndim == 2: + W = W[None] + D = D[None] + assert D.ndim == W.ndim == 3 + B, T, T_ = D.shape + assert T == T_ + + # sparsify the problem + # we will make a list of temporal problems and then + # stack over the windows axis to finish. + # each matrix in coefficients will be (sparse_dim, T) + coefficients = [] + # each vector in targets will be (T,) + targets = [] + # we want to solve for a vector of shape BT, which we will reshape + # into a (B, T) matrix. + # after the loop below, we will stack a coefts matrix (sparse_dim, B, T) + # and a target vector of shape (B, T), both to be vectorized on last two axes, + # so that the target p is indexed by i = bT + t (block/window major). + + # calculate coefficients matrices and target vector + # this list stores boolean masks corresponding to whether or not each + # term comes from the prior or the likelihood. we can trim the likelihood terms, + # but not the prior terms, in the trimmed least squares (robust iters) iterations below. + cannot_trim = [] + for Wb, Db in zip(W, D): + # indices of active temporal pairs in this window + I, J = np.nonzero(Wb > 0) + n_sampled = I.size + + # construct Kroneckers and sparse objective in this window + pair_weights = np.ones(n_sampled) + if soft_weights: + pair_weights = Wb[I, J] + Mb = sparse.csr_matrix((pair_weights, (range(n_sampled), I)), shape=(n_sampled, T)) + Nb = sparse.csr_matrix((pair_weights, (range(n_sampled), J)), shape=(n_sampled, T)) + block_sparse_kron = Mb - Nb + block_disp_pairs = pair_weights * Db[I, J] + cannot_trim_block = np.ones_like(block_disp_pairs, dtype=bool) + + # add the temporal smoothness prior in this window + if temporal_prior: + temporal_diff_operator = sparse.diags( + ( + np.full(T - 1, -1, dtype=block_sparse_kron.dtype), + np.full(T - 1, 1, dtype=block_sparse_kron.dtype), + ), + offsets=(0, 1), + shape=(T - 1, T), + ) + block_sparse_kron = sparse.vstack( + (block_sparse_kron, temporal_diff_operator), + format="csr", + ) + block_disp_pairs = np.concatenate( + (block_disp_pairs, np.zeros(T - 1)), + ) + cannot_trim_block = np.concatenate( + (cannot_trim_block, np.zeros(T - 1, dtype=bool)), + ) + + coefficients.append(block_sparse_kron) + targets.append(block_disp_pairs) + cannot_trim.append(cannot_trim_block) + coefficients = sparse.block_diag(coefficients) + targets = np.concatenate(targets, axis=0) + cannot_trim = np.concatenate(cannot_trim, axis=0) + + # spatial smoothness prior: penalize difference of each block's + # displacement with the next. + # only if B > 1, and not in the last window. + # this is a (BT, BT) sparse matrix D such that: + # entry at (i, j) is: + # { 1 if i = j, i.e., i = j = bT + t for b = 0,...,B-2 + # { -1 if i = bT + t and j = (b+1)T + t for b = 0,...,B-2 + # { 0 otherwise. + # put more simply, the first (B-1)T diagonal entries are 1, + # and entries (i, j) such that i = j - T are -1. + if B > 1 and spatial_prior: + spatial_diff_operator = sparse.diags( + ( + np.ones((B - 1) * T, dtype=block_sparse_kron.dtype), + np.full((B - 1) * T, -1, dtype=block_sparse_kron.dtype), + ), + offsets=(0, T), + shape=((B - 1) * T, B * T), + ) + coefficients = sparse.vstack((coefficients, spatial_diff_operator)) + targets = np.concatenate((targets, np.zeros((B - 1) * T, dtype=targets.dtype))) + cannot_trim = np.concatenate((cannot_trim, np.zeros((B - 1) * T, dtype=bool))) + coefficients = coefficients.tocsr() + + # initialize at the column mean of pairwise displacements (in each window) + p0 = D.mean(axis=2).reshape(B * T) + + # use LSMR to solve the whole problem || targets - coefficients @ motion ||^2 + iters = range(max(1, lsqr_robust_n_iter)) + if progress_bar and lsqr_robust_n_iter > 1: + iters = tqdm(iters, desc="robust lsqr") + for it in iters: + # trim active set -- start with no trimming + idx = slice(None) + if it: + idx = np.flatnonzero( + cannot_trim | (np.abs(zscore(coefficients @ displacement - targets)) <= robust_regression_sigma) + ) + + # solve trimmed ols problem + displacement, *_ = sparse.linalg.lsmr(coefficients[idx], targets[idx], x0=p0) + + # warm start next iteration + p0 = displacement + # Cleanup lsmr memory (see https://stackoverflow.com/questions/56147713/memory-leak-in-scipy) + # TODO: check if this gets fixed in scipy + gc.collect() + + displacement = displacement.reshape(B, T).T + else: + raise ValueError( + f"Method {convergence_method} doesn't exist for compute_global_displacement" + f" possible values for 'convergence_method' are {_possible_convergence_method}" + ) + + return np.squeeze(displacement) + + +# normxcorr1d is now implemented in dredge +# we keep the old version here but this will be removed soon + +# def normxcorr1d( +# template, +# x, +# weights=None, +# centered=True, +# normalized=True, +# padding="same", +# conv_engine="torch", +# ): +# """normxcorr1d: Normalized cross-correlation, optionally weighted + +# The API is like torch's F.conv1d, except I have accidentally +# changed the position of input/weights -- template acts like weights, +# and x acts like input. + +# Returns the cross-correlation of `template` and `x` at spatial lags +# determined by `mode`. Useful for estimating the location of `template` +# within `x`. + +# This might not be the most efficient implementation -- ideas welcome. +# It uses a direct convolutional translation of the formula +# corr = (E[XY] - EX EY) / sqrt(var X * var Y) + +# This also supports weights! In that case, the usual adaptation of +# the above formula is made to the weighted case -- and all of the +# normalizations are done per block in the same way. + +# Parameters +# ---------- +# template : tensor, shape (num_templates, length) +# The reference template signal +# x : tensor, 1d shape (length,) or 2d shape (num_inputs, length) +# The signal in which to find `template` +# weights : tensor, shape (length,) +# Will use weighted means, variances, covariances if supplied. +# centered : bool +# If true, means will be subtracted (per weighted patch). +# normalized : bool +# If true, normalize by the variance (per weighted patch). +# padding : str +# How far to look? if unset, we'll use half the length +# conv_engine : string, one of "torch", "numpy" +# What library to use for computing cross-correlations. +# If numpy, falls back to the scipy correlate function. + +# Returns +# ------- +# corr : tensor +# """ +# if conv_engine == "torch": +# assert HAVE_TORCH +# conv1d = F.conv1d +# npx = torch +# elif conv_engine == "numpy": +# conv1d = scipy_conv1d +# npx = np +# else: +# raise ValueError(f"Unknown conv_engine {conv_engine}. 'conv_engine' must be 'torch' or 'numpy'") + +# x = npx.atleast_2d(x) +# num_templates, length = template.shape +# num_inputs, length_ = template.shape +# assert length == length_ + +# # generalize over weighted / unweighted case +# device_kw = {} if conv_engine == "numpy" else dict(device=x.device) +# ones = npx.ones((1, 1, length), dtype=x.dtype, **device_kw) +# no_weights = weights is None +# if no_weights: +# weights = ones +# wt = template[:, None, :] +# else: +# assert weights.shape == (length,) +# weights = weights[None, None] +# wt = template[:, None, :] * weights + +# # conv1d valid rule: +# # (B,1,L),(O,1,L)->(B,O,L) + +# # compute expectations +# # how many points in each window? seems necessary to normalize +# # for numerical stability. +# N = conv1d(ones, weights, padding=padding) +# if centered: +# Et = conv1d(ones, wt, padding=padding) +# Et /= N +# Ex = conv1d(x[:, None, :], weights, padding=padding) +# Ex /= N + +# # compute (weighted) covariance +# # important: the formula E[XY] - EX EY is well-suited here, +# # because the means are naturally subtracted correctly +# # patch-wise. you couldn't pre-subtract them! +# cov = conv1d(x[:, None, :], wt, padding=padding) +# cov /= N +# if centered: +# cov -= Ex * Et + +# # compute variances for denominator, using var X = E[X^2] - (EX)^2 +# if normalized: +# var_template = conv1d(ones, wt * template[:, None, :], padding=padding) +# var_template /= N +# var_x = conv1d(npx.square(x)[:, None, :], weights, padding=padding) +# var_x /= N +# if centered: +# var_template -= npx.square(Et) +# var_x -= npx.square(Ex) + +# # now find the final normxcorr +# corr = cov # renaming for clarity +# if normalized: +# corr /= npx.sqrt(var_x) +# corr /= npx.sqrt(var_template) +# # get rid of NaNs in zero-variance areas +# corr[~npx.isfinite(corr)] = 0 + +# return corr diff --git a/src/spikeinterface/sortingcomponents/motion/dredge.py b/src/spikeinterface/sortingcomponents/motion/dredge.py new file mode 100644 index 0000000000..a0dde6d52b --- /dev/null +++ b/src/spikeinterface/sortingcomponents/motion/dredge.py @@ -0,0 +1,1407 @@ +""" +Copy-paste and then refactoring of DREDge +https://github.com/evarol/dredge + +For historical reason, some function from the DREDge package where implemeneted +in spikeinterface in the motion_estimation.py before the DREDge package itself! + +Here a copy/paste (and small rewriting) of some functions from DREDge. + +The main entry for this function are still: + + * motion = estimate_motion((recording, ..., method='dredge_lfp') + * motion = estimate_motion((recording, ..., method='dredge_ap') < not Done yet + +but here the original functions from Charlie, Julien and Erdem have been ported for an +easier maintenance instead of making DREDge a dependency of spikeinterface. + +Some renaming has been done. Small details has been added. +But this code is very similar to the original code. +2 classes has been added : DredgeApRegistration and DredgeLfpRegistration +but the original function dredge_ap() and dredge_online_lfp() can be used directly. + +""" + +import warnings + +from tqdm.auto import trange +import numpy as np + +import gc + +from .motion_utils import ( + Motion, + get_spatial_windows, + get_window_domains, + scipy_conv1d, + make_2d_motion_histogram, + get_spatial_bin_edges, +) + + +# simple class wrapper to be compliant with estimate_motion +class DredgeApRegistration: + """ + Estimate motion from spikes times and depth. + + This the certified and official version of the dredge implementation. + + Method developed by the Paninski's group from Columbia university: + Charlie Windolf, Julien Boussard, Erdem Varol + + This method is quite similar to "decentralized" which was the previous implementation in spikeinterface. + + The reference is here https://www.biorxiv.org/content/10.1101/2023.10.24.563768v1 + + The original code were here : https://github.com/evarol/DREDge + But this code which use the same internal function is in line with the Motion object of spikeinterface contrary to the dredge repo. + + This code has been ported in spikeinterface (with simple copy/paste) by Samuel but main author is truely Charlie Windolf. + """ + + name = "dredge_ap" + need_peak_location = True + params_doc = """ + bin_um: float + Bin duration in second + bin_s : float + The size of the bins along depth in microns and along time in seconds. + The returned object's .displacement array will respect these bins. + Increasing these can lead to more stable estimates and faster runtimes + at the cost of spatial and/or temporal resolution. + max_disp_um : float + Maximum possible displacement in microns. If you can guess a number which is larger + than the largest displacement possible in your recording across a span of `time_horizon_s` + seconds, setting this value to that number can stabilize the result and speed up + the algorithm (since it can do less cross-correlating). + By default, this is set to win-scale_um / 4, or 112.5 microns. Which can be a bit + large! + time_horizon_s : float + "Time horizon" parameter, in seconds. Time bins separated by more seconds than this + will not be cross-correlated. So, if your data has nonstationarities or changes which + could lead to bad cross-correlations at some timescale, it can help to input that + value here. If this is too small, it can make the motion estimation unstable. + mincorr : float, between 0 and 1 + Correlation threshold. Pairs of frames whose maximal cross correlation value is smaller + than this threshold will be ignored when solving for the global displacement estimate. + thomas_kw, xcorr_kw, raster_kw, weights_kw + These dictionaries allow setting parameters for fine control over the registration + device : str or torch.device + What torch device to run on? E.g., "cpu" or "cuda" or "cuda:1". + """ + + @classmethod + def run( + cls, + recording, + peaks, + peak_locations, + direction, + rigid, + win_shape, + win_step_um, + win_scale_um, + win_margin_um, + verbose, + progress_bar, + extra, + **method_kwargs, + ): + + outs = dredge_ap( + recording, + peaks, + peak_locations, + direction=direction, + rigid=rigid, + win_shape=win_shape, + win_step_um=win_step_um, + win_scale_um=win_scale_um, + win_margin_um=win_margin_um, + extra_outputs=(extra is not None), + progress_bar=progress_bar, + **method_kwargs, + ) + + if extra is not None: + motion, extra_ = outs + extra.update(extra_) + else: + motion = outs + return motion + + +# @TODO : Charlie I started very small refactoring, I let you continue +def dredge_ap( + recording, + peaks, + peak_locations, + direction="y", + rigid=False, + # nonrigid window construction arguments + win_shape="gaussian", + win_step_um=400, + win_scale_um=450, + win_margin_um=None, + bin_um=1.0, + bin_s=1.0, + max_disp_um=None, + time_horizon_s=1000.0, + mincorr=0.1, + # weights arguments + do_window_weights=True, + weights_threshold_low=0.2, + weights_threshold_high=0.2, + mincorr_percentile=None, + mincorr_percentile_nneighbs=None, + # raster arguments + amp_scale_fn=None, ## @Charlie this one is not used anymore + post_transform=np.log1p, ###@this one is directly transimited to weight_correlation_matrix() and so get_wieiith() + histogram_depth_smooth_um=1, + histogram_time_smooth_s=1, + avg_in_bin=False, + # low-level keyword args + thomas_kw=None, + xcorr_kw=None, + # misc + device=None, + progress_bar=True, + extra_outputs=False, + precomputed_D_C_maxdisp=None, +): + """Estimate motion from spikes + + Spikes located at depths specified in `depths` along the probe, occurring at times in + seconds specified in `times` with amplitudes `amps` are used to create a 2d image of + the spiking activity. This image is cross-correlated with itself to produce a displacement + matrix (or several, one for each nonrigid window). This matrix is used to solve for a + motion estimate. + + Arguments + --------- + recording: BaseRecording + The recording extractor + peaks: numpy array + Peak vector (complex dtype). + Needed for decentralized and iterative_template methods. + peak_locations: numpy array + Complex dtype with "x", "y", "z" fields + Needed for decentralized and iterative_template methods. + direction : "x" | "y", default "y" + Dimension on which the motion is estimated. "y" is depth along the probe. + rigid : bool, default=False + If True, ignore the nonrigid window args (win_shape, win_step_um, win_scale_um, + win_margin_um) and do rigid registration (equivalent to one flat window, which + is how it's implemented). + win_shape : str, default="gaussian" + Nonrigid window shape + win_step_um : float + Spacing between nonrigid window centers in microns + win_scale_um : float + Controls the width of nonrigid windows centers + win_margin_um : float + Distance of nonrigid windows centers from the probe boundary (-1000 means there will + be no window center within 1000um of the edge of the probe) + {} + + Returns + ------- + motion : Motion + The motion object + extra : dict + This has extra info about what happened during registration, including the nonrigid + windows if one wants to visualize them. Set `extra_outputs` to also save displacement + and correlation matrices. + """ + + dim = ["x", "y", "z"].index(direction) + # @charlie: I removed amps/depths_um/times_s from the signature + # preaks and peak_locations are more SI compatible + # the way to get then + amps = peak_amplitudes = peaks["amplitude"] + depths_um = peak_depths = peak_locations[direction] + times_s = peak_times = recording.sample_index_to_time(peaks["sample_index"]) + + thomas_kw = thomas_kw if thomas_kw is not None else {} + xcorr_kw = xcorr_kw if xcorr_kw is not None else {} + if time_horizon_s: + xcorr_kw["max_dt_bins"] = np.ceil(time_horizon_s / bin_s) + + # TODO @charlie I think this is a bad to have the dict which is transported to every function + # this should be used only in histogram function but not in weight_correlation_matrix() + # only important kwargs should be explicitly reported + # raster_kw = dict( + # amp_scale_fn=amp_scale_fn, + # post_transform=post_transform, + # histogram_depth_smooth_um=histogram_depth_smooth_um, + # histogram_time_smooth_s=histogram_time_smooth_s, + # bin_s=bin_s, + # bin_um=bin_um, + # avg_in_bin=avg_in_bin, + # return_counts=count_masked_correlation, + # count_bins=count_bins, + # count_bin_min=count_bin_min, + # ) + + weights_kw = dict( + mincorr=mincorr, + time_horizon_s=time_horizon_s, + do_window_weights=do_window_weights, + weights_threshold_low=weights_threshold_low, + weights_threshold_high=weights_threshold_high, + ) + + # this will store return values other than the MotionEstimate + extra = {} + + # TODO charlie I switch this to make_2d_motion_histogram + # but we need to add all options from the original spike_raster() + # but I think this is OK + # raster_res = spike_raster( + # amps, + # depths_um, + # times_s, + # **raster_kw, + # ) + # if count_masked_correlation: + # raster, spatial_bin_edges_um, time_bin_edges_s, counts = raster_res + # else: + # raster, spatial_bin_edges_um, time_bin_edges_s = raster_res + + motion_histogram, time_bin_edges_s, spatial_bin_edges_um = make_2d_motion_histogram( + recording, + peaks, + peak_locations, + weight_with_amplitude=True, + avg_in_bin=avg_in_bin, + direction=direction, + bin_s=bin_s, + bin_um=bin_um, + hist_margin_um=0.0, # @charlie maybe we should expose this and set +20. for instance + spatial_bin_edges=None, + depth_smooth_um=histogram_depth_smooth_um, + time_smooth_s=histogram_time_smooth_s, + ) + raster = motion_histogram.T + + # TODO charlie : put the log for hitstogram + + # TODO @charlie you should check that we are doing the same thing + # windows, window_centers = get_spatial_windows( + # np.c_[np.zeros_like(spatial_bin_edges_um), spatial_bin_edges_um], + # win_step_um, + # win_scale_um, + # spatial_bin_edges=spatial_bin_edges_um, + # margin_um=-win_scale_um / 2 if win_margin_um is None else win_margin_um, + # win_shape=win_shape, + # zero_threshold=1e-5, + # rigid=rigid, + # ) + + dim = ["x", "y", "z"].index(direction) + contact_depths = recording.get_channel_locations()[:, dim] + spatial_bin_centers = 0.5 * (spatial_bin_edges_um[1:] + spatial_bin_edges_um[:-1]) + + windows, window_centers = get_spatial_windows( + contact_depths, + spatial_bin_centers, + rigid=rigid, + win_shape=win_shape, + win_step_um=win_step_um, + win_scale_um=win_scale_um, + win_margin_um=win_margin_um, + zero_threshold=1e-5, + ) + + # TODO charlie : the count has disapeared + # if extra_outputs and count_masked_correlation: + # extra["counts"] = counts + + # cross-correlate to get D and C + if precomputed_D_C_maxdisp is None: + Ds, Cs, max_disp_um = xcorr_windows( + raster, + windows, + spatial_bin_edges_um, + win_scale_um, + rigid=rigid, + bin_um=bin_um, + max_disp_um=max_disp_um, + progress_bar=progress_bar, + device=device, + # TODO charlie : put back the count for the mask + # masks=(counts > 0) if count_masked_correlation else None, + **xcorr_kw, + ) + else: + Ds, Cs, max_disp_um = precomputed_D_C_maxdisp + + # turn Cs into weights + Us, wextra = weight_correlation_matrix( + Ds, + Cs, + windows, + raster, + spatial_bin_edges_um, + time_bin_edges_s, + # raster_kw, #@charlie this is removed + post_transform=post_transform, # @charlie this isnew + lambda_t=thomas_kw.get("lambda_t", DEFAULT_LAMBDA_T), + eps=thomas_kw.get("eps", DEFAULT_EPS), + progress_bar=progress_bar, + in_place=not extra_outputs, + **weights_kw, + ) + extra.update({k: wextra[k] for k in wextra if k not in ("S", "U")}) + if extra_outputs: + extra.update({k: wextra[k] for k in wextra if k in ("S", "U")}) + del wextra + if extra_outputs: + extra["D"] = Ds + extra["C"] = Cs + del Cs + + # @charlie : is this needed ? + gc.collect() + + # solve for P + # now we can do our tridiag solve + displacement, textra = thomas_solve(Ds, Us, progress_bar=progress_bar, **thomas_kw) + if extra_outputs: + extra.update(textra) + del textra + + if extra_outputs: + extra["windows"] = windows + extra["window_centers"] = window_centers + extra["max_disp_um"] = max_disp_um + + time_bin_centers = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1]) + motion = Motion([displacement.T], [time_bin_centers], window_centers, direction=direction) + + if extra_outputs: + return motion, extra + else: + return motion + + +dredge_ap.__doc__ = dredge_ap.__doc__.format(DredgeApRegistration.params_doc) + + +# simple class wrapper to be compliant with estimate_motion +class DredgeLfpRegistration: + """ + Estimate motion from LFP recording. + + This the certified and official version of the dredge implementation. + + Method developed by the Paninski's group from Columbia university: + Charlie Windolf, Julien Boussard, Erdem Varol + + The reference is here https://www.biorxiv.org/content/10.1101/2023.10.24.563768v1 + """ + + name = "dredge_lfp" + need_peak_location = False + params_doc = """ + lfp_recording : spikeinterface BaseRecording object + Preprocessed LFP recording. The temporal resolution of this recording will + be the target resolution of the registration, so definitely use SpikeInterface + to resample your recording to, say, 250Hz (or a value you like) rather than + estimating motion at the original frequency (which may be high). + direction : "x" | "y", default "y" + Dimension on which the motion is estimated. "y" is depth along the probe. + rigid : boolean, optional + If True, window-related arguments are ignored and we do rigid registration + win_shape, win_step_um, win_scale_um, win_margin_um : float + Nonrigid window-related arguments + The depth domain will be broken up into windows with shape controlled by win_shape, + spaced by win_step_um at a margin of win_margin_um from the boundary, and with + width controlled by win_scale_um. + chunk_len_s : float + Length of chunks (in seconds) that the recording is broken into for online + registration. The computational speed of the method is a function of the + number of samples this corresponds to, and things can get slow if it is + set high enough that the number of samples per chunk is bigger than ~10,000. + But, it can't be set too low or the algorithm doesn't have enough data + to work with. The default is set assuming sampling rate of 250Hz, leading + to 2500 samples per chunk. + time_horizon_s : float + Time-bins farther apart than this value in seconds will not be cross-correlated. + Set this to at least `chunk_len_s`. + max_disp_um : number, optional + This is the ceiling on the possible displacement estimates. It should be + set to a number which is larger than the allowed displacement in a single + chunk. Setting it as small as possible (while following that rule) can speed + things up and improve the result by making it impossible to estimate motion + which is too big. + mincorr : float in [0,1] + Minimum correlation between pairs of frames such that they will be included + in the optimization of the displacement estimates. + mincorr_percentile, mincorr_percentile_nneighbs + If mincorr_percentile is set to a number in [0, 100], then mincorr will be replaced + by this percentile of the correlations of neighbors within mincorr_percentile_nneighbs + time bins of each other. + device : string or torch.device + Controls torch device + """ + + @classmethod + def run( + cls, + recording, + peaks, + peak_locations, + direction, + rigid, + win_shape, + win_step_um, + win_scale_um, + win_margin_um, + verbose, + progress_bar, + extra, + **method_kwargs, + ): + # Note peaks and peak_locations are not used and can be None + + outs = dredge_online_lfp( + recording, + direction=direction, + rigid=rigid, + win_shape=win_shape, + win_step_um=win_step_um, + win_scale_um=win_scale_um, + win_margin_um=win_margin_um, + extra_outputs=(extra is not None), + progress_bar=progress_bar, + **method_kwargs, + ) + + if extra is not None: + motion, extra_ = outs + extra.update(extra_) + else: + motion = outs + return motion + + +def dredge_online_lfp( + lfp_recording, + direction="y", + # nonrigid window construction arguments + rigid=True, + win_shape="gaussian", + win_step_um=800, + win_scale_um=850, + win_margin_um=None, + chunk_len_s=10.0, + max_disp_um=500, + time_horizon_s=None, + # weighting arguments + mincorr=0.8, + mincorr_percentile=None, + mincorr_percentile_nneighbs=20, + soft=False, + # low-level arguments + thomas_kw=None, + xcorr_kw=None, + # misc + extra_outputs=False, + device=None, + progress_bar=True, +): + """Online registration of a preprocessed LFP recording + + Arguments + --------- + {} + + Returns + ------- + motion : Motion + A motion object. + extra : dict + Dict containing extra info for debugging + """ + dim = ["x", "y", "z"].index(direction) + # contact pos is the only on the direction + contact_depths = lfp_recording.get_channel_locations()[:, dim] + + fs = lfp_recording.get_sampling_frequency() + T_total = lfp_recording.get_num_samples() + T_chunk = min(int(np.floor(fs * chunk_len_s)), T_total) + + # kwarg defaults and handling + # need lfp-specific defaults + xcorr_kw = xcorr_kw if xcorr_kw is not None else {} + thomas_kw = thomas_kw if thomas_kw is not None else {} + full_xcorr_kw = dict( + rigid=rigid, + bin_um=np.median(np.diff(contact_depths)), + max_disp_um=max_disp_um, + progress_bar=False, + device=device, + **xcorr_kw, + ) + threshold_kw = dict( + mincorr_percentile_nneighbs=mincorr_percentile_nneighbs, + in_place=True, + soft=soft, + # time_horizon_s=weights_kw["time_horizon_s"], # max_dt not implemented for lfp at this point + time_horizon_s=time_horizon_s, + bin_s=1 / fs, # only relevant for time_horizon_s + ) + + # here we check that contact positons are unique on the direction + if contact_depths.size != np.unique(contact_depths).size: + raise ValueError( + f"estimate motion with 'dredge_lfp' need channel_positions to be unique in the direction='{direction}'" + ) + if np.any(np.diff(contact_depths) < 0): + raise ValueError( + f"estimate motion with 'dredge_lfp' need channel_positions to be ordered direction='{direction}'" + "please use spikeinterface.preprocessing.depth_order(recording)" + ) + + # Important detail : in LFP bin center are contact position in the direction + spatial_bin_centers = contact_depths + + windows, window_centers = get_spatial_windows( + contact_depths=contact_depths, + spatial_bin_centers=spatial_bin_centers, + rigid=rigid, + win_margin_um=win_margin_um, + win_step_um=win_step_um, + win_scale_um=win_scale_um, + win_shape=win_shape, + zero_threshold=1e-5, + ) + + B = len(windows) + + if extra_outputs: + extra = dict(window_centers=window_centers, windows=windows) + + # -- allocate output and initialize first chunk + P_online = np.empty((B, T_total), dtype=np.float32) + # below, t0 is start of prev chunk, t1 start of cur chunk, t2 end of cur + t0, t1 = 0, T_chunk + traces0 = lfp_recording.get_traces(start_frame=t0, end_frame=t1) + Ds0, Cs0, max_disp_um = xcorr_windows(traces0.T, windows, contact_depths, win_scale_um, **full_xcorr_kw) + full_xcorr_kw["max_disp_um"] = max_disp_um + Ss0, mincorr0 = threshold_correlation_matrix( + Cs0, + mincorr=mincorr, + mincorr_percentile=mincorr_percentile, + **threshold_kw, + ) + if extra_outputs: + extra["D"] = [Ds0] + extra["C"] = [Cs0] + extra["S"] = [Ss0] + extra["D01"] = [] + extra["C01"] = [] + extra["S01"] = [] + extra["mincorrs"] = [mincorr0] + extra["max_disp_um"] = max_disp_um + + P_online[:, t0:t1], _ = thomas_solve(Ds0, Ss0, **thomas_kw) + + # -- loop through chunks + chunk_starts = range(T_chunk, T_total, T_chunk) + if progress_bar: + chunk_starts = trange( + T_chunk, + T_total, + T_chunk, + desc=f"Online chunks [{chunk_len_s}s each]", + ) + for t1 in chunk_starts: + t2 = min(T_total, t1 + T_chunk) + traces1 = lfp_recording.get_traces(start_frame=t1, end_frame=t2) + + # cross-correlations between prev/cur chunks + # these are T1, T0 shaped + Ds10, Cs10, _ = xcorr_windows( + traces1.T, + windows, + contact_depths, + win_scale_um, + raster_b=traces0.T, + **full_xcorr_kw, + ) + + # cross-correlation in current chunk + Ds1, Cs1, _ = xcorr_windows(traces1.T, windows, contact_depths, win_scale_um, **full_xcorr_kw) + Ss1, mincorr1 = threshold_correlation_matrix( + Cs1, + mincorr_percentile=mincorr_percentile, + mincorr=mincorr, + **threshold_kw, + ) + Ss10, _ = threshold_correlation_matrix(Cs10, mincorr=mincorr1, t_offset_bins=T_chunk, **threshold_kw) + + if extra_outputs: + extra["mincorrs"].append(mincorr1) + extra["D"].append(Ds1) + extra["C"].append(Cs1) + extra["S"].append(Ss1) + extra["D01"].append(Ds10) + extra["C01"].append(Cs10) + extra["S01"].append(Ss10) + + # solve online problem + P_online[:, t1:t2], _ = thomas_solve( + Ds1, + Ss1, + P_prev=P_online[:, t0:t1], + Ds_curprev=Ds10, + Us_curprev=Ss10, + Ds_prevcur=-Ds10.transpose(0, 2, 1), + Us_prevcur=Ss10.transpose(0, 2, 1), + **thomas_kw, + ) + + # update loop vars + t0, t1 = t1, t2 + traces0 = traces1 + + motion = Motion([P_online.T], [lfp_recording.get_times(0)], window_centers, direction=direction) + + if extra_outputs: + return motion, extra + else: + return motion + + +dredge_online_lfp.__doc__ = dredge_online_lfp.__doc__.format(DredgeLfpRegistration.params_doc) + + +# -- functions from dredgelib (zone forbiden for sam) + +DEFAULT_LAMBDA_T = 1.0 +DEFAULT_EPS = 1e-3 + +# -- linear algebra, Newton method solver, block tridiagonal (Thomas) solver + + +def laplacian(n, wink=True, eps=DEFAULT_EPS, lambd=1.0, ridge_mask=None): + """Construct a discrete Laplacian operator (plus eps*identity).""" + lap = np.zeros((n, n)) + if ridge_mask is None: + diag = lambd + eps + else: + diag = lambd + eps * ridge_mask + np.fill_diagonal(lap, diag) + if wink: + lap[0, 0] -= 0.5 * lambd + lap[-1, -1] -= 0.5 * lambd + # fill diagonal using a for loop for space reasons when this is large + for i in range(n - 1): + lap[i, i + 1] -= 0.5 * lambd + lap[i + 1, i] -= 0.5 * lambd + return lap + + +def neg_hessian_likelihood_term(Ub, Ub_prevcur=None, Ub_curprev=None): + """Newton step coefficients + + The negative Hessian of the non-regularized cost function inside a nonrigid block. + Together with the term arising from the regularization, this constructs the + coefficients matrix in our linear problem. + """ + negHUb = -Ub - Ub.T + diagonal_terms = np.diagonal(negHUb) + Ub.sum(1) + Ub.sum(0) + if Ub_prevcur is None: + np.fill_diagonal(negHUb, diagonal_terms) + else: + diagonal_terms += Ub_prevcur.sum(0) + Ub_curprev.sum(1) + np.fill_diagonal(negHUb, diagonal_terms) + return negHUb + + +def newton_rhs( + Db, + Ub, + Pb_prev=None, + Db_prevcur=None, + Ub_prevcur=None, + Db_curprev=None, + Ub_curprev=None, +): + """Newton step right hand side + + The gradient at P=0 of the cost function, which is the right hand side of Newton's method. + """ + UDb = Ub * Db + grad_at_0 = UDb.sum(1) - UDb.sum(0) + + # batch case + if Pb_prev is None: + return grad_at_0 + + # online case + align_term = (Ub_prevcur.T + Ub_curprev) @ Pb_prev + rhs = align_term + grad_at_0 + (Ub_curprev * Db_curprev).sum(1) - (Ub_prevcur * Db_prevcur).sum(0) + + return rhs + + +def newton_solve_rigid( + D, + U, + Sigma0inv, + Pb_prev=None, + Db_prevcur=None, + Ub_prevcur=None, + Db_curprev=None, + Ub_curprev=None, +): + """Solve the rigid Newton step + + D is TxT displacement, U is TxT subsampling or soft weights matrix. + """ + from scipy.linalg import solve, lstsq + + negHU = neg_hessian_likelihood_term( + U, + Ub_prevcur=Ub_prevcur, + Ub_curprev=Ub_curprev, + ) + targ = newton_rhs( + D, + U, + Pb_prev=Pb_prev, + Db_prevcur=Db_prevcur, + Ub_prevcur=Ub_prevcur, + Db_curprev=Db_curprev, + Ub_curprev=Ub_curprev, + ) + try: + p = solve(Sigma0inv + negHU, targ, assume_a="pos") + except np.linalg.LinAlgError: + warnings.warn("Singular problem, using least squares.") + p, *_ = lstsq(Sigma0inv + negHU, targ) + return p, negHU + + +def thomas_solve( + Ds, + Us, + lambda_t=DEFAULT_LAMBDA_T, + lambda_s=1.0, + eps=DEFAULT_EPS, + P_prev=None, + Ds_prevcur=None, + Us_prevcur=None, + Ds_curprev=None, + Us_curprev=None, + progress_bar=False, + bandwidth=None, +): + """Block tridiagonal algorithm, special cased to our setting + + This code solves for the displacement estimates across the nonrigid windows, + given blockwise, pairwise (BxTxT) displacement and weights arrays `Ds` and `Us`. + + If `lambda_t>0`, a temporal prior is applied to "fill the gaps", effectively + interpolating through time to avoid artifacts in low-signal areas. Setting this + to 0 can lead to numerical warnings and should be done with care. + + If `lambda_s>0`, a spatial prior is applied. This can help fill gaps more + meaningfully in the nonrigid case, using information from the neighboring nonrigid + windows to inform the estimate in an untrusted region of a given window. + + If arguments `P_prev,Ds_prevcur,Us_prevcur` are supplied, this code handles the + online case. The return value will be the new chunk's displacement estimate, + solving the online registration problem. + """ + from scipy.linalg import solve + + Ds = np.asarray(Ds, dtype=np.float64) + Us = np.asarray(Us, dtype=np.float64) + online = P_prev is not None + online_kw_rhs = online_kw_hess = lambda b: {} + if online: + assert Ds_prevcur is not None + assert Us_prevcur is not None + online_kw_rhs = lambda b: dict( # noqa + Pb_prev=P_prev[b].astype(np.float64, copy=False), + Db_prevcur=Ds_prevcur[b].astype(np.float64, copy=False), + Ub_prevcur=Us_prevcur[b].astype(np.float64, copy=False), + Db_curprev=Ds_curprev[b].astype(np.float64, copy=False), + Ub_curprev=Us_curprev[b].astype(np.float64, copy=False), + ) + online_kw_hess = lambda b: dict( # noqa + Ub_prevcur=Us_prevcur[b].astype(np.float64, copy=False), + Ub_curprev=Us_curprev[b].astype(np.float64, copy=False), + ) + + B, T, T_ = Ds.shape + assert T == T_ + assert Us.shape == Ds.shape + + # figure out which temporal bins are included in the problem + # these are used to figure out where epsilon can be added + # for numerical stability without changing the solution + had_weights = (Us > 0).any(axis=2) + had_weights[~had_weights.any(axis=1)] = 1 + + # temporal prior matrix + L_t = [laplacian(T, eps=eps, lambd=lambda_t, ridge_mask=w) for w in had_weights] + extra = dict(L_t=L_t) + + # just solve independent problems when there's no spatial regularization + # not that there's much overhead to the backward pass etc but might as well + if B == 1 or lambda_s == 0: + P = np.zeros((B, T)) + extra["HU"] = np.zeros((B, T, T)) + for b in range(B): + P[b], extra["HU"][b] = newton_solve_rigid(Ds[b], Us[b], L_t[b], **online_kw_rhs(b)) + return P, extra + + # spatial prior is a sparse, block tridiagonal kronecker product + # the first and last diagonal blocks are + Lambda_s_diagb = laplacian(T, eps=eps, lambd=lambda_s / 2, ridge_mask=had_weights[0]) + # and the off-diagonal blocks are + Lambda_s_offdiag = laplacian(T, eps=0, lambd=-lambda_s / 2) + + # initialize block-LU stuff and forward variable + alpha_hat_b = L_t[0] + Lambda_s_diagb + neg_hessian_likelihood_term(Us[0], **online_kw_hess(0)) + targets = np.c_[Lambda_s_offdiag, newton_rhs(Us[0], Ds[0], **online_kw_rhs(0))] + res = solve(alpha_hat_b, targets, assume_a="pos") + assert res.shape == (T, T + 1) + gamma_hats = [res[:, :T]] + ys = [res[:, T]] + + # forward pass + for b in trange(1, B, desc="Solve") if progress_bar else range(1, B): + if b < B - 1: + Lambda_s_diagb = laplacian(T, eps=eps, lambd=lambda_s, ridge_mask=had_weights[b]) + else: + Lambda_s_diagb = laplacian(T, eps=eps, lambd=lambda_s / 2, ridge_mask=had_weights[b]) + + Ab = L_t[b] + Lambda_s_diagb + neg_hessian_likelihood_term(Us[b], **online_kw_hess(b)) + alpha_hat_b = Ab - Lambda_s_offdiag @ gamma_hats[b - 1] + targets[:, T] = newton_rhs(Us[b], Ds[b], **online_kw_rhs(b)) + targets[:, T] -= Lambda_s_offdiag @ ys[b - 1] + res = solve(alpha_hat_b, targets) + assert res.shape == (T, T + 1) + gamma_hats.append(res[:, :T]) + ys.append(res[:, T]) + + # back substitution + xs = [None] * B + xs[-1] = ys[-1] + for b in range(B - 2, -1, -1): + xs[b] = ys[b] - gamma_hats[b] @ xs[b + 1] + + # un-vectorize + P = np.concatenate(xs).reshape(B, T) + + return P, extra + + +def threshold_correlation_matrix( + Cs, + mincorr=0.0, + mincorr_percentile=None, + mincorr_percentile_nneighbs=20, + time_horizon_s=0, + in_place=False, + bin_s=1, + t_offset_bins=None, + T=None, + soft=True, +): + if mincorr_percentile is not None: + diags = [np.diagonal(Cs, offset=j, axis1=1, axis2=2).ravel() for j in range(1, mincorr_percentile_nneighbs)] + mincorr = np.percentile( + np.concatenate(diags), + mincorr_percentile, + ) + + # need abs to avoid -0.0s which cause numerical issues + if in_place: + Ss = Cs + if soft: + Ss[Ss < mincorr] = 0 + else: + Ss = (Ss >= mincorr).astype(Cs.dtype) + np.square(Ss, out=Ss) + else: + if soft: + Ss = np.square((Cs >= mincorr) * Cs) + else: + Ss = (Cs >= mincorr).astype(Cs.dtype) + if time_horizon_s is not None and time_horizon_s > 0 and T is not None and time_horizon_s < T: + tt0 = bin_s * np.arange(T) + tt1 = tt0 + if t_offset_bins: + tt1 = tt0 + t_offset_bins + dt = tt1[:, None] - tt0[None, :] + mask = (np.abs(dt) <= time_horizon_s).astype(Ss.dtype) + Ss *= mask[None] + return Ss, mincorr + + +def xcorr_windows( + raster_a, + windows, + spatial_bin_edges_um, + win_scale_um, + raster_b=None, + rigid=False, + bin_um=1, + max_disp_um=None, + max_dt_bins=None, + progress_bar=True, + centered=True, + normalized=True, + masks=None, + device=None, +): + """Main computational function + + Compute pairwise (time x time) maximum cross-correlation and displacement + matrices in each nonrigid window. + """ + import torch + + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if max_disp_um is None: + if rigid: + max_disp_um = int(spatial_bin_edges_um.ptp() // 4) + else: + max_disp_um = int(win_scale_um // 4) + + max_disp_bins = int(max_disp_um // bin_um) + slices = get_window_domains(windows) + B, D = windows.shape + D_, T0 = raster_a.shape + + assert D == D_ + + # torch versions on device + windows_ = torch.as_tensor(windows, dtype=torch.float, device=device) + raster_a_ = torch.as_tensor(raster_a, dtype=torch.float, device=device) + if raster_b is not None: + assert raster_b.shape[0] == D + T1 = raster_b.shape[1] + raster_b_ = torch.as_tensor(raster_b, dtype=torch.float, device=device) + else: + T1 = T0 + raster_b_ = raster_a_ + if masks is not None: + masks = torch.as_tensor(masks, dtype=torch.float, device=device) + + # estimate each window's displacement + Ds = np.zeros((B, T0, T1), dtype=np.float32) + Cs = np.zeros((B, T0, T1), dtype=np.float32) + block_iter = trange(B, desc="Cross correlation") if progress_bar else range(B) + for b in block_iter: + window = windows_[b] + + # we search for the template (windowed part of raster a) + # within a larger-than-the-window neighborhood in raster b + targ_low = slices[b].start - max_disp_bins + b_low = max(0, targ_low) + targ_high = slices[b].stop + max_disp_bins + b_high = min(D, targ_high) + padding = max(b_low - targ_low, targ_high - b_high) + + # arithmetic to compute the lags in um corresponding to + # corr argmaxes + n_left = padding + slices[b].start - b_low + n_right = padding + b_high - slices[b].stop + poss_disp = -np.arange(-n_left, n_right + 1) * bin_um + + Ds[b], Cs[b] = calc_corr_decent_pair( + raster_a_[slices[b]], + raster_b_[b_low:b_high], + weights=window[slices[b]], + masks=None if masks is None else masks[slices[b]], + xmasks=None if masks is None else masks[b_low:b_high], + disp=padding, + possible_displacement=poss_disp, + device=device, + centered=centered, + normalized=normalized, + max_dt_bins=max_dt_bins, + ) + + return Ds, Cs, max_disp_um + + +def calc_corr_decent_pair( + raster_a, + raster_b, + weights=None, + masks=None, + xmasks=None, + disp=None, + batch_size=512, + normalized=True, + centered=True, + possible_displacement=None, + max_dt_bins=None, + device=None, +): + """Weighted pairwise cross-correlation + + Calculate TxT normalized xcorr and best displacement matrices + Given a DxT raster, this computes normalized cross correlations for + all pairs of time bins at offsets in the range [-disp, disp], by + increments of step_size. Then it finds the best one and its + corresponding displacement, resulting in two TxT matrices: one for + the normxcorrs at the best displacement, and the matrix of the best + displacements. + + Arguments + --------- + raster : DxT array + batch_size : int + How many raster rows to xcorr against the whole raster + at once. + step_size : int + Displacement increment. Not implemented yet but easy to do. + disp : int + Maximum displacement + device : torch device + Returns: D, C: TxT arrays + """ + import torch + + D, Ta = raster_a.shape + D_, Tb = raster_b.shape + + # sensible default: at most half the domain. + if disp is None: + disp == D // 2 + + # range of displacements + if D == D_: + if possible_displacement is None: + possible_displacement = np.arange(-disp, disp + 1) + else: + assert possible_displacement is not None + assert disp is not None + + # pick torch device if unset + if device is None: + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + # process rasters into the tensors we need for conv2ds below + # convert to TxD device floats + raster_a = torch.as_tensor(raster_a.T, dtype=torch.float32, device=device) + # normalize over depth for normalized (uncentered) xcorrs + raster_b = torch.as_tensor(raster_b.T, dtype=torch.float32, device=device) + + D = np.zeros((Ta, Tb), dtype=np.float32) + C = np.zeros((Ta, Tb), dtype=np.float32) + for i in range(0, Ta, batch_size): + for j in range(0, Tb, batch_size): + dt_bins = min(abs(i - j), abs(i + batch_size - j), abs(i - j - batch_size)) + if max_dt_bins and dt_bins > max_dt_bins: + continue + weights_ = weights + if masks is not None: + weights_ = masks.T[i : i + batch_size] * weights + corr = normxcorr1d( + raster_a[i : i + batch_size], + raster_b[j : j + batch_size], + weights=weights_, + xmasks=None if xmasks is None else xmasks.T[j : j + batch_size], + padding=disp, + normalized=normalized, + centered=centered, + ) + max_corr, best_disp_inds = torch.max(corr, dim=2) + best_disp = possible_displacement[best_disp_inds.cpu()] + D[i : i + batch_size, j : j + batch_size] = best_disp.T + C[i : i + batch_size, j : j + batch_size] = max_corr.cpu().T + + return D, C + + +def normxcorr1d( + template, + x, + weights=None, + xmasks=None, + centered=True, + normalized=True, + padding="same", + conv_engine="torch", +): + """ + normxcorr1d: Normalized cross-correlation, optionally weighted + + The API is like torch's F.conv1d, except I have accidentally + changed the position of input/weights -- template acts like weights, + and x acts like input. + + Returns the cross-correlation of `template` and `x` at spatial lags + determined by `mode`. Useful for estimating the location of `template` + within `x`. + + This might not be the most efficient implementation -- ideas welcome. + It uses a direct convolutional translation of the formula + corr = (E[XY] - EX EY) / sqrt(var X * var Y) + + This also supports weights! In that case, the usual adaptation of + the above formula is made to the weighted case -- and all of the + normalizations are done per block in the same way. + + Parameters + ---------- + template : tensor, shape (num_templates, length) + The reference template signal + x : tensor, 1d shape (length,) or 2d shape (num_inputs, length) + The signal in which to find `template` + weights : tensor, shape (length,) + Will use weighted means, variances, covariances if supplied. + centered : bool + If true, means will be subtracted (per weighted patch). + normalized : bool + If true, normalize by the variance (per weighted patch). + padding : int, optional + How far to look? if unset, we'll use half the length + conv_engine : "torch" | "numpy" + What library to use for computing cross-correlations. + If numpy, falls back to the scipy correlate function. + + Returns + ------- + corr : tensor + """ + + if conv_engine == "torch": + import torch + import torch.nn.functional as F + + conv1d = F.conv1d + npx = torch + elif conv_engine == "numpy": + conv1d = scipy_conv1d + npx = np + else: + raise ValueError(f"Unknown conv_engine {conv_engine}") + + x = npx.atleast_2d(x) + num_templates, lengtht = template.shape + num_inputs, lengthx = x.shape + + # generalize over weighted / unweighted case + device_kw = {} if conv_engine == "numpy" else dict(device=x.device) + if xmasks is None: + onesx = npx.ones((1, 1, lengthx), dtype=x.dtype, **device_kw) + wx = x[:, None, :] + else: + assert xmasks.shape == x.shape + onesx = xmasks[:, None, :] + wx = x[:, None, :] * onesx + no_weights = weights is None + if no_weights: + weights = npx.ones((1, 1, lengtht), dtype=x.dtype, **device_kw) + wt = template[:, None, :] + else: + if weights.shape == (lengtht,): + weights = weights[None, None] + elif weights.shape == (num_templates, lengtht): + weights = weights[:, None, :] + else: + assert False + wt = template[:, None, :] * weights + x = x[:, None, :] + template = template[:, None, :] + + # conv1d valid rule: + # (B,1,L),(O,1,L)->(B,O,L) + # below, we always put x on the LHS, templates on the RHS, so this reads + # (num_inputs, 1, lengthx), (num_templates, 1, lengtht) -> (num_inputs, num_templates, length_out) + + # compute expectations + # how many points in each window? seems necessary to normalize + # for numerical stability. + Nx = conv1d(onesx, weights, padding=padding) # 1,nt,l + empty = Nx == 0 + Nx[empty] = 1 + if centered: + Et = conv1d(onesx, wt, padding=padding) # 1,nt,l + Et /= Nx + Ex = conv1d(wx, weights, padding=padding) # nx,nt,l + Ex /= Nx + + # compute (weighted) covariance + # important: the formula E[XY] - EX EY is well-suited here, + # because the means are naturally subtracted correctly + # patch-wise. you couldn't pre-subtract them! + cov = conv1d(wx, wt, padding=padding) + cov /= Nx + if centered: + cov -= Ex * Et + + # compute variances for denominator, using var X = E[X^2] - (EX)^2 + if normalized: + var_template = conv1d(onesx, wt * template, padding=padding) + var_template /= Nx + var_x = conv1d(wx * x, weights, padding=padding) + var_x /= Nx + if centered: + var_template -= npx.square(Et) + var_x -= npx.square(Ex) + + # fill in zeros to avoid problems when dividing + var_template[var_template <= 0] = 1 + var_x[var_x <= 0] = 1 + + # now find the final normxcorr + corr = cov # renaming for clarity + if normalized: + corr[npx.broadcast_to(empty, corr.shape)] = 0 + corr /= npx.sqrt(var_x) + corr /= npx.sqrt(var_template) + + return corr + + +def get_weights( + Ds, + Ss, + Sigma0inv_t, + windows, + raster, + dbe, + tbe, + # @charlie raster_kw is removed in favor of post_transform only is this OK ??? + # raster_kw, + post_transform=np.log1p, + weights_threshold_low=0.0, + weights_threshold_high=np.inf, + progress_bar=False, +): + """Compute per-time-bin weighting for each nonrigid window""" + # determine window-weighted raster "heat" in each nonrigid window + # as a function of time + assert windows.shape[1] == dbe.size - 1 + weights = [] + p_inds = [] + for b in range((len(Ds))): + ilow, ihigh = np.flatnonzero(windows[b])[[0, -1]] + ihigh += 1 + window_sliced = windows[b, ilow:ihigh] + weights.append(window_sliced @ raster[ilow:ihigh]) + weights_orig = np.array(weights) + + # scale_fn = raster_kw["post_transform"] or raster_kw["amp_scale_fn"] + scale_fn = post_transform + if isinstance(weights_threshold_low, tuple): + nspikes_threshold_low, amp_threshold_low = weights_threshold_low + unif = np.full_like(windows[0], 1 / len(windows[0])) + weights_threshold_low = scale_fn(amp_threshold_low) * windows @ (nspikes_threshold_low * unif) + weights_threshold_low = weights_threshold_low[:, None] + if isinstance(weights_threshold_high, tuple): + nspikes_threshold_high, amp_threshold_high = weights_threshold_high + unif = np.full_like(windows[0], 1 / len(windows[0])) + weights_threshold_high = scale_fn(amp_threshold_high) * windows @ (nspikes_threshold_high * unif) + weights_threshold_high = weights_threshold_high[:, None] + weights_thresh = weights_orig.copy() + weights_thresh[weights_orig < weights_threshold_low] = 0 + weights_thresh[weights_orig > weights_threshold_high] = np.inf + + return weights, weights_thresh, p_inds + + +def weight_correlation_matrix( + Ds, + Cs, + windows, + raster, + depth_bin_edges, + time_bin_edges, + # @charlie raster_kw is remove in favor of post_transform only + # raster_kw, + post_transform=np.log1p, + mincorr=0.0, + mincorr_percentile=None, + mincorr_percentile_nneighbs=20, + time_horizon_s=None, + lambda_t=DEFAULT_LAMBDA_T, + eps=DEFAULT_EPS, + do_window_weights=True, + weights_threshold_low=0.0, + weights_threshold_high=np.inf, + progress_bar=True, + in_place=False, +): + """Transform the correlation matrix into the weights used in optimization.""" + extra = {} + + Ds = np.asarray(Ds) + Cs = np.asarray(Cs) + if Ds.ndim == 2: + Ds = Ds[None] + Cs = Cs[None] + B, T, T_ = Ds.shape + assert T == T_ + assert Ds.shape == Cs.shape + extra = {} + + Ss, mincorr = threshold_correlation_matrix( + Cs, + mincorr=mincorr, + mincorr_percentile=mincorr_percentile, + mincorr_percentile_nneighbs=mincorr_percentile_nneighbs, + time_horizon_s=time_horizon_s, + bin_s=time_bin_edges[1] - time_bin_edges[0], + T=T, + in_place=in_place, + ) + extra["S"] = Ss + extra["mincorr"] = mincorr + + if not do_window_weights: + return Ss, extra + + # get weights + L_t = lambda_t * laplacian(T, eps=max(1e-5, eps)) + weights_orig, weights_thresh, Pind = get_weights( + Ds, + Ss, + L_t, + windows, + raster, + depth_bin_edges, + time_bin_edges, + # raster_kw, + post_transform=post_transform, + weights_threshold_low=weights_threshold_low, + weights_threshold_high=weights_threshold_high, + progress_bar=progress_bar, + ) + extra["weights_orig"] = weights_orig + extra["weights_thresh"] = weights_thresh + extra["Pind"] = Pind + + # update noise model. we deliberately divide by zero and inf here. + Us = Ss if in_place else np.zeros_like(Ss) + with np.errstate(divide="ignore"): + # low mem impl of U = abs(1/(1/weights_thresh+1/weights_thresh'+1/S)) + np.reciprocal(Ss, out=Us) + invW = 1.0 / weights_thresh + Us += invW[:, :, None] + Us += invW[:, None, :] + np.reciprocal(Us, out=Us) + # handles possible -0s that cause issues elsewhere + np.abs(Us, out=Us) + # more readable equivalent: + # for b in range(B): + # invWbtt = invW[b, :, None] + invW[b, None, :] + # Us[b] = np.abs(1.0 / (invWbtt + 1.0 / Ss[b])) + extra["U"] = Us + + return Us, extra diff --git a/src/spikeinterface/sortingcomponents/motion/iterative_template.py b/src/spikeinterface/sortingcomponents/motion/iterative_template.py new file mode 100644 index 0000000000..1b5eb75508 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/motion/iterative_template.py @@ -0,0 +1,296 @@ +import numpy as np + +from .motion_utils import Motion, get_spatial_windows, get_spatial_bin_edges, make_3d_motion_histograms + + +class IterativeTemplateRegistration: + """ + Alignment function implemented by Kilosort2.5 and ported from pykilosort: + https://github.com/int-brain-lab/pykilosort/blob/ibl_prod/pykilosort/datashift2.py#L166 + + The main difference with respect to the original implementation are: + * scipy is used for gaussian smoothing + * windowing is implemented as gaussian tapering (instead of rectangular blocks) + * the 3d histogram is constructed in less cryptic way + * peak_locations are computed outside and so can either center fo mass or monopolar trianglation + contrary to kilosort2.5 use exclusively center of mass + + See https://www.science.org/doi/abs/10.1126/science.abf4588?cookieSet=1 + + Ported by Alessio Buccino into SpikeInterface + """ + + name = "iterative_template" + need_peak_location = True + params_doc = """ + bin_um: float, default: 10 + Spatial bin size in micrometers + hist_margin_um: float, default: 0 + Margin in um from histogram estimation. + Positive margin extrapolate out of the probe the motion. + Negative margin crop the motion on the border + bin_s: float, default: 2.0 + Bin duration in second + num_amp_bins: int, default: 20 + number ob bins in the histogram on the log amplitues dimension + num_shifts_global: int, default: 15 + Number of spatial bin shifts to consider for global alignment + num_iterations: int, default: 10 + Number of iterations for global alignment procedure + num_shifts_block: int, default: 5 + Number of spatial bin shifts to consider for non-rigid alignment + smoothing_sigma: float, default: 0.5 + Sigma of gaussian for covariance matrices smoothing + kriging_sigma: float, + sigma parameter for kriging_kernel function + kriging_p: foat + p parameter for kriging_kernel function + kriging_d: float + d parameter for kriging_kernel function + """ + + @classmethod + def run( + cls, + recording, + peaks, + peak_locations, + direction, + rigid, + win_shape, + win_step_um, + win_scale_um, + win_margin_um, + verbose, + progress_bar, + extra, + bin_um=10.0, + hist_margin_um=0.0, + bin_s=2.0, + num_amp_bins=20, + num_shifts_global=15, + num_iterations=10, + num_shifts_block=5, + smoothing_sigma=0.5, + kriging_sigma=1, + kriging_p=2, + kriging_d=2, + ): + + dim = ["x", "y", "z"].index(direction) + contact_depths = recording.get_channel_locations()[:, dim] + + # spatial histogram bins + spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um) + spatial_bin_centers = 0.5 * (spatial_bin_edges[1:] + spatial_bin_edges[:-1]) + + # get spatial windows + non_rigid_windows, non_rigid_window_centers = get_spatial_windows( + contact_depths=contact_depths, + spatial_bin_centers=spatial_bin_centers, + rigid=rigid, + win_margin_um=win_margin_um, + win_step_um=win_step_um, + win_scale_um=win_scale_um, + win_shape=win_shape, + zero_threshold=None, + ) + + # make a 3D histogram + if verbose: + print("Making 3D motion histograms") + motion_histograms, temporal_hist_bin_edges, spatial_hist_bin_edges = make_3d_motion_histograms( + recording, + peaks, + peak_locations, + direction=direction, + num_amp_bins=num_amp_bins, + bin_s=bin_s, + spatial_bin_edges=spatial_bin_edges, + ) + # temporal bins are bin center + temporal_bins = temporal_hist_bin_edges[:-1] + bin_s // 2.0 + + # do alignment + if verbose: + print("Estimating alignment shifts") + shift_indices, target_histogram, shift_covs_block = iterative_template_registration( + motion_histograms, + non_rigid_windows=non_rigid_windows, + num_shifts_global=num_shifts_global, + num_iterations=num_iterations, + num_shifts_block=num_shifts_block, + smoothing_sigma=smoothing_sigma, + kriging_sigma=kriging_sigma, + kriging_p=kriging_p, + kriging_d=kriging_d, + ) + + # convert to um + motion_array = -(shift_indices * bin_um) + + if extra: + extra["non_rigid_windows"] = non_rigid_windows + extra["motion_histograms"] = motion_histograms + extra["target_histogram"] = target_histogram + extra["shift_covs_block"] = shift_covs_block + extra["temporal_hist_bin_edges"] = temporal_hist_bin_edges + extra["spatial_hist_bin_edges"] = spatial_hist_bin_edges + + # replace nan by zeros + np.nan_to_num(motion_array, copy=False) + + motion = Motion([motion_array], [temporal_bins], non_rigid_window_centers, direction=direction) + + return motion + + +def iterative_template_registration( + spikecounts_hist_images, + non_rigid_windows=None, + num_shifts_global=15, + num_iterations=10, + num_shifts_block=5, + smoothing_sigma=0.5, + kriging_sigma=1, + kriging_p=2, + kriging_d=2, +): + """ + + Parameters + ---------- + + spikecounts_hist_images : np.ndarray + Spike count histogram images (num_temporal_bins, num_spatial_bins, num_amps_bins) + non_rigid_windows : list, default: None + If num_non_rigid_windows > 1, this argument is required and it is a list of + windows to taper spatial bins in different blocks + num_shifts_global : int, default: 15 + Number of spatial bin shifts to consider for global alignment + num_iterations : int, default: 10 + Number of iterations for global alignment procedure + num_shifts_block : int, default: 5 + Number of spatial bin shifts to consider for non-rigid alignment + smoothing_sigma : float, default: 0.5 + Sigma of gaussian for covariance matrices smoothing + kriging_sigma : float, default: 1 + sigma parameter for kriging_kernel function + kriging_p : float, default: 2 + p parameter for kriging_kernel function + kriging_d : float, default: 2 + d parameter for kriging_kernel function + + Returns + ------- + optimal_shift_indices + Optimal shifts for each temporal and spatial bin (num_temporal_bins, num_non_rigid_windows) + target_spikecount_hist + Target histogram used for alignment (num_spatial_bins, num_amps_bins) + """ + from scipy.ndimage import gaussian_filter, gaussian_filter1d + + # F is y bins by amp bins by batches + # ysamp are the coordinates of the y bins in um + spikecounts_hist_images = spikecounts_hist_images.swapaxes(0, 1).swapaxes(1, 2) + num_temporal_bins = spikecounts_hist_images.shape[2] + + # look up and down this many y bins to find best alignment + shift_covs = np.zeros((2 * num_shifts_global + 1, num_temporal_bins)) + shifts = np.arange(-num_shifts_global, num_shifts_global + 1) + + # mean subtraction to compute covariance + F = spikecounts_hist_images + Fg = F - np.mean(F, axis=0) + + # initialize the target "frame" for alignment with a single sample + # here we removed min(299, ...) + F0 = Fg[:, :, np.floor(num_temporal_bins / 2).astype("int") - 1] + F0 = F0[:, :, np.newaxis] + + # first we do rigid registration by integer shifts + # everything is iteratively aligned until most of the shifts become 0. + best_shifts = np.zeros((num_iterations, num_temporal_bins)) + for iteration in range(num_iterations): + for t, shift in enumerate(shifts): + # for each NEW potential shift, estimate covariance + Fs = np.roll(Fg, shift, axis=0) + shift_covs[t, :] = np.mean(Fs * F0, axis=(0, 1)) + if iteration + 1 < num_iterations: + # estimate the best shifts + imax = np.argmax(shift_covs, axis=0) + # align the data by these integer shifts + for t, shift in enumerate(shifts): + ibest = imax == t + Fg[:, :, ibest] = np.roll(Fg[:, :, ibest], shift, axis=0) + best_shifts[iteration, ibest] = shift + # new target frame based on our current best alignment + F0 = np.mean(Fg, axis=2)[:, :, np.newaxis] + target_spikecount_hist = F0[:, :, 0] + + # now we figure out how to split the probe into nblocks pieces + # if len(non_rigid_windows) = 1, then we're doing rigid registration + num_non_rigid_windows = len(non_rigid_windows) + + # for each small block, we only look up and down this many samples to find + # nonrigid shift + shifts_block = np.arange(-num_shifts_block, num_shifts_block + 1) + num_shifts = len(shifts_block) + shift_covs_block = np.zeros((2 * num_shifts_block + 1, num_temporal_bins, num_non_rigid_windows)) + + # this part determines the up/down covariance for each block without + # shifting anything + for window_index in range(num_non_rigid_windows): + win = non_rigid_windows[window_index] + window_slice = np.flatnonzero(win > 1e-5) + window_slice = slice(window_slice[0], window_slice[-1]) + tiled_window = win[window_slice, np.newaxis, np.newaxis] + Ftaper = Fg[window_slice] * np.tile(tiled_window, (1,) + Fg.shape[1:]) + for t, shift in enumerate(shifts_block): + Fs = np.roll(Ftaper, shift, axis=0) + F0taper = F0[window_slice] * np.tile(tiled_window, (1,) + F0.shape[1:]) + shift_covs_block[t, :, window_index] = np.mean(Fs * F0taper, axis=(0, 1)) + + # gaussian smoothing: + # here the original my_conv2_cpu is substituted with scipy gaussian_filters + shift_covs_block_smooth = shift_covs_block.copy() + shifts_block_up = np.linspace(-num_shifts_block, num_shifts_block, (2 * num_shifts_block * 10) + 1) + # 1. 2d smoothing over time and blocks dimensions for each shift + for shift_index in range(num_shifts): + shift_covs_block_smooth[shift_index, :, :] = gaussian_filter( + shift_covs_block_smooth[shift_index, :, :], smoothing_sigma + ) # some additional smoothing for robustness, across all dimensions + # 2. 1d smoothing over shift dimension for each spatial block + for window_index in range(num_non_rigid_windows): + shift_covs_block_smooth[:, :, window_index] = gaussian_filter1d( + shift_covs_block_smooth[:, :, window_index], smoothing_sigma, axis=0 + ) # some additional smoothing for robustness, across all dimensions + upsample_kernel = kriging_kernel( + shifts_block[:, np.newaxis], shifts_block_up[:, np.newaxis], sigma=kriging_sigma, p=kriging_p, d=kriging_d + ) + + optimal_shift_indices = np.zeros((num_temporal_bins, num_non_rigid_windows)) + for window_index in range(num_non_rigid_windows): + # using the upsampling kernel K, get the upsampled cross-correlation + # curves + upsampled_cov = upsample_kernel.T @ shift_covs_block_smooth[:, :, window_index] + + # find the max index of these curves + imax = np.argmax(upsampled_cov, axis=0) + + # add the value of the shift to the last row of the matrix of shifts + # (as if it was the last iteration of the main rigid loop ) + best_shifts[num_iterations - 1, :] = shifts_block_up[imax] + + # the sum of all the shifts equals the final shifts for this block + optimal_shift_indices[:, window_index] = np.sum(best_shifts, axis=0) + + return optimal_shift_indices, target_spikecount_hist, shift_covs_block + + +def kriging_kernel(source_location, target_location, sigma=1, p=2, d=2): + from scipy.spatial.distance import cdist + + dist_xy = cdist(source_location, target_location, metric="euclidean") + K = np.exp(-((dist_xy / sigma) ** p) / d) + return K diff --git a/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py b/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py new file mode 100644 index 0000000000..6fe36a6193 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py @@ -0,0 +1,72 @@ +import numpy as np + +# TODO this need a full rewrite with motion object + + +def clean_motion_vector(motion, temporal_bins, bin_duration_s, speed_threshold=30, sigma_smooth_s=None): + """ + Simple machinery to remove spurious fast bump in the motion vector. + Also can apply a smoothing. + + + Arguments + --------- + motion: numpy array 2d + Motion estimate in um. + temporal_bins: numpy.array 1d + temporal bins (bin center) + bin_duration_s: float + bin duration in second + speed_threshold: float (units um/s) + Maximum speed treshold between 2 bins allowed. + Expressed in um/s + sigma_smooth_s: None or float + Optional smooting gaussian kernel. + + Returns + ------- + corr : tensor + + + """ + motion_clean = motion.copy() + + # STEP 1 : + # * detect long plateau or small peak corssing the speed thresh + # * mask the period and interpolate + for i in range(motion.shape[1]): + one_motion = motion_clean[:, i] + speed = np.diff(one_motion, axis=0) / bin_duration_s + (inds,) = np.nonzero(np.abs(speed) > speed_threshold) + inds += 1 + if inds.size % 2 == 1: + # more compicated case: number of of inds is odd must remove first or last + # take the smallest duration sum + inds0 = inds[:-1] + inds1 = inds[1:] + d0 = np.sum(inds0[1::2] - inds0[::2]) + d1 = np.sum(inds1[1::2] - inds1[::2]) + if d0 < d1: + inds = inds0 + mask = np.ones(motion_clean.shape[0], dtype="bool") + for i in range(inds.size // 2): + mask[inds[i * 2] : inds[i * 2 + 1]] = False + import scipy.interpolate + + f = scipy.interpolate.interp1d(temporal_bins[mask], one_motion[mask]) + one_motion[~mask] = f(temporal_bins[~mask]) + + # Step 2 : gaussian smooth + if sigma_smooth_s is not None: + half_size = motion_clean.shape[0] // 2 + if motion_clean.shape[0] % 2 == 0: + # take care of the shift + bins = (np.arange(motion_clean.shape[0]) - half_size + 1) * bin_duration_s + else: + bins = (np.arange(motion_clean.shape[0]) - half_size) * bin_duration_s + smooth_kernel = np.exp(-(bins**2) / (2 * sigma_smooth_s**2)) + smooth_kernel /= np.sum(smooth_kernel) + smooth_kernel = smooth_kernel[:, None] + motion_clean = scipy.signal.fftconvolve(motion_clean, smooth_kernel, mode="same", axes=0) + + return motion_clean diff --git a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py new file mode 100644 index 0000000000..2d8564fc54 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import warnings +import numpy as np + + +from spikeinterface.sortingcomponents.tools import make_multi_method_doc + + +from .motion_utils import Motion, get_spatial_windows, get_spatial_bin_edges +from .decentralized import DecentralizedRegistration +from .iterative_template import IterativeTemplateRegistration +from .dredge import DredgeLfpRegistration, DredgeApRegistration + + +# estimate_motion > infer_motion +def estimate_motion( + recording, + peaks=None, + peak_locations=None, + direction="y", + rigid=False, + win_shape="gaussian", + win_step_um=50.0, # @alessio charlie is proposing here instead 400 + win_scale_um=150.0, # @alessio charlie is proposing here instead 400 + win_margin_um=None, + method="decentralized", + extra_outputs=False, + progress_bar=False, + verbose=False, + margin_um=None, + **method_kwargs, +): + """ + + + Estimate motion with several possible methods. + + Most of methods except dredge_lfp needs peaks and after their localization. + + Note that the way you detect peak locations (center of mass/monopolar_triangulation/grid_convolution) + have an impact on the result. + + Parameters + ---------- + recording: BaseRecording + The recording extractor + peaks: numpy array + Peak vector (complex dtype). + Needed for decentralized and iterative_template methods. + peak_locations: numpy array + Complex dtype with "x", "y", "z" fields + Needed for decentralized and iterative_template methods. + direction: "x" | "y" | "z", default: "y" + Dimension on which the motion is estimated. "y" is depth along the probe. + + {method_doc} + + **non-rigid section** + + rigid : bool, default: False + Compute rigid (one motion for the entire probe) or non rigid motion + Rigid computation is equivalent to non-rigid with only one window with rectangular shape. + win_shape : "gaussian" | "rect" | "triangle", default: "gaussian" + The shape of the windows for non rigid. + When rigid this is force to "rect" + Nonrigid window-related arguments + The depth domain will be broken up into windows with shape controlled by win_shape, + spaced by win_step_um at a margin of win_margin_um from the boundary, and with + width controlled by win_scale_um. + When win_margin_um is None the margin is automatically set to -win_scale_um/2. + See get_spatial_windows. + win_step_um : float, default: 50 + See win_shape + win_scale_um : float, default: 150 + See win_shape + win_margin_um : None | float, default: None + See win_shape + extra_outputs: bool, default: False + If True then return an extra dict that contains variables + to check intermediate steps (motion_histogram, non_rigid_windows, pairwise_displacement) + progress_bar: bool, default: False + Display progress bar or not + verbose: bool, default: False + If True, output is verbose + + + Returns + ------- + motion: Motion object + The motion object. + extra: dict + Optional output if `extra_outputs=True` + This dict contain histogram, pairwise_displacement usefull for ploting. + """ + + if margin_um is not None: + warnings.warn("estimate_motion() margin_um has been removed used hist_margin_um or win_margin_um") + + # TODO handle multi segment one day : Charlie this is for you + assert recording.get_num_segments() == 1, "At the moment estimate_motion handle only unique segment" + + method_class = estimate_motion_methods[method] + + if method_class.need_peak_location: + if peaks is None or peak_locations is None: + raise ValueError(f"estimate_motion: the method {method} need peaks and peak_locations") + + if extra_outputs: + extra = {} + else: + extra = None + + # run method + motion = method_class.run( + recording, + peaks, + peak_locations, + direction, + rigid, + win_shape, + win_step_um, + win_scale_um, + win_margin_um, + verbose, + progress_bar, + extra, + **method_kwargs, + ) + + if extra_outputs: + return motion, extra + else: + return motion + + +_methods_list = [DecentralizedRegistration, IterativeTemplateRegistration, DredgeLfpRegistration, DredgeApRegistration] +estimate_motion_methods = {m.name: m for m in _methods_list} +method_doc = make_multi_method_doc(_methods_list) +estimate_motion.__doc__ = estimate_motion.__doc__.format(method_doc=method_doc) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py similarity index 99% rename from src/spikeinterface/sortingcomponents/motion_interpolation.py rename to src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 32bb7634e9..11ce11e1aa 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -27,6 +27,9 @@ def correct_motion_on_peaks(peaks, peak_locations, motion, recording): corrected_peak_locations: np.array Motion-corrected peak locations """ + if recording is None: + raise ValueError("correct_motion_on_peaks need recording to be not None") + corrected_peak_locations = peak_locations.copy() for segment_index in range(motion.num_segments): diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py new file mode 100644 index 0000000000..a48e10b3e1 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -0,0 +1,577 @@ +import warnings +import json +from pathlib import Path + +import numpy as np +import spikeinterface +from spikeinterface.core.core_tools import check_json + + +class Motion: + """ + Motion of the tissue relative the probe. + + Parameters + ---------- + displacement : numpy array 2d or list of + Motion estimate in um. + List is the number of segment. + For each semgent : + * shape (temporal bins, spatial bins) + * motion.shape[0] = temporal_bins.shape[0] + * motion.shape[1] = 1 (rigid) or spatial_bins.shape[1] (non rigid) + temporal_bins_s : numpy.array 1d or list of + temporal bins (bin center) + spatial_bins_um : numpy.array 1d + Windows center. + spatial_bins_um.shape[0] == displacement.shape[1] + If rigid then spatial_bins_um.shape[0] == 1 + direction : str, default: 'y' + Direction of the motion. + interpolation_method : str + How to determine the displacement between bin centers? See the docs + for scipy.interpolate.RegularGridInterpolator for options. + """ + + def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y", interpolation_method="linear"): + if isinstance(displacement, np.ndarray): + self.displacement = [displacement] + assert isinstance(temporal_bins_s, np.ndarray) + self.temporal_bins_s = [temporal_bins_s] + else: + assert isinstance(displacement, (list, tuple)) + self.displacement = displacement + self.temporal_bins_s = temporal_bins_s + + assert isinstance(spatial_bins_um, np.ndarray) + self.spatial_bins_um = spatial_bins_um + + self.num_segments = len(self.displacement) + self.interpolators = None + self.interpolation_method = interpolation_method + + self.direction = direction + self.dim = ["x", "y", "z"].index(direction) + self.check_properties() + + def check_properties(self): + assert all(d.ndim == 2 for d in self.displacement) + assert all(t.ndim == 1 for t in self.temporal_bins_s) + assert all(self.spatial_bins_um.shape == (d.shape[1],) for d in self.displacement) + + def __repr__(self): + nbins = self.spatial_bins_um.shape[0] + if nbins == 1: + rigid_txt = "rigid" + else: + rigid_txt = f"non-rigid - {nbins} spatial bins" + + interval_s = self.temporal_bins_s[0][1] - self.temporal_bins_s[0][0] + txt = f"Motion {rigid_txt} - interval {interval_s}s - {self.num_segments} segments" + return txt + + def make_interpolators(self): + from scipy.interpolate import RegularGridInterpolator + + self.interpolators = [ + RegularGridInterpolator( + (self.temporal_bins_s[j], self.spatial_bins_um), self.displacement[j], method=self.interpolation_method + ) + for j in range(self.num_segments) + ] + self.temporal_bounds = [(t[0], t[-1]) for t in self.temporal_bins_s] + self.spatial_bounds = (self.spatial_bins_um.min(), self.spatial_bins_um.max()) + + def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_index=None, grid=False): + """Evaluate the motion estimate at times and positions + + Evaluate the motion estimate, returning the (linearly interpolated) estimated displacement + at the given times and locations. + + Parameters + ---------- + times_s: np.array + locations_um: np.array + Either this is a one-dimensional array (a vector of positions along self.dimension), or + else a 2d array with the 2 or 3 spatial dimensions indexed along axis=1. + segment_index: int, default: None + The index of the segment to evaluate. If None, and there is only one segment, then that segment is used. + grid : bool, default: False + If grid=False, the default, then times_s and locations_um should have the same one-dimensional + shape, and the returned displacement[i] is the displacement at time times_s[i] and location + locations_um[i]. + If grid=True, times_s and locations_um determine a grid of positions to evaluate the displacement. + Then the returned displacement[i,j] is the displacement at depth locations_um[i] and time times_s[j]. + + Returns + ------- + displacement : np.array + A displacement per input location, of shape times_s.shape if grid=False and (locations_um.size, times_s.size) + if grid=True. + """ + if self.interpolators is None: + self.make_interpolators() + + if segment_index is None: + if self.num_segments == 1: + segment_index = 0 + else: + raise ValueError("Several segment need segment_index=") + + times_s = np.asarray(times_s) + locations_um = np.asarray(locations_um) + + if locations_um.ndim == 1: + locations_um = locations_um + elif locations_um.ndim == 2: + locations_um = locations_um[:, self.dim] + else: + assert False + + times_s = times_s.clip(*self.temporal_bounds[segment_index]) + locations_um = locations_um.clip(*self.spatial_bounds) + + if grid: + # construct a grid over which to evaluate the displacement + locations_um, times_s = np.meshgrid(locations_um, times_s, indexing="ij") + out_shape = times_s.shape + locations_um = locations_um.ravel() + times_s = times_s.ravel() + else: + # usual case: input is a point cloud + assert locations_um.shape == times_s.shape + assert times_s.ndim == 1 + out_shape = times_s.shape + + points = np.column_stack((times_s, locations_um)) + displacement = self.interpolators[segment_index](points) + # reshape to grid domain shape if necessary + displacement = displacement.reshape(out_shape) + + return displacement + + def to_dict(self): + return dict( + displacement=self.displacement, + temporal_bins_s=self.temporal_bins_s, + spatial_bins_um=self.spatial_bins_um, + interpolation_method=self.interpolation_method, + direction=self.direction, + ) + + def save(self, folder): + folder = Path(folder) + folder.mkdir(exist_ok=False, parents=True) + + info_file = folder / f"spikeinterface_info.json" + info = dict( + version=spikeinterface.__version__, + dev_mode=spikeinterface.DEV_MODE, + object="Motion", + num_segments=self.num_segments, + direction=self.direction, + interpolation_method=self.interpolation_method, + ) + with open(info_file, mode="w") as f: + json.dump(check_json(info), f, indent=4) + + np.save(folder / "spatial_bins_um.npy", self.spatial_bins_um) + + for segment_index in range(self.num_segments): + np.save(folder / f"displacement_seg{segment_index}.npy", self.displacement[segment_index]) + np.save(folder / f"temporal_bins_s_seg{segment_index}.npy", self.temporal_bins_s[segment_index]) + + @classmethod + def load(cls, folder): + folder = Path(folder) + + info_file = folder / f"spikeinterface_info.json" + err_msg = f"Motion.load(folder): the folder {folder} does not contain a Motion object." + if not info_file.exists(): + raise IOError(err_msg) + + with open(info_file, "r") as f: + info = json.load(f) + if "object" not in info or info["object"] != "Motion": + raise IOError(err_msg) + + direction = info["direction"] + interpolation_method = info["interpolation_method"] + spatial_bins_um = np.load(folder / "spatial_bins_um.npy") + displacement = [] + temporal_bins_s = [] + for segment_index in range(info["num_segments"]): + displacement.append(np.load(folder / f"displacement_seg{segment_index}.npy")) + temporal_bins_s.append(np.load(folder / f"temporal_bins_s_seg{segment_index}.npy")) + + return cls( + displacement, + temporal_bins_s, + spatial_bins_um, + direction=direction, + interpolation_method=interpolation_method, + ) + + def __eq__(self, other): + for segment_index in range(self.num_segments): + if not np.allclose(self.displacement[segment_index], other.displacement[segment_index]): + return False + if not np.allclose(self.temporal_bins_s[segment_index], other.temporal_bins_s[segment_index]): + return False + + if not np.allclose(self.spatial_bins_um, other.spatial_bins_um): + return False + + return True + + def copy(self): + return Motion( + [d.copy() for d in self.displacement], + [t.copy() for t in self.temporal_bins_s], + self.spatial_bins_um.copy(), + direction=self.direction, + interpolation_method=self.interpolation_method, + ) + + +def get_spatial_windows( + contact_depths, + spatial_bin_centers, + rigid=False, + win_shape="gaussian", + win_step_um=50.0, + win_scale_um=150.0, + win_margin_um=None, + zero_threshold=None, +): + """ + Generate spatial windows (taper) for non-rigid motion. + For rigid motion, this is equivalent to have one unique rectangular window that covers the entire probe. + The windowing can be gaussian or rectangular. + Windows are centered between the min/max of contact_depths. + We can ensure window to not be to close from border with win_margin_um. + + + Parameters + ---------- + contact_depths : np.ndarray + Position of electrodes of the corection direction shape=(num_channels, ) + spatial_bin_centers : np.array + The pre-computed spatial bin centers + rigid : bool, default False + If True, returns a single rectangular window + win_shape : str, default "gaussian" + Shape of the window + "gaussian" | "rect" | "triangle" + win_step_um : float + The steps at which windows are defined + win_scale_um : float, default 150. + Sigma of gaussian window if win_shape is gaussian + Width of the rectangle if win_shape is rect + win_margin_um : None | float, default None + The margin to extend (if positive) or shrink (if negative) the probe dimension to compute windows. + When None, then the margin is set to -win_scale_um./2 + zero_threshold: None | float + Lower value for thresholding to set zeros. + + Returns + ------- + windows : 2D arrays + The scaling for each window. Each element has num_spatial_bins values + shape: (num_window, spatial_bins) + window_centers: 1D np.array + The center of each window + + Notes + ----- + Note that kilosort2.5 uses overlaping rectangular windows. + Here by default we use gaussian window. + + """ + n = spatial_bin_centers.size + + if rigid: + # win_shape = 'rect' is forced + windows, window_centers = get_rigid_windows(spatial_bin_centers) + else: + if win_scale_um <= win_step_um / 5.0: + warnings.warn( + f"get_spatial_windows(): spatial windows are probably not overlapping because {win_scale_um=} and {win_step_um=}" + ) + + if win_margin_um is None: + # this ensure that first/last windows do not overflow outside the probe + win_margin_um = -win_scale_um / 2.0 + + min_ = np.min(contact_depths) - win_margin_um + max_ = np.max(contact_depths) + win_margin_um + num_windows = int((max_ - min_) // win_step_um) + + if num_windows < 1: + raise Exception( + f"get_spatial_windows(): {win_step_um=}/{win_scale_um=}/{win_margin_um=} are too large for the " + f"probe size (depth range={np.ptp(contact_depths)}). You can try to reduce them or use rigid motion." + ) + border = ((max_ - min_) % win_step_um) / 2 + window_centers = np.arange(num_windows + 1) * win_step_um + min_ + border + windows = [] + + for win_center in window_centers: + if win_shape == "gaussian": + win = np.exp(-((spatial_bin_centers - win_center) ** 2) / (2 * win_scale_um**2)) + elif win_shape == "rect": + win = np.abs(spatial_bin_centers - win_center) < (win_scale_um / 2.0) + win = win.astype("float64") + elif win_shape == "triangle": + center_dist = np.abs(spatial_bin_centers - win_center) + in_window = center_dist <= (win_scale_um / 2.0) + win = -center_dist + win[~in_window] = 0 + win[in_window] -= win[in_window].min() + win[in_window] /= win[in_window].max() + windows.append(win) + + windows = np.array(windows) + + if zero_threshold is not None: + windows[windows < zero_threshold] = 0 + windows /= windows.sum(axis=1, keepdims=True) + + return windows, window_centers + + +def get_rigid_windows(spatial_bin_centers): + """Generate a single rectangular window for rigid motion.""" + windows = np.ones((1, spatial_bin_centers.size), dtype="float64") + window_centers = np.array([(spatial_bin_centers[0] + spatial_bin_centers[-1]) / 2.0]) + return windows, window_centers + + +def get_window_domains(windows): + """Array of windows -> list of slices where window > 0.""" + slices = [] + for w in windows: + in_window = np.flatnonzero(w) + slices.append(slice(in_window[0], in_window[-1] + 1)) + return slices + + +def scipy_conv1d(input, weights, padding="valid"): + """SciPy translation of torch F.conv1d""" + from scipy.signal import correlate + + n, c_in, length = input.shape + c_out, in_by_groups, kernel_size = weights.shape + assert in_by_groups == c_in == 1 + + if padding == "same": + mode = "same" + length_out = length + elif padding == "valid": + mode = "valid" + length_out = length - 2 * (kernel_size // 2) + elif isinstance(padding, int): + mode = "valid" + input = np.pad(input, [*[(0, 0)] * (input.ndim - 1), (padding, padding)]) + length_out = length - (kernel_size - 1) + 2 * padding + else: + raise ValueError(f"Unknown 'padding' value of {padding}, 'padding' must be 'same', 'valid' or an integer") + + output = np.zeros((n, c_out, length_out), dtype=input.dtype) + for m in range(n): + for c in range(c_out): + output[m, c] = correlate(input[m, 0], weights[c, 0], mode=mode) + + return output + + +def get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um): + # contact along one axis + probe = recording.get_probe() + dim = ["x", "y", "z"].index(direction) + contact_depths = probe.contact_positions[:, dim] + + min_ = np.min(contact_depths) - hist_margin_um + max_ = np.max(contact_depths) + hist_margin_um + spatial_bins = np.arange(min_, max_ + bin_um, bin_um) + + return spatial_bins + + +def make_2d_motion_histogram( + recording, + peaks, + peak_locations, + weight_with_amplitude=False, + avg_in_bin=True, + direction="y", + bin_s=1.0, + bin_um=2.0, + hist_margin_um=50, + spatial_bin_edges=None, + depth_smooth_um=None, + time_smooth_s=None, +): + """ + Generate 2d motion histogram in depth and time. + + Parameters + ---------- + recording : BaseRecording + The input recording + peaks : np.array + The peaks array + peak_locations : np.array + Array with peak locations + weight_with_amplitude : bool, default: False + If True, motion histogram is weighted by amplitudes + avg_in_bin : bool, default True + If true, average the amplitudes in each bin. + This is done only if weight_with_amplitude=True. + direction : "x" | "y" | "z", default: "y" + The depth direction + bin_s : float, default: 1.0 + The temporal bin duration in s + bin_um : float, default: 2.0 + The spatial bin size in um. Ignored if spatial_bin_edges is given. + hist_margin_um : float, default: 50 + The margin to add to the minimum and maximum positions before spatial binning. + Ignored if spatial_bin_edges is given. + spatial_bin_edges : np.array, default: None + The pre-computed spatial bin edges + depth_smooth_um: None or float + Optional gaussian smoother on histogram on depth axis. + This is given as the sigma of the gaussian in micrometers. + time_smooth_s: None or float + Optional gaussian smoother on histogram on time axis. + This is given as the sigma of the gaussian in seconds. + + Returns + ------- + motion_histogram + 2d np.array with motion histogram (num_temporal_bins, num_spatial_bins) + temporal_bin_edges + 1d array with temporal bin edges + spatial_bin_edges + 1d array with spatial bin edges + """ + n_samples = recording.get_num_samples() + mint_s = recording.sample_index_to_time(0) + maxt_s = recording.sample_index_to_time(n_samples) + temporal_bin_edges = np.arange(mint_s, maxt_s + bin_s, bin_s) + if spatial_bin_edges is None: + spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um) + else: + bin_um = spatial_bin_edges[1] - spatial_bin_edges[0] + + arr = np.zeros((peaks.size, 2), dtype="float64") + arr[:, 0] = recording.sample_index_to_time(peaks["sample_index"]) + arr[:, 1] = peak_locations[direction] + + if weight_with_amplitude: + weights = np.abs(peaks["amplitude"]) + else: + weights = None + + motion_histogram, edges = np.histogramdd(arr, bins=(temporal_bin_edges, spatial_bin_edges), weights=weights) + + # average amplitude in each bin + if weight_with_amplitude and avg_in_bin: + bin_counts, _ = np.histogramdd(arr, bins=(temporal_bin_edges, spatial_bin_edges)) + bin_counts[bin_counts == 0] = 1 + motion_histogram = motion_histogram / bin_counts + + from scipy.ndimage import gaussian_filter1d + + if depth_smooth_um is not None: + motion_histogram = gaussian_filter1d(motion_histogram, depth_smooth_um / bin_um, axis=1, mode="constant") + + if time_smooth_s is not None: + motion_histogram = gaussian_filter1d(motion_histogram, time_smooth_s / bin_s, axis=0, mode="constant") + + return motion_histogram, temporal_bin_edges, spatial_bin_edges + + +def make_3d_motion_histograms( + recording, + peaks, + peak_locations, + direction="y", + bin_s=1.0, + bin_um=2.0, + hist_margin_um=50, + num_amp_bins=20, + log_transform=True, + spatial_bin_edges=None, +): + """ + Generate 3d motion histograms in depth, amplitude, and time. + This is used by the "iterative_template_registration" (Kilosort2.5) method. + + + Parameters + ---------- + recording : BaseRecording + The input recording + peaks : np.array + The peaks array + peak_locations : np.array + Array with peak locations + direction : "x" | "y" | "z", default: "y" + The depth direction + bin_s : float, default: 1.0 + The temporal bin duration in s. + bin_um : float, default: 2.0 + The spatial bin size in um. Ignored if spatial_bin_edges is given. + hist_margin_um : float, default: 50 + The margin to add to the minimum and maximum positions before spatial binning. + Ignored if spatial_bin_edges is given. + log_transform : bool, default: True + If True, histograms are log-transformed + spatial_bin_edges : np.array, default: None + The pre-computed spatial bin edges + + Returns + ------- + motion_histograms + 3d np.array with motion histogram (num_temporal_bins, num_spatial_bins, num_amp_bins) + temporal_bin_edges + 1d array with temporal bin edges + spatial_bin_edges + 1d array with spatial bin edges + """ + n_samples = recording.get_num_samples() + mint_s = recording.sample_index_to_time(0) + maxt_s = recording.sample_index_to_time(n_samples) + temporal_bin_edges = np.arange(mint_s, maxt_s + bin_s, bin_s) + if spatial_bin_edges is None: + spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um) + + # pre-compute abs amplitude and ranges for scaling + amplitude_bin_edges = np.linspace(0, 1, num_amp_bins + 1) + abs_peaks = np.abs(peaks["amplitude"]) + max_peak_amp = np.max(abs_peaks) + min_peak_amp = np.min(abs_peaks) + # log amplitudes and scale between 0-1 + abs_peaks_log_norm = (np.log10(abs_peaks) - np.log10(min_peak_amp)) / ( + np.log10(max_peak_amp) - np.log10(min_peak_amp) + ) + + arr = np.zeros((peaks.size, 3), dtype="float64") + arr[:, 0] = recording.sample_index_to_time(peaks["sample_index"]) + arr[:, 1] = peak_locations[direction] + arr[:, 2] = abs_peaks_log_norm + + motion_histograms, edges = np.histogramdd( + arr, + bins=( + temporal_bin_edges, + spatial_bin_edges, + amplitude_bin_edges, + ), + ) + + if log_transform: + motion_histograms = np.log2(1 + motion_histograms) + + return motion_histograms, temporal_bin_edges, spatial_bin_edges diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_drege.py b/src/spikeinterface/sortingcomponents/motion/tests/test_drege.py new file mode 100644 index 0000000000..8133c1fa6b --- /dev/null +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_drege.py @@ -0,0 +1,9 @@ +import pytest + + +def test_dredge_online_lfp(): + pass + + +if __name__ == "__main__": + pass diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py similarity index 90% rename from src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py rename to src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py index af62ba52ec..3c83a56b9d 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py @@ -3,7 +3,7 @@ import numpy as np import pytest from spikeinterface.core.node_pipeline import ExtractDenseWaveforms -from spikeinterface.sortingcomponents.motion_estimation import estimate_motion +from spikeinterface.sortingcomponents.motion import estimate_motion from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass from spikeinterface.sortingcomponents.tests.common import make_dataset @@ -18,12 +18,11 @@ plt.show() -@pytest.fixture(scope="module") -def setup_module(tmp_path_factory): - recording, sorting = make_dataset() - cache_folder = tmp_path_factory.mktemp("cache_folder") +def setup_dataset_and_peaks(cache_folder): + print(cache_folder, type(cache_folder)) cache_folder.mkdir(parents=True, exist_ok=True) + recording, sorting = make_dataset() # detect and localize extract_dense_waveforms = ExtractDenseWaveforms(recording, ms_before=0.1, ms_after=0.3, return_output=False) pipeline_nodes = [ @@ -49,9 +48,16 @@ def setup_module(tmp_path_factory): return recording, sorting, cache_folder -def test_estimate_motion(setup_module): +@pytest.fixture(scope="module", name="dataset") +def dataset_fixture(create_cache_folder): + cache_folder = create_cache_folder / "motion_estimation" + return setup_dataset_and_peaks(cache_folder) + + +def test_estimate_motion(dataset): # recording, sorting = make_dataset() - recording, sorting, cache_folder = setup_module + recording, sorting, cache_folder = dataset + peaks = np.load(cache_folder / "dataset_peaks.npy") peak_locations = np.load(cache_folder / "dataset_peak_locations.npy") @@ -146,14 +152,14 @@ def test_estimate_motion(setup_module): kwargs = dict( direction="y", - bin_duration_s=1.0, + bin_s=1.0, bin_um=10.0, margin_um=5, - output_extra_check=True, + extra_outputs=True, ) kwargs.update(cases_kwargs) - motion, extra_check = estimate_motion(recording, peaks, peak_locations, **kwargs) + motion, extra = estimate_motion(recording, peaks, peak_locations, **kwargs) motions[name] = motion if cases_kwargs["rigid"]: @@ -215,5 +221,9 @@ def test_estimate_motion(setup_module): if __name__ == "__main__": - setup_module() - test_estimate_motion() + import tempfile + + with tempfile.TemporaryDirectory() as tmpdirname: + cache_folder = Path(tmpdirname) + args = setup_dataset_and_peaks(cache_folder) + test_estimate_motion(args) diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py similarity index 97% rename from src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py rename to src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py index cb26560272..e022f0cc6c 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py @@ -4,13 +4,13 @@ import pytest import spikeinterface.core as sc from spikeinterface import download_dataset -from spikeinterface.sortingcomponents.motion_interpolation import ( +from spikeinterface.sortingcomponents.motion.motion_interpolation import ( InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion, interpolate_motion_on_traces, ) -from spikeinterface.sortingcomponents.motion_utils import Motion +from spikeinterface.sortingcomponents.motion import Motion from spikeinterface.sortingcomponents.tests.common import make_dataset diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_utils.py similarity index 97% rename from src/spikeinterface/sortingcomponents/tests/test_motion_utils.py rename to src/spikeinterface/sortingcomponents/motion/tests/test_motion_utils.py index 0b67be39c0..73c469c955 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_utils.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from spikeinterface.sortingcomponents.motion_utils import Motion +from spikeinterface.sortingcomponents.motion.motion_utils import Motion from spikeinterface.generation import make_one_displacement_vector if hasattr(pytest, "global_test_folder"): diff --git a/src/spikeinterface/sortingcomponents/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion_estimation.py deleted file mode 100644 index 3134d68681..0000000000 --- a/src/spikeinterface/sortingcomponents/motion_estimation.py +++ /dev/null @@ -1,1547 +0,0 @@ -from __future__ import annotations - -from tqdm.auto import tqdm, trange -import numpy as np - - -from .motion_utils import Motion -from .tools import make_multi_method_doc - -try: - import torch - import torch.nn.functional as F - - HAVE_TORCH = True -except ImportError: - HAVE_TORCH = False - - -def estimate_motion( - recording, - peaks, - peak_locations, - direction="y", - bin_duration_s=10.0, - bin_um=10.0, - margin_um=0.0, - rigid=False, - win_shape="gaussian", - win_step_um=50.0, - win_sigma_um=150.0, - post_clean=False, - speed_threshold=30, - sigma_smooth_s=None, - method="decentralized", - output_extra_check=False, - progress_bar=False, - upsample_to_histogram_bin=False, - verbose=False, - **method_kwargs, -): - """ - Estimate motion for given peaks and after their localization. - - Note that the way you detect peak locations (center of mass/monopolar triangulation) have an impact on the result. - - Parameters - ---------- - recording: BaseRecording - The recording extractor - peaks: numpy array - Peak vector (complex dtype) - peak_locations: numpy array - Complex dtype with "x", "y", "z" fields - - {method_doc} - - **histogram section** - - direction: "x" | "y" | "z", default: "y" - Dimension on which the motion is estimated. "y" is depth along the probe. - bin_duration_s: float, default: 10 - Bin duration in second - bin_um: float, default: 10 - Spatial bin size in micrometers - margin_um: float, default: 0 - Margin in um to exclude from histogram estimation and - non-rigid smoothing functions to avoid edge effects. - Positive margin extrapolate out of the probe the motion. - Negative margin crop the motion on the border - - **non-rigid section** - - rigid : bool, default: False - Compute rigid (one motion for the entire probe) or non rigid motion - Rigid computation is equivalent to non-rigid with only one window with rectangular shape. - win_shape: "gaussian" | "rect" | "triangle", default: "gaussian" - The shape of the windows for non rigid. - When rigid this is force to "rect" - win_step_um: float, default: 50 - Step deteween window - win_sigma_um: float, default: 150 - Sigma of the gaussian window - - **motion cleaning section** - - post_clean: bool, default: False - Apply some post cleaning to motion matrix or not - speed_threshold: float default: 30. - Detect to fast motion bump and remove then with interpolation - sigma_smooth_s: None or float - Optional smooting gaussian kernel when not None - - output_extra_check: bool, default: False - If True then return an extra dict that contains variables - to check intermediate steps (motion_histogram, non_rigid_windows, pairwise_displacement) - upsample_to_histogram_bin: bool or None, default: False - If True then upsample the returned motion array to the number of depth bins specified by bin_um. - When None: - * for non rigid case: then automatically True - * for rigid (non_rigid_kwargs=None): automatically False - This feature is in fact a bad idea and the interpolation should be done outside using better methods - progress_bar: bool, default: False - Display progress bar or not - verbose: bool, default: False - If True, output is verbose - - - Returns - ------- - motion: Motion object - The motion object. - extra_check: dict - Optional output if `output_extra_check=True` - This dict contain histogram, pairwise_displacement usefull for ploting. - """ - # TODO handle multi segment one day - assert recording.get_num_segments() == 1 - - if output_extra_check: - extra_check = {} - else: - extra_check = None - - # contact positions - probe = recording.get_probe() - dim = ["x", "y", "z"].index(direction) - contact_pos = probe.contact_positions[:, dim] - - # spatial bins - spatial_bin_edges = get_spatial_bin_edges(recording, direction, margin_um, bin_um) - - # get windows - non_rigid_windows, non_rigid_window_centers = get_windows( - rigid, bin_um, contact_pos, spatial_bin_edges, margin_um, win_step_um, win_sigma_um, win_shape - ) - - if output_extra_check: - extra_check["non_rigid_windows"] = non_rigid_windows - - # run method - method_class = estimate_motion_methods[method] - motion_array, temporal_bins = method_class.run( - recording, - peaks, - peak_locations, - direction, - bin_duration_s, - bin_um, - spatial_bin_edges, - non_rigid_windows, - verbose, - progress_bar, - extra_check, - **method_kwargs, - ) - - # replace nan by zeros - np.nan_to_num(motion_array, copy=False) - - if post_clean: - motion_array = clean_motion_vector( - motion_array, temporal_bins, bin_duration_s, speed_threshold=speed_threshold, sigma_smooth_s=sigma_smooth_s - ) - - if upsample_to_histogram_bin is None: - upsample_to_histogram_bin = not rigid - if upsample_to_histogram_bin: - extra_check["motion_array"] = motion_array - extra_check["non_rigid_window_centers"] = non_rigid_window_centers - non_rigid_windows = np.array(non_rigid_windows) - non_rigid_windows /= non_rigid_windows.sum(axis=0, keepdims=True) - non_rigid_window_centers = spatial_bin_edges[:-1] + bin_um / 2 - motion_array = motion_array @ non_rigid_windows - - # TODO handle multi segment - motion = Motion([motion_array], [temporal_bins], non_rigid_window_centers, direction=direction) - - if output_extra_check: - return motion, extra_check - else: - return motion - - -class DecentralizedRegistration: - """ - Method developed by the Paninski's group from Columbia university: - Charlie Windolf, Julien Boussard, Erdem Varol, Hyun Dong Lee - - This method is also known as DREDGe, but this implemenation does not use LFP signals. - - Original reference: - DECENTRALIZED MOTION INFERENCE AND REGISTRATION OF NEUROPIXEL DATA - https://ieeexplore.ieee.org/document/9414145 - https://proceedings.neurips.cc/paper/2021/hash/b950ea26ca12daae142bd74dba4427c8-Abstract.html - - This code was improved during Spike Sorting NY Hackathon 2022 by Erdem Varol and Charlie Windolf. - An additional major improvement can be found in this paper: - https://www.biorxiv.org/content/biorxiv/early/2022/12/05/2022.12.04.519043.full.pdf - - - Here are some various implementations by the original team: - https://github.com/int-brain-lab/spikes_localization_registration/blob/main/registration_pipeline/image_based_motion_estimate.py#L211 - https://github.com/cwindolf/spike-psvae/tree/main/spike_psvae - https://github.com/evarol/DREDge - """ - - name = "decentralized" - params_doc = """ - histogram_depth_smooth_um: None or float - Optional gaussian smoother on histogram on depth axis. - This is given as the sigma of the gaussian in micrometers. - histogram_time_smooth_s: None or float - Optional gaussian smoother on histogram on time axis. - This is given as the sigma of the gaussian in seconds. - pairwise_displacement_method: "conv" or "phase_cross_correlation" - How to estimate the displacement in the pairwise matrix. - max_displacement_um: float - Maximum possible displacement in micrometers. - weight_scale: "linear" or "exp" - For parwaise displacement, how to to rescale the associated weight matrix. - error_sigma: float, default: 0.2 - In case weight_scale="exp" this controls the sigma of the exponential. - conv_engine: "numpy" or "torch" or None, default: None - In case of pairwise_displacement_method="conv", what library to use to compute - the underlying correlation - torch_device=None - In case of conv_engine="torch", you can control which device (cpu or gpu) - batch_size: int - Size of batch for the convolution. Increasing this will speed things up dramatically - on GPUs and sometimes on CPU as well. - corr_threshold: float - Minimum correlation between pair of time bins in order for these to be - considered when optimizing a global displacment vector to align with - the pairwise displacements. - time_horizon_s: None or float - When not None the parwise discplament matrix is computed in a small time horizon. - In short only pair of bins close in time. - So the pariwaise matrix is super sparse and have values only the diagonal. - convergence_method: "lsmr" | "lsqr_robust" | "gradient_descent", default: "lsqr_robust" - Which method to use to compute the global displacement vector from the pairwise matrix. - robust_regression_sigma: float - Use for convergence_method="lsqr_robust" for iterative selection of the regression. - temporal_prior : bool, default: True - Ensures continuity across time, unless there is evidence in the recording for jumps. - spatial_prior : bool, default: False - Ensures continuity across space. Not usually necessary except in recordings with - glitches across space. - force_spatial_median_continuity: bool, default: False - When spatial_prior=False we can optionally apply a median continuity across spatial windows. - reference_displacement : string, one of: "mean", "median", "time", "mode_search" - Strategy for picking what is considered displacement=0. - - "mean" : the mean displacement is subtracted - - "median" : the median displacement is subtracted - - "time" : the displacement at a given time (in seconds) is subtracted - - "mode_search" : an attempt is made to guess the mode. needs work. - lsqr_robust_n_iter: int - Number of iteration for convergence_method="lsqr_robust". - """ - - @classmethod - def run( - cls, - recording, - peaks, - peak_locations, - direction, - bin_duration_s, - bin_um, - spatial_bin_edges, - non_rigid_windows, - verbose, - progress_bar, - extra_check, - histogram_depth_smooth_um=None, - histogram_time_smooth_s=None, - pairwise_displacement_method="conv", - max_displacement_um=100.0, - weight_scale="linear", - error_sigma=0.2, - conv_engine=None, - torch_device=None, - batch_size=1, - corr_threshold=0.0, - time_horizon_s=None, - convergence_method="lsqr_robust", - soft_weights=False, - normalized_xcorr=True, - centered_xcorr=True, - temporal_prior=True, - spatial_prior=False, - force_spatial_median_continuity=False, - reference_displacement="median", - reference_displacement_time_s=0, - robust_regression_sigma=2, - lsqr_robust_n_iter=20, - weight_with_amplitude=False, - ): - # use torch if installed - if conv_engine is None: - conv_engine = "torch" if HAVE_TORCH else "numpy" - - # make 2D histogram raster - if verbose: - print("Computing motion histogram") - - motion_histogram, temporal_hist_bin_edges, spatial_hist_bin_edges = make_2d_motion_histogram( - recording, - peaks, - peak_locations, - direction=direction, - bin_duration_s=bin_duration_s, - spatial_bin_edges=spatial_bin_edges, - weight_with_amplitude=weight_with_amplitude, - ) - import scipy.signal - - if histogram_depth_smooth_um is not None: - bins = np.arange(motion_histogram.shape[1]) * bin_um - bins = bins - np.mean(bins) - smooth_kernel = np.exp(-(bins**2) / (2 * histogram_depth_smooth_um**2)) - smooth_kernel /= np.sum(smooth_kernel) - - motion_histogram = scipy.signal.fftconvolve(motion_histogram, smooth_kernel[None, :], mode="same", axes=1) - - if histogram_time_smooth_s is not None: - bins = np.arange(motion_histogram.shape[0]) * bin_duration_s - bins = bins - np.mean(bins) - smooth_kernel = np.exp(-(bins**2) / (2 * histogram_time_smooth_s**2)) - smooth_kernel /= np.sum(smooth_kernel) - motion_histogram = scipy.signal.fftconvolve(motion_histogram, smooth_kernel[:, None], mode="same", axes=0) - - if extra_check is not None: - extra_check["motion_histogram"] = motion_histogram - extra_check["pairwise_displacement_list"] = [] - extra_check["temporal_hist_bin_edges"] = temporal_hist_bin_edges - extra_check["spatial_hist_bin_edges"] = spatial_hist_bin_edges - - # temporal bins are bin center - temporal_bins = 0.5 * (temporal_hist_bin_edges[1:] + temporal_hist_bin_edges[:-1]) - - motion = np.zeros((temporal_bins.size, len(non_rigid_windows)), dtype=np.float64) - windows_iter = non_rigid_windows - if progress_bar: - windows_iter = tqdm(windows_iter, desc="windows") - if spatial_prior: - all_pairwise_displacements = np.empty( - (len(non_rigid_windows), temporal_bins.size, temporal_bins.size), dtype=np.float64 - ) - all_pairwise_displacement_weights = np.empty( - (len(non_rigid_windows), temporal_bins.size, temporal_bins.size), dtype=np.float64 - ) - for i, win in enumerate(windows_iter): - window_slice = np.flatnonzero(win > 1e-5) - window_slice = slice(window_slice[0], window_slice[-1]) - if verbose: - print(f"Computing pairwise displacement: {i + 1} / {len(non_rigid_windows)}") - - pairwise_displacement, pairwise_displacement_weight = compute_pairwise_displacement( - motion_histogram[:, window_slice], - bin_um, - window=win[window_slice], - method=pairwise_displacement_method, - weight_scale=weight_scale, - error_sigma=error_sigma, - conv_engine=conv_engine, - torch_device=torch_device, - batch_size=batch_size, - max_displacement_um=max_displacement_um, - normalized_xcorr=normalized_xcorr, - centered_xcorr=centered_xcorr, - corr_threshold=corr_threshold, - time_horizon_s=time_horizon_s, - bin_duration_s=bin_duration_s, - progress_bar=False, - ) - - if spatial_prior: - all_pairwise_displacements[i] = pairwise_displacement - all_pairwise_displacement_weights[i] = pairwise_displacement_weight - - if extra_check is not None: - extra_check["pairwise_displacement_list"].append(pairwise_displacement) - - if verbose: - print(f"Computing global displacement: {i + 1} / {len(non_rigid_windows)}") - - # TODO: if spatial_prior, do this after the loop - if not spatial_prior: - motion[:, i] = compute_global_displacement( - pairwise_displacement, - pairwise_displacement_weight=pairwise_displacement_weight, - convergence_method=convergence_method, - robust_regression_sigma=robust_regression_sigma, - lsqr_robust_n_iter=lsqr_robust_n_iter, - temporal_prior=temporal_prior, - spatial_prior=spatial_prior, - soft_weights=soft_weights, - progress_bar=False, - ) - - if spatial_prior: - motion = compute_global_displacement( - all_pairwise_displacements, - pairwise_displacement_weight=all_pairwise_displacement_weights, - convergence_method=convergence_method, - robust_regression_sigma=robust_regression_sigma, - lsqr_robust_n_iter=lsqr_robust_n_iter, - temporal_prior=temporal_prior, - spatial_prior=spatial_prior, - soft_weights=soft_weights, - progress_bar=False, - ) - elif len(non_rigid_windows) > 1: - # if spatial_prior is False, we still want keep the spatial bins - # correctly offset from each other - if force_spatial_median_continuity: - for i in range(len(non_rigid_windows) - 1): - motion[:, i + 1] -= np.median(motion[:, i + 1] - motion[:, i]) - - # try to avoid constant offset - # let the user choose how to do this. here are some ideas. - # (one can also -= their own number on the result of this function.) - if reference_displacement == "mean": - motion -= motion.mean() - elif reference_displacement == "median": - motion -= np.median(motion) - elif reference_displacement == "time": - # reference the motion to 0 at a specific time, independently in each window - reference_displacement_bin = np.digitize(reference_displacement_time_s, temporal_hist_bin_edges) - 1 - motion -= motion[reference_displacement_bin, :] - elif reference_displacement == "mode_search": - # just a sketch of an idea - # things might want to change, should have a configurable bin size, - # should use a call to histogram instead of the loop, ... - step_size = 0.1 - round_mode = np.round # floor? - best_ref = np.median(motion) - max_zeros = np.sum(round_mode(motion - best_ref) == 0) - for ref in np.arange(np.floor(motion.min()), np.ceil(motion.max()), step_size): - n_zeros = np.sum(round_mode(motion - ref) == 0) - if n_zeros > max_zeros: - max_zeros = n_zeros - best_ref = ref - motion -= best_ref - - return motion, temporal_bins - - -class IterativeTemplateRegistration: - """ - Alignment function implemented by Kilosort2.5 and ported from pykilosort: - https://github.com/int-brain-lab/pykilosort/blob/ibl_prod/pykilosort/datashift2.py#L166 - - The main difference with respect to the original implementation are: - * scipy is used for gaussian smoothing - * windowing is implemented as gaussian tapering (instead of rectangular blocks) - * the 3d histogram is constructed in less cryptic way - * peak_locations are computed outside and so can either center fo mass or monopolar trianglation - contrary to kilosort2.5 use exclusively center of mass - - See https://www.science.org/doi/abs/10.1126/science.abf4588?cookieSet=1 - - Ported by Alessio Buccino into SpikeInterface - """ - - name = "iterative_template" - params_doc = """ - num_amp_bins: int, default: 20 - number ob bins in the histogram on the log amplitues dimension - num_shifts_global: int, default: 15 - Number of spatial bin shifts to consider for global alignment - num_iterations: int, default: 10 - Number of iterations for global alignment procedure - num_shifts_block: int, default: 5 - Number of spatial bin shifts to consider for non-rigid alignment - smoothing_sigma: float, default: 0.5 - Sigma of gaussian for covariance matrices smoothing - kriging_sigma: float, - sigma parameter for kriging_kernel function - kriging_p: foat - p parameter for kriging_kernel function - kriging_d: float - d parameter for kriging_kernel function - """ - - @classmethod - def run( - cls, - recording, - peaks, - peak_locations, - direction, - bin_duration_s, - bin_um, - spatial_bin_edges, - non_rigid_windows, - verbose, - progress_bar, - extra_check, - num_amp_bins=20, - num_shifts_global=15, - num_iterations=10, - num_shifts_block=5, - smoothing_sigma=0.5, - kriging_sigma=1, - kriging_p=2, - kriging_d=2, - ): - # make a 3D histogram - motion_histograms, temporal_hist_bin_edges, spatial_hist_bin_edges = make_3d_motion_histograms( - recording, - peaks, - peak_locations, - direction=direction, - num_amp_bins=num_amp_bins, - bin_duration_s=bin_duration_s, - spatial_bin_edges=spatial_bin_edges, - ) - # temporal bins are bin center - temporal_bins = temporal_hist_bin_edges[:-1] + bin_duration_s // 2.0 - - # do alignment - shift_indices, target_histogram, shift_covs_block = iterative_template_registration( - motion_histograms, - non_rigid_windows=non_rigid_windows, - num_shifts_global=num_shifts_global, - num_iterations=num_iterations, - num_shifts_block=num_shifts_block, - smoothing_sigma=smoothing_sigma, - kriging_sigma=kriging_sigma, - kriging_p=kriging_p, - kriging_d=kriging_d, - ) - - # convert to um - motion = -(shift_indices * bin_um) - - if extra_check: - extra_check["motion_histograms"] = motion_histograms - extra_check["target_histogram"] = target_histogram - extra_check["shift_covs_block"] = shift_covs_block - extra_check["temporal_hist_bin_edges"] = temporal_hist_bin_edges - extra_check["spatial_hist_bin_edges"] = spatial_hist_bin_edges - - return motion, temporal_bins - - -_methods_list = [DecentralizedRegistration, IterativeTemplateRegistration] -estimate_motion_methods = {m.name: m for m in _methods_list} -method_doc = make_multi_method_doc(_methods_list) -estimate_motion.__doc__ = estimate_motion.__doc__.format(method_doc=method_doc) - - -def get_spatial_bin_edges(recording, direction, margin_um, bin_um): - # contact along one axis - probe = recording.get_probe() - dim = ["x", "y", "z"].index(direction) - contact_pos = probe.contact_positions[:, dim] - - min_ = np.min(contact_pos) - margin_um - max_ = np.max(contact_pos) + margin_um - spatial_bins = np.arange(min_, max_ + bin_um, bin_um) - - return spatial_bins - - -def get_windows(rigid, bin_um, contact_pos, spatial_bin_edges, margin_um, win_step_um, win_sigma_um, win_shape): - """ - Generate spatial windows (taper) for non-rigid motion. - For rigid motion, this is equivalent to have one unique rectangular window that covers the entire probe. - The windowing can be gaussian or rectangular. - - Parameters - ---------- - rigid : bool - If True, returns a single rectangular window - bin_um : float - Spatial bin size in um - contact_pos : np.ndarray - Position of electrodes (num_channels, 2) - spatial_bin_edges : np.array - The pre-computed spatial bin edges - margin_um : float - The margin to extend (if positive) or shrink (if negative) the probe dimension to compute windows.= - win_step_um : float - The steps at which windows are defined - win_sigma_um : float - Sigma of gaussian window (if win_shape is gaussian) - win_shape : float - "gaussian" | "rect" - - Returns - ------- - non_rigid_windows : list of 1D arrays - The scaling for each window. Each element has num_spatial_bins values - non_rigid_window_centers: 1D np.array - The center of each window - - Notes - ----- - Note that kilosort2.5 uses overlaping rectangular windows. - Here by default we use gaussian window. - - """ - bin_centers = spatial_bin_edges[:-1] + bin_um / 2.0 - n = bin_centers.size - - if rigid: - # win_shape = 'rect' is forced - non_rigid_windows = [np.ones(n, dtype="float64")] - middle = (spatial_bin_edges[0] + spatial_bin_edges[-1]) / 2.0 - non_rigid_window_centers = np.array([middle]) - else: - assert win_sigma_um >= win_step_um, f"win_sigma_um too low {win_sigma_um} compared to win_step_um {win_step_um}" - - min_ = np.min(contact_pos) - margin_um - max_ = np.max(contact_pos) + margin_um - num_non_rigid_windows = int((max_ - min_) // win_step_um) - border = ((max_ - min_) % win_step_um) / 2 - non_rigid_window_centers = np.arange(num_non_rigid_windows + 1) * win_step_um + min_ + border - non_rigid_windows = [] - - for win_center in non_rigid_window_centers: - if win_shape == "gaussian": - win = np.exp(-((bin_centers - win_center) ** 2) / (2 * win_sigma_um**2)) - elif win_shape == "rect": - win = np.abs(bin_centers - win_center) < (win_sigma_um / 2.0) - win = win.astype("float64") - elif win_shape == "triangle": - center_dist = np.abs(bin_centers - win_center) - in_window = center_dist <= (win_sigma_um / 2.0) - win = -center_dist - win[~in_window] = 0 - win[in_window] -= win[in_window].min() - win[in_window] /= win[in_window].max() - - non_rigid_windows.append(win) - - return non_rigid_windows, non_rigid_window_centers - - -def make_2d_motion_histogram( - recording, - peaks, - peak_locations, - weight_with_amplitude=False, - direction="y", - bin_duration_s=1.0, - bin_um=2.0, - margin_um=50, - spatial_bin_edges=None, -): - """ - Generate 2d motion histogram in depth and time. - - Parameters - ---------- - recording : BaseRecording - The input recording - peaks : np.array - The peaks array - peak_locations : np.array - Array with peak locations - weight_with_amplitude : bool, default: False - If True, motion histogram is weighted by amplitudes - direction : "x" | "y" | "z", default: "y" - The depth direction - bin_duration_s : float, default: 1.0 - The temporal bin duration in s - bin_um : float, default: 2.0 - The spatial bin size in um. Ignored if spatial_bin_edges is given. - margin_um : float, default: 50 - The margin to add to the minimum and maximum positions before spatial binning. - Ignored if spatial_bin_edges is given. - spatial_bin_edges : np.array, default: None - The pre-computed spatial bin edges - - Returns - ------- - motion_histogram - 2d np.array with motion histogram (num_temporal_bins, num_spatial_bins) - temporal_bin_edges - 1d array with temporal bin edges - spatial_bin_edges - 1d array with spatial bin edges - """ - n_samples = recording.get_num_samples() - mint_s = recording.sample_index_to_time(0) - maxt_s = recording.sample_index_to_time(n_samples) - temporal_bin_edges = np.arange(mint_s, maxt_s + bin_duration_s, bin_duration_s) - if spatial_bin_edges is None: - spatial_bin_edges = get_spatial_bin_edges(recording, direction, margin_um, bin_um) - - arr = np.zeros((peaks.size, 2), dtype="float64") - arr[:, 0] = recording.sample_index_to_time(peaks["sample_index"]) - arr[:, 1] = peak_locations[direction] - - if weight_with_amplitude: - weights = np.abs(peaks["amplitude"]) - else: - weights = None - - motion_histogram, edges = np.histogramdd(arr, bins=(temporal_bin_edges, spatial_bin_edges), weights=weights) - - # average amplitude in each bin - if weight_with_amplitude: - bin_counts, _ = np.histogramdd(arr, bins=(temporal_bin_edges, spatial_bin_edges)) - bin_counts[bin_counts == 0] = 1 - motion_histogram = motion_histogram / bin_counts - - return motion_histogram, temporal_bin_edges, spatial_bin_edges - - -def make_3d_motion_histograms( - recording, - peaks, - peak_locations, - direction="y", - bin_duration_s=1.0, - bin_um=2.0, - margin_um=50, - num_amp_bins=20, - log_transform=True, - spatial_bin_edges=None, -): - """ - Generate 3d motion histograms in depth, amplitude, and time. - This is used by the "iterative_template_registration" (Kilosort2.5) method. - - - Parameters - ---------- - recording : BaseRecording - The input recording - peaks : np.array - The peaks array - peak_locations : np.array - Array with peak locations - direction : "x" | "y" | "z", default: "y" - The depth direction - bin_duration_s : float, default: 1.0 - The temporal bin duration in s. - bin_um : float, default: 2.0 - The spatial bin size in um. Ignored if spatial_bin_edges is given. - margin_um : float, default: 50 - The margin to add to the minimum and maximum positions before spatial binning. - Ignored if spatial_bin_edges is given. - log_transform : bool, default: True - If True, histograms are log-transformed - spatial_bin_edges : np.array, default: None - The pre-computed spatial bin edges - - Returns - ------- - motion_histograms - 3d np.array with motion histogram (num_temporal_bins, num_spatial_bins, num_amp_bins) - temporal_bin_edges - 1d array with temporal bin edges - spatial_bin_edges - 1d array with spatial bin edges - """ - n_samples = recording.get_num_samples() - mint_s = recording.sample_index_to_time(0) - maxt_s = recording.sample_index_to_time(n_samples) - temporal_bin_edges = np.arange(mint_s, maxt_s + bin_duration_s, bin_duration_s) - if spatial_bin_edges is None: - spatial_bin_edges = get_spatial_bin_edges(recording, direction, margin_um, bin_um) - - # pre-compute abs amplitude and ranges for scaling - amplitude_bin_edges = np.linspace(0, 1, num_amp_bins + 1) - abs_peaks = np.abs(peaks["amplitude"]) - max_peak_amp = np.max(abs_peaks) - min_peak_amp = np.min(abs_peaks) - # log amplitudes and scale between 0-1 - abs_peaks_log_norm = (np.log10(abs_peaks) - np.log10(min_peak_amp)) / ( - np.log10(max_peak_amp) - np.log10(min_peak_amp) - ) - - arr = np.zeros((peaks.size, 3), dtype="float64") - arr[:, 0] = recording.sample_index_to_time(peaks["sample_index"]) - arr[:, 1] = peak_locations[direction] - arr[:, 2] = abs_peaks_log_norm - - motion_histograms, edges = np.histogramdd( - arr, - bins=( - temporal_bin_edges, - spatial_bin_edges, - amplitude_bin_edges, - ), - ) - - if log_transform: - motion_histograms = np.log2(1 + motion_histograms) - - return motion_histograms, temporal_bin_edges, spatial_bin_edges - - -def compute_pairwise_displacement( - motion_hist, - bin_um, - method="conv", - weight_scale="linear", - error_sigma=0.2, - conv_engine="numpy", - torch_device=None, - batch_size=1, - max_displacement_um=1500, - corr_threshold=0, - time_horizon_s=None, - normalized_xcorr=True, - centered_xcorr=True, - bin_duration_s=None, - progress_bar=False, - window=None, -): - """ - Compute pairwise displacement - """ - from scipy import linalg - - assert conv_engine in ("torch", "numpy"), f"'conv_engine' must be 'torch' or 'numpy'" - size = motion_hist.shape[0] - pairwise_displacement = np.zeros((size, size), dtype="float32") - - if time_horizon_s is not None: - band_width = int(np.ceil(time_horizon_s / bin_duration_s)) - if band_width >= size: - time_horizon_s = None - - if conv_engine == "torch": - if torch_device is None: - torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - if method == "conv": - if max_displacement_um is None: - n = motion_hist.shape[1] // 2 - else: - n = min( - motion_hist.shape[1] // 2, - int(np.ceil(max_displacement_um // bin_um)), - ) - possible_displacement = np.arange(-n, n + 1) * bin_um - - xrange = trange if progress_bar else range - - motion_hist_engine = motion_hist - window_engine = window - if conv_engine == "torch": - motion_hist_engine = torch.as_tensor(motion_hist, dtype=torch.float32, device=torch_device) - window_engine = torch.as_tensor(window, dtype=torch.float32, device=torch_device) - - pairwise_displacement = np.empty((size, size), dtype=np.float32) - correlation = np.empty((size, size), dtype=motion_hist.dtype) - - for i in xrange(0, size, batch_size): - corr = normxcorr1d( - motion_hist_engine, - motion_hist_engine[i : i + batch_size], - weights=window_engine, - padding=possible_displacement.size // 2, - conv_engine=conv_engine, - normalized=normalized_xcorr, - centered=centered_xcorr, - ) - if conv_engine == "torch": - max_corr, best_disp_inds = torch.max(corr, dim=2) - best_disp = possible_displacement[best_disp_inds.cpu()] - pairwise_displacement[i : i + batch_size] = best_disp - correlation[i : i + batch_size] = max_corr.cpu() - elif conv_engine == "numpy": - best_disp_inds = np.argmax(corr, axis=2) - max_corr = np.take_along_axis(corr, best_disp_inds[..., None], 2).squeeze() - best_disp = possible_displacement[best_disp_inds] - pairwise_displacement[i : i + batch_size] = best_disp - correlation[i : i + batch_size] = max_corr - - if corr_threshold is not None and corr_threshold > 0: - which = correlation > corr_threshold - correlation *= which - - elif method == "phase_cross_correlation": - # this 'phase_cross_correlation' is an old idea from Julien/Charlie/Erden that is kept for testing - # but this is not very releveant - try: - import skimage.registration - except ImportError: - raise ImportError("To use the 'phase_cross_correlation' method install scikit-image") - - errors = np.zeros((size, size), dtype="float32") - loop = range(size) - if progress_bar: - loop = tqdm(loop) - for i in loop: - for j in range(size): - shift, error, diffphase = skimage.registration.phase_cross_correlation( - motion_hist[i, :], motion_hist[j, :] - ) - pairwise_displacement[i, j] = shift * bin_um - errors[i, j] = error - correlation = 1 - errors - - else: - raise ValueError( - f"method {method} does not exist for compute_pairwise_displacement. Current possible methods are" - f" 'conv' or 'phase_cross_correlation'" - ) - - if weight_scale == "linear": - # between 0 and 1 - pairwise_displacement_weight = correlation - elif weight_scale == "exp": - pairwise_displacement_weight = np.exp((correlation - 1) / error_sigma) - - # handle the time horizon by multiplying the weights by a - # matrix with the time horizon on its diagonal bands. - if method == "conv" and time_horizon_s is not None and time_horizon_s > 0: - horizon_matrix = linalg.toeplitz( - np.r_[np.ones(band_width, dtype=bool), np.zeros(size - band_width, dtype=bool)] - ) - pairwise_displacement_weight *= horizon_matrix - - return pairwise_displacement, pairwise_displacement_weight - - -_possible_convergence_method = ("lsmr", "gradient_descent", "lsqr_robust") - - -def compute_global_displacement( - pairwise_displacement, - pairwise_displacement_weight=None, - sparse_mask=None, - temporal_prior=True, - spatial_prior=True, - soft_weights=False, - convergence_method="lsmr", - robust_regression_sigma=2, - lsqr_robust_n_iter=20, - progress_bar=False, -): - """ - Compute global displacement - - Arguments - --------- - pairwise_displacement : time x time array - pairwise_displacement_weight : time x time array - sparse_mask : time x time array - convergence_method : str - One of "gradient" - - """ - import scipy - from scipy.optimize import minimize - from scipy.sparse import csr_matrix - from scipy.sparse.linalg import lsqr - from scipy.stats import zscore - - if convergence_method == "gradient_descent": - size = pairwise_displacement.shape[0] - - D = pairwise_displacement - if pairwise_displacement_weight is not None or sparse_mask is not None: - # weighted problem - if pairwise_displacement_weight is None: - pairwise_displacement_weight = np.ones_like(D) - if sparse_mask is None: - sparse_mask = np.ones_like(D) - W = pairwise_displacement_weight * sparse_mask - - I, J = np.nonzero(W > 0) - Wij = W[I, J] - Dij = D[I, J] - W = csr_matrix((Wij, (I, J)), shape=W.shape) - WD = csr_matrix((Wij * Dij, (I, J)), shape=W.shape) - fixed_terms = (W @ WD).diagonal() - (WD @ W).diagonal() - diag_WW = (W @ W).diagonal() - Wsq = W.power(2) - - def obj(p): - return 0.5 * np.square(Wij * (Dij - (p[I] - p[J]))).sum() - - def jac(p): - return fixed_terms - 2 * (Wsq @ p) + 2 * p * diag_WW - - else: - # unweighted problem, it's faster when we have no weights - fixed_terms = -D.sum(axis=1) + D.sum(axis=0) - - def obj(p): - v = np.square((D - (p[:, None] - p[None, :]))).sum() - return 0.5 * v - - def jac(p): - return fixed_terms + 2 * (size * p - p.sum()) - - res = minimize(fun=obj, jac=jac, x0=D.mean(axis=1), method="L-BFGS-B") - if not res.success: - print("Global displacement gradient descent had an error") - displacement = res.x - - elif convergence_method == "lsqr_robust": - - if sparse_mask is not None: - I, J = np.nonzero(sparse_mask > 0) - elif pairwise_displacement_weight is not None: - I, J = pairwise_displacement_weight.nonzero() - else: - I, J = np.nonzero(np.ones_like(pairwise_displacement, dtype=bool)) - - nnz_ones = np.ones(I.shape[0], dtype=pairwise_displacement.dtype) - - if pairwise_displacement_weight is not None: - if isinstance(pairwise_displacement_weight, scipy.sparse.csr_matrix): - W = np.array(pairwise_displacement_weight[I, J]).T - else: - W = pairwise_displacement_weight[I, J][:, None] - else: - W = nnz_ones[:, None] - if isinstance(pairwise_displacement, scipy.sparse.csr_matrix): - V = np.array(pairwise_displacement[I, J])[0] - else: - V = pairwise_displacement[I, J] - M = csr_matrix((nnz_ones, (range(I.shape[0]), I)), shape=(I.shape[0], pairwise_displacement.shape[0])) - N = csr_matrix((nnz_ones, (range(I.shape[0]), J)), shape=(I.shape[0], pairwise_displacement.shape[0])) - A = M - N - idx = np.ones(A.shape[0], dtype=bool) - - # TODO: this is already soft_weights - xrange = trange if progress_bar else range - for i in xrange(lsqr_robust_n_iter): - p = lsqr(A[idx].multiply(W[idx]), V[idx] * W[idx][:, 0])[0] - idx = np.nonzero(np.abs(zscore(A @ p - V)) <= robust_regression_sigma) - displacement = p - - elif convergence_method == "lsmr": - import gc - from scipy import sparse - - D = pairwise_displacement - - # weighted problem - if pairwise_displacement_weight is None: - pairwise_displacement_weight = np.ones_like(D) - if sparse_mask is None: - sparse_mask = np.ones_like(D) - W = pairwise_displacement_weight * sparse_mask - if isinstance(W, scipy.sparse.csr_matrix): - W = W.astype(np.float32).toarray() - D = D.astype(np.float32).toarray() - - assert D.shape == W.shape - - # first dimension is the windows dim, which could be empty in rigid case - # we expand dims so that below we can consider only the nonrigid case - if D.ndim == 2: - W = W[None] - D = D[None] - assert D.ndim == W.ndim == 3 - B, T, T_ = D.shape - assert T == T_ - - # sparsify the problem - # we will make a list of temporal problems and then - # stack over the windows axis to finish. - # each matrix in coefficients will be (sparse_dim, T) - coefficients = [] - # each vector in targets will be (T,) - targets = [] - # we want to solve for a vector of shape BT, which we will reshape - # into a (B, T) matrix. - # after the loop below, we will stack a coefts matrix (sparse_dim, B, T) - # and a target vector of shape (B, T), both to be vectorized on last two axes, - # so that the target p is indexed by i = bT + t (block/window major). - - # calculate coefficients matrices and target vector - # this list stores boolean masks corresponding to whether or not each - # term comes from the prior or the likelihood. we can trim the likelihood terms, - # but not the prior terms, in the trimmed least squares (robust iters) iterations below. - cannot_trim = [] - for Wb, Db in zip(W, D): - # indices of active temporal pairs in this window - I, J = np.nonzero(Wb > 0) - n_sampled = I.size - - # construct Kroneckers and sparse objective in this window - pair_weights = np.ones(n_sampled) - if soft_weights: - pair_weights = Wb[I, J] - Mb = sparse.csr_matrix((pair_weights, (range(n_sampled), I)), shape=(n_sampled, T)) - Nb = sparse.csr_matrix((pair_weights, (range(n_sampled), J)), shape=(n_sampled, T)) - block_sparse_kron = Mb - Nb - block_disp_pairs = pair_weights * Db[I, J] - cannot_trim_block = np.ones_like(block_disp_pairs, dtype=bool) - - # add the temporal smoothness prior in this window - if temporal_prior: - temporal_diff_operator = sparse.diags( - ( - np.full(T - 1, -1, dtype=block_sparse_kron.dtype), - np.full(T - 1, 1, dtype=block_sparse_kron.dtype), - ), - offsets=(0, 1), - shape=(T - 1, T), - ) - block_sparse_kron = sparse.vstack( - (block_sparse_kron, temporal_diff_operator), - format="csr", - ) - block_disp_pairs = np.concatenate( - (block_disp_pairs, np.zeros(T - 1)), - ) - cannot_trim_block = np.concatenate( - (cannot_trim_block, np.zeros(T - 1, dtype=bool)), - ) - - coefficients.append(block_sparse_kron) - targets.append(block_disp_pairs) - cannot_trim.append(cannot_trim_block) - coefficients = sparse.block_diag(coefficients) - targets = np.concatenate(targets, axis=0) - cannot_trim = np.concatenate(cannot_trim, axis=0) - - # spatial smoothness prior: penalize difference of each block's - # displacement with the next. - # only if B > 1, and not in the last window. - # this is a (BT, BT) sparse matrix D such that: - # entry at (i, j) is: - # { 1 if i = j, i.e., i = j = bT + t for b = 0,...,B-2 - # { -1 if i = bT + t and j = (b+1)T + t for b = 0,...,B-2 - # { 0 otherwise. - # put more simply, the first (B-1)T diagonal entries are 1, - # and entries (i, j) such that i = j - T are -1. - if B > 1 and spatial_prior: - spatial_diff_operator = sparse.diags( - ( - np.ones((B - 1) * T, dtype=block_sparse_kron.dtype), - np.full((B - 1) * T, -1, dtype=block_sparse_kron.dtype), - ), - offsets=(0, T), - shape=((B - 1) * T, B * T), - ) - coefficients = sparse.vstack((coefficients, spatial_diff_operator)) - targets = np.concatenate((targets, np.zeros((B - 1) * T, dtype=targets.dtype))) - cannot_trim = np.concatenate((cannot_trim, np.zeros((B - 1) * T, dtype=bool))) - coefficients = coefficients.tocsr() - - # initialize at the column mean of pairwise displacements (in each window) - p0 = D.mean(axis=2).reshape(B * T) - - # use LSMR to solve the whole problem || targets - coefficients @ motion ||^2 - iters = range(max(1, lsqr_robust_n_iter)) - if progress_bar and lsqr_robust_n_iter > 1: - iters = tqdm(iters, desc="robust lsqr") - for it in iters: - # trim active set -- start with no trimming - idx = slice(None) - if it: - idx = np.flatnonzero( - cannot_trim | (np.abs(zscore(coefficients @ displacement - targets)) <= robust_regression_sigma) - ) - - # solve trimmed ols problem - displacement, *_ = sparse.linalg.lsmr(coefficients[idx], targets[idx], x0=p0) - - # warm start next iteration - p0 = displacement - # Cleanup lsmr memory (see https://stackoverflow.com/questions/56147713/memory-leak-in-scipy) - # TODO: check if this gets fixed in scipy - gc.collect() - - displacement = displacement.reshape(B, T).T - else: - raise ValueError( - f"Method {convergence_method} doesn't exist for compute_global_displacement" - f" possible values for 'convergence_method' are {_possible_convergence_method}" - ) - - return np.squeeze(displacement) - - -def iterative_template_registration( - spikecounts_hist_images, - non_rigid_windows=None, - num_shifts_global=15, - num_iterations=10, - num_shifts_block=5, - smoothing_sigma=0.5, - kriging_sigma=1, - kriging_p=2, - kriging_d=2, -): - """ - - Parameters - ---------- - - spikecounts_hist_images : np.ndarray - Spike count histogram images (num_temporal_bins, num_spatial_bins, num_amps_bins) - non_rigid_windows : list, default: None - If num_non_rigid_windows > 1, this argument is required and it is a list of - windows to taper spatial bins in different blocks - num_shifts_global : int, default: 15 - Number of spatial bin shifts to consider for global alignment - num_iterations : int, default: 10 - Number of iterations for global alignment procedure - num_shifts_block : int, default: 5 - Number of spatial bin shifts to consider for non-rigid alignment - smoothing_sigma : float, default: 0.5 - Sigma of gaussian for covariance matrices smoothing - kriging_sigma : float, default: 1 - sigma parameter for kriging_kernel function - kriging_p : float, default: 2 - p parameter for kriging_kernel function - kriging_d : float, default: 2 - d parameter for kriging_kernel function - - Returns - ------- - optimal_shift_indices - Optimal shifts for each temporal and spatial bin (num_temporal_bins, num_non_rigid_windows) - target_spikecount_hist - Target histogram used for alignment (num_spatial_bins, num_amps_bins) - """ - from scipy.ndimage import gaussian_filter, gaussian_filter1d - - # F is y bins by amp bins by batches - # ysamp are the coordinates of the y bins in um - spikecounts_hist_images = spikecounts_hist_images.swapaxes(0, 1).swapaxes(1, 2) - num_temporal_bins = spikecounts_hist_images.shape[2] - - # look up and down this many y bins to find best alignment - shift_covs = np.zeros((2 * num_shifts_global + 1, num_temporal_bins)) - shifts = np.arange(-num_shifts_global, num_shifts_global + 1) - - # mean subtraction to compute covariance - F = spikecounts_hist_images - Fg = F - np.mean(F, axis=0) - - # initialize the target "frame" for alignment with a single sample - # here we removed min(299, ...) - F0 = Fg[:, :, np.floor(num_temporal_bins / 2).astype("int") - 1] - F0 = F0[:, :, np.newaxis] - - # first we do rigid registration by integer shifts - # everything is iteratively aligned until most of the shifts become 0. - best_shifts = np.zeros((num_iterations, num_temporal_bins)) - for iteration in range(num_iterations): - for t, shift in enumerate(shifts): - # for each NEW potential shift, estimate covariance - Fs = np.roll(Fg, shift, axis=0) - shift_covs[t, :] = np.mean(Fs * F0, axis=(0, 1)) - if iteration + 1 < num_iterations: - # estimate the best shifts - imax = np.argmax(shift_covs, axis=0) - # align the data by these integer shifts - for t, shift in enumerate(shifts): - ibest = imax == t - Fg[:, :, ibest] = np.roll(Fg[:, :, ibest], shift, axis=0) - best_shifts[iteration, ibest] = shift - # new target frame based on our current best alignment - F0 = np.mean(Fg, axis=2)[:, :, np.newaxis] - target_spikecount_hist = F0[:, :, 0] - - # now we figure out how to split the probe into nblocks pieces - # if len(non_rigid_windows) = 1, then we're doing rigid registration - num_non_rigid_windows = len(non_rigid_windows) - - # for each small block, we only look up and down this many samples to find - # nonrigid shift - shifts_block = np.arange(-num_shifts_block, num_shifts_block + 1) - num_shifts = len(shifts_block) - shift_covs_block = np.zeros((2 * num_shifts_block + 1, num_temporal_bins, num_non_rigid_windows)) - - # this part determines the up/down covariance for each block without - # shifting anything - for window_index in range(num_non_rigid_windows): - win = non_rigid_windows[window_index] - window_slice = np.flatnonzero(win > 1e-5) - window_slice = slice(window_slice[0], window_slice[-1]) - tiled_window = win[window_slice, np.newaxis, np.newaxis] - Ftaper = Fg[window_slice] * np.tile(tiled_window, (1,) + Fg.shape[1:]) - for t, shift in enumerate(shifts_block): - Fs = np.roll(Ftaper, shift, axis=0) - F0taper = F0[window_slice] * np.tile(tiled_window, (1,) + F0.shape[1:]) - shift_covs_block[t, :, window_index] = np.mean(Fs * F0taper, axis=(0, 1)) - - # gaussian smoothing: - # here the original my_conv2_cpu is substituted with scipy gaussian_filters - shift_covs_block_smooth = shift_covs_block.copy() - shifts_block_up = np.linspace(-num_shifts_block, num_shifts_block, (2 * num_shifts_block * 10) + 1) - # 1. 2d smoothing over time and blocks dimensions for each shift - for shift_index in range(num_shifts): - shift_covs_block_smooth[shift_index, :, :] = gaussian_filter( - shift_covs_block_smooth[shift_index, :, :], smoothing_sigma - ) # some additional smoothing for robustness, across all dimensions - # 2. 1d smoothing over shift dimension for each spatial block - for window_index in range(num_non_rigid_windows): - shift_covs_block_smooth[:, :, window_index] = gaussian_filter1d( - shift_covs_block_smooth[:, :, window_index], smoothing_sigma, axis=0 - ) # some additional smoothing for robustness, across all dimensions - upsample_kernel = kriging_kernel( - shifts_block[:, np.newaxis], shifts_block_up[:, np.newaxis], sigma=kriging_sigma, p=kriging_p, d=kriging_d - ) - - optimal_shift_indices = np.zeros((num_temporal_bins, num_non_rigid_windows)) - for window_index in range(num_non_rigid_windows): - # using the upsampling kernel K, get the upsampled cross-correlation - # curves - upsampled_cov = upsample_kernel.T @ shift_covs_block_smooth[:, :, window_index] - - # find the max index of these curves - imax = np.argmax(upsampled_cov, axis=0) - - # add the value of the shift to the last row of the matrix of shifts - # (as if it was the last iteration of the main rigid loop ) - best_shifts[num_iterations - 1, :] = shifts_block_up[imax] - - # the sum of all the shifts equals the final shifts for this block - optimal_shift_indices[:, window_index] = np.sum(best_shifts, axis=0) - - return optimal_shift_indices, target_spikecount_hist, shift_covs_block - - -def normxcorr1d( - template, - x, - weights=None, - centered=True, - normalized=True, - padding="same", - conv_engine="torch", -): - """normxcorr1d: Normalized cross-correlation, optionally weighted - - The API is like torch's F.conv1d, except I have accidentally - changed the position of input/weights -- template acts like weights, - and x acts like input. - - Returns the cross-correlation of `template` and `x` at spatial lags - determined by `mode`. Useful for estimating the location of `template` - within `x`. - - This might not be the most efficient implementation -- ideas welcome. - It uses a direct convolutional translation of the formula - corr = (E[XY] - EX EY) / sqrt(var X * var Y) - - This also supports weights! In that case, the usual adaptation of - the above formula is made to the weighted case -- and all of the - normalizations are done per block in the same way. - - Parameters - ---------- - template : tensor, shape (num_templates, length) - The reference template signal - x : tensor, 1d shape (length,) or 2d shape (num_inputs, length) - The signal in which to find `template` - weights : tensor, shape (length,) - Will use weighted means, variances, covariances if supplied. - centered : bool - If true, means will be subtracted (per weighted patch). - normalized : bool - If true, normalize by the variance (per weighted patch). - padding : str - How far to look? if unset, we'll use half the length - conv_engine : string, one of "torch", "numpy" - What library to use for computing cross-correlations. - If numpy, falls back to the scipy correlate function. - - Returns - ------- - corr : tensor - """ - if conv_engine == "torch": - assert HAVE_TORCH - conv1d = F.conv1d - npx = torch - elif conv_engine == "numpy": - conv1d = scipy_conv1d - npx = np - else: - raise ValueError(f"Unknown conv_engine {conv_engine}. 'conv_engine' must be 'torch' or 'numpy'") - - x = npx.atleast_2d(x) - num_templates, length = template.shape - num_inputs, length_ = template.shape - assert length == length_ - - # generalize over weighted / unweighted case - device_kw = {} if conv_engine == "numpy" else dict(device=x.device) - ones = npx.ones((1, 1, length), dtype=x.dtype, **device_kw) - no_weights = weights is None - if no_weights: - weights = ones - wt = template[:, None, :] - else: - assert weights.shape == (length,) - weights = weights[None, None] - wt = template[:, None, :] * weights - - # conv1d valid rule: - # (B,1,L),(O,1,L)->(B,O,L) - - # compute expectations - # how many points in each window? seems necessary to normalize - # for numerical stability. - N = conv1d(ones, weights, padding=padding) - if centered: - Et = conv1d(ones, wt, padding=padding) - Et /= N - Ex = conv1d(x[:, None, :], weights, padding=padding) - Ex /= N - - # compute (weighted) covariance - # important: the formula E[XY] - EX EY is well-suited here, - # because the means are naturally subtracted correctly - # patch-wise. you couldn't pre-subtract them! - cov = conv1d(x[:, None, :], wt, padding=padding) - cov /= N - if centered: - cov -= Ex * Et - - # compute variances for denominator, using var X = E[X^2] - (EX)^2 - if normalized: - var_template = conv1d(ones, wt * template[:, None, :], padding=padding) - var_template /= N - var_x = conv1d(npx.square(x)[:, None, :], weights, padding=padding) - var_x /= N - if centered: - var_template -= npx.square(Et) - var_x -= npx.square(Ex) - - # now find the final normxcorr - corr = cov # renaming for clarity - if normalized: - corr /= npx.sqrt(var_x) - corr /= npx.sqrt(var_template) - # get rid of NaNs in zero-variance areas - corr[~npx.isfinite(corr)] = 0 - - return corr - - -def scipy_conv1d(input, weights, padding="valid"): - """SciPy translation of torch F.conv1d""" - from scipy.signal import correlate - - n, c_in, length = input.shape - c_out, in_by_groups, kernel_size = weights.shape - assert in_by_groups == c_in == 1 - - if padding == "same": - mode = "same" - length_out = length - elif padding == "valid": - mode = "valid" - length_out = length - 2 * (kernel_size // 2) - elif isinstance(padding, int): - mode = "valid" - input = np.pad(input, [*[(0, 0)] * (input.ndim - 1), (padding, padding)]) - length_out = length - (kernel_size - 1) + 2 * padding - else: - raise ValueError(f"Unknown 'padding' value of {padding}, 'padding' must be 'same', 'valid' or an integer") - - output = np.zeros((n, c_out, length_out), dtype=input.dtype) - for m in range(n): - for c in range(c_out): - output[m, c] = correlate(input[m, 0], weights[c, 0], mode=mode) - - return output - - -def clean_motion_vector(motion, temporal_bins, bin_duration_s, speed_threshold=30, sigma_smooth_s=None): - """ - Simple machinery to remove spurious fast bump in the motion vector. - Also can applyt a smoothing. - - - Arguments - --------- - motion: numpy array 2d - Motion estimate in um. - temporal_bins: numpy.array 1d - temporal bins (bin center) - bin_duration_s: float - bin duration in second - speed_threshold: float (units um/s) - Maximum speed treshold between 2 bins allowed. - Expressed in um/s - sigma_smooth_s: None or float - Optional smooting gaussian kernel. - - Returns - ------- - corr : tensor - - - """ - motion_clean = motion.copy() - - # STEP 1 : - # * detect long plateau or small peak corssing the speed thresh - # * mask the period and interpolate - for i in range(motion.shape[1]): - one_motion = motion_clean[:, i] - speed = np.diff(one_motion, axis=0) / bin_duration_s - (inds,) = np.nonzero(np.abs(speed) > speed_threshold) - inds += 1 - if inds.size % 2 == 1: - # more compicated case: number of of inds is odd must remove first or last - # take the smallest duration sum - inds0 = inds[:-1] - inds1 = inds[1:] - d0 = np.sum(inds0[1::2] - inds0[::2]) - d1 = np.sum(inds1[1::2] - inds1[::2]) - if d0 < d1: - inds = inds0 - mask = np.ones(motion_clean.shape[0], dtype="bool") - for i in range(inds.size // 2): - mask[inds[i * 2] : inds[i * 2 + 1]] = False - import scipy.interpolate - - f = scipy.interpolate.interp1d(temporal_bins[mask], one_motion[mask]) - one_motion[~mask] = f(temporal_bins[~mask]) - - # Step 2 : gaussian smooth - if sigma_smooth_s is not None: - half_size = motion_clean.shape[0] // 2 - if motion_clean.shape[0] % 2 == 0: - # take care of the shift - bins = (np.arange(motion_clean.shape[0]) - half_size + 1) * bin_duration_s - else: - bins = (np.arange(motion_clean.shape[0]) - half_size) * bin_duration_s - smooth_kernel = np.exp(-(bins**2) / (2 * sigma_smooth_s**2)) - smooth_kernel /= np.sum(smooth_kernel) - smooth_kernel = smooth_kernel[:, None] - motion_clean = scipy.signal.fftconvolve(motion_clean, smooth_kernel, mode="same", axes=0) - - return motion_clean - - -def kriging_kernel(source_location, target_location, sigma=1, p=2, d=2): - from scipy.spatial.distance import cdist - - dist_xy = cdist(source_location, target_location, metric="euclidean") - K = np.exp(-((dist_xy / sigma) ** p) / d) - return K diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py deleted file mode 100644 index a8de3f6d13..0000000000 --- a/src/spikeinterface/sortingcomponents/motion_utils.py +++ /dev/null @@ -1,234 +0,0 @@ -import json -from pathlib import Path - -import numpy as np -import spikeinterface -from spikeinterface.core.core_tools import check_json - - -class Motion: - """ - Motion of the tissue relative the probe. - - Parameters - ---------- - displacement : numpy array 2d or list of - Motion estimate in um. - List is the number of segment. - For each semgent : - * shape (temporal bins, spatial bins) - * motion.shape[0] = temporal_bins.shape[0] - * motion.shape[1] = 1 (rigid) or spatial_bins.shape[1] (non rigid) - temporal_bins_s : numpy.array 1d or list of - temporal bins (bin center) - spatial_bins_um : numpy.array 1d - Windows center. - spatial_bins_um.shape[0] == displacement.shape[1] - If rigid then spatial_bins_um.shape[0] == 1 - direction : str, default: 'y' - Direction of the motion. - interpolation_method : str - How to determine the displacement between bin centers? See the docs - for scipy.interpolate.RegularGridInterpolator for options. - """ - - def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y", interpolation_method="linear"): - if isinstance(displacement, np.ndarray): - self.displacement = [displacement] - assert isinstance(temporal_bins_s, np.ndarray) - self.temporal_bins_s = [temporal_bins_s] - else: - assert isinstance(displacement, (list, tuple)) - self.displacement = displacement - self.temporal_bins_s = temporal_bins_s - - assert isinstance(spatial_bins_um, np.ndarray) - self.spatial_bins_um = spatial_bins_um - - self.num_segments = len(self.displacement) - self.interpolators = None - self.interpolation_method = interpolation_method - - self.direction = direction - self.dim = ["x", "y", "z"].index(direction) - self.check_properties() - - def check_properties(self): - assert all(d.ndim == 2 for d in self.displacement) - assert all(t.ndim == 1 for t in self.temporal_bins_s) - assert all(self.spatial_bins_um.shape == (d.shape[1],) for d in self.displacement) - - def __repr__(self): - nbins = self.spatial_bins_um.shape[0] - if nbins == 1: - rigid_txt = "rigid" - else: - rigid_txt = f"non-rigid - {nbins} spatial bins" - - interval_s = self.temporal_bins_s[0][1] - self.temporal_bins_s[0][0] - txt = f"Motion {rigid_txt} - interval {interval_s}s - {self.num_segments} segments" - return txt - - def make_interpolators(self): - from scipy.interpolate import RegularGridInterpolator - - self.interpolators = [ - RegularGridInterpolator( - (self.temporal_bins_s[j], self.spatial_bins_um), self.displacement[j], method=self.interpolation_method - ) - for j in range(self.num_segments) - ] - self.temporal_bounds = [(t[0], t[-1]) for t in self.temporal_bins_s] - self.spatial_bounds = (self.spatial_bins_um.min(), self.spatial_bins_um.max()) - - def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_index=None, grid=False): - """Evaluate the motion estimate at times and positions - - Evaluate the motion estimate, returning the (linearly interpolated) estimated displacement - at the given times and locations. - - Parameters - ---------- - times_s: np.array - The time points at which to evaluate the displacement. - locations_um: np.array - Either this is a one-dimensional array (a vector of positions along self.dimension), or - else a 2d array with the 2 or 3 spatial dimensions indexed along axis=1. - segment_index: int, default: None - The index of the segment to evaluate. If None, and there is only one segment, then that segment is used. - grid : bool, default: False - If grid=False, the default, then times_s and locations_um should have the same one-dimensional - shape, and the returned displacement[i] is the displacement at time times_s[i] and location - locations_um[i]. - If grid=True, times_s and locations_um determine a grid of positions to evaluate the displacement. - Then the returned displacement[i,j] is the displacement at depth locations_um[i] and time times_s[j]. - - Returns - ------- - displacement : np.array - A displacement per input location, of shape times_s.shape if grid=False and (locations_um.size, times_s.size) - if grid=True. - """ - if self.interpolators is None: - self.make_interpolators() - - if segment_index is None: - if self.num_segments == 1: - segment_index = 0 - else: - raise ValueError("Several segment need segment_index=") - - times_s = np.asarray(times_s) - locations_um = np.asarray(locations_um) - - if locations_um.ndim == 1: - locations_um = locations_um - elif locations_um.ndim == 2: - locations_um = locations_um[:, self.dim] - else: - assert False - - times_s = times_s.clip(*self.temporal_bounds[segment_index]) - locations_um = locations_um.clip(*self.spatial_bounds) - - if grid: - # construct a grid over which to evaluate the displacement - locations_um, times_s = np.meshgrid(locations_um, times_s, indexing="ij") - out_shape = times_s.shape - locations_um = locations_um.ravel() - times_s = times_s.ravel() - else: - # usual case: input is a point cloud - assert locations_um.shape == times_s.shape - assert times_s.ndim == 1 - out_shape = times_s.shape - - points = np.column_stack((times_s, locations_um)) - displacement = self.interpolators[segment_index](points) - # reshape to grid domain shape if necessary - displacement = displacement.reshape(out_shape) - - return displacement - - def to_dict(self): - return dict( - displacement=self.displacement, - temporal_bins_s=self.temporal_bins_s, - spatial_bins_um=self.spatial_bins_um, - direction=self.direction, - interpolation_method=self.interpolation_method, - ) - - def save(self, folder): - folder = Path(folder) - folder.mkdir(exist_ok=False, parents=True) - - info_file = folder / f"spikeinterface_info.json" - info = dict( - version=spikeinterface.__version__, - dev_mode=spikeinterface.DEV_MODE, - object="Motion", - num_segments=self.num_segments, - direction=self.direction, - interpolation_method=self.interpolation_method, - ) - with open(info_file, mode="w") as f: - json.dump(check_json(info), f, indent=4) - - np.save(folder / "spatial_bins_um.npy", self.spatial_bins_um) - - for segment_index in range(self.num_segments): - np.save(folder / f"displacement_seg{segment_index}.npy", self.displacement[segment_index]) - np.save(folder / f"temporal_bins_s_seg{segment_index}.npy", self.temporal_bins_s[segment_index]) - - @classmethod - def load(cls, folder): - folder = Path(folder) - - info_file = folder / f"spikeinterface_info.json" - err_msg = f"Motion.load(folder): the folder {folder} does not contain a Motion object." - if not info_file.exists(): - raise IOError(err_msg) - - with open(info_file, "r") as f: - info = json.load(f) - if "object" not in info or info["object"] != "Motion": - raise IOError(err_msg) - - direction = info["direction"] - interpolation_method = info["interpolation_method"] - spatial_bins_um = np.load(folder / "spatial_bins_um.npy") - displacement = [] - temporal_bins_s = [] - for segment_index in range(info["num_segments"]): - displacement.append(np.load(folder / f"displacement_seg{segment_index}.npy")) - temporal_bins_s.append(np.load(folder / f"temporal_bins_s_seg{segment_index}.npy")) - - return cls( - displacement, - temporal_bins_s, - spatial_bins_um, - direction=direction, - interpolation_method=interpolation_method, - ) - - def __eq__(self, other): - for segment_index in range(self.num_segments): - if not np.allclose(self.displacement[segment_index], other.displacement[segment_index]): - return False - if not np.allclose(self.temporal_bins_s[segment_index], other.temporal_bins_s[segment_index]): - return False - - if not np.allclose(self.spatial_bins_um, other.spatial_bins_um): - return False - - return True - - def copy(self): - return Motion( - [d.copy() for d in self.displacement], - [t.copy() for t in self.temporal_bins_s], - [s.copy() for s in self.spatial_bins_um], - direction=self.direction, - interpolation_method=self.interpolation_method, - ) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 0b79350a62..81cda212b2 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -5,7 +5,7 @@ from .base import BaseWidget, to_attr from spikeinterface.core import BaseRecording, SortingAnalyzer -from spikeinterface.sortingcomponents.motion_utils import Motion +from spikeinterface.sortingcomponents.motion import Motion class MotionWidget(BaseWidget): @@ -230,7 +230,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from matplotlib.colors import Normalize from .utils_matplotlib import make_mpl_figure - from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks + from spikeinterface.sortingcomponents.motion import correct_motion_on_peaks dp = to_attr(data_plot) @@ -291,12 +291,10 @@ class MotionInfoWidget(BaseWidget): ---------- motion_info : dict The motion info returned by correct_motion() or loaded back with load_motion_info(). + recording : RecordingExtractor + The recording extractor object segment_index : int, default: None The segment index to display. - recording : RecordingExtractor, default: None - The recording extractor object (only used to get "real" times). - segment_index : int, default: 0 - The segment index to display. sampling_frequency : float, default: None The sampling frequency (needed if recording is None). depth_lim : tuple or None, default: None @@ -320,8 +318,8 @@ class MotionInfoWidget(BaseWidget): def __init__( self, motion_info: dict, + recording: BaseRecording, segment_index: int | None = None, - recording: BaseRecording | None = None, depth_lim: tuple[float, float] | None = None, motion_lim: tuple[float, float] | None = None, color_amplitude: bool = False, @@ -366,7 +364,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): from .utils_matplotlib import make_mpl_figure - from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks + from spikeinterface.sortingcomponents.motion import correct_motion_on_peaks dp = to_attr(data_plot)