diff --git a/debugging/_test_session_alignment.py b/debugging/_test_session_alignment.py index 972a2af5b8..f71c754222 100644 --- a/debugging/_test_session_alignment.py +++ b/debugging/_test_session_alignment.py @@ -134,8 +134,8 @@ def _prep_recording(recording, plot=False): return peaks, peak_locations -MOTION = False # True -SAVE = False +MOTION = True # True +SAVE = True PLOT = False BIN_UM = 5 @@ -224,7 +224,7 @@ def _prep_recording(recording, plot=False): "chunked_bin_size_s": "estimate", "log_scale": True, "depth_smooth_um": 10, - "histogram_type": "activity_1d", # "y_only", "2Dy_x", "2Dy_amplitude"" TOOD: better names! + "histogram_type": "activity_1d", # "y_only", "locations_2d", "activity_2d"" TOOD: better names! } compute_alignment_kwargs = { "num_shifts_block": None, # TODO: can be in um so comaprable with window kwargs. @@ -247,9 +247,7 @@ def _prep_recording(recording, plot=False): } if MOTION: - from session_alignment import align_sessions_after_motion_correction - - corrected_recordings_list, extra_info = align_sessions_after_motion_correction( + corrected_recordings_list, extra_info = session_alignment.align_sessions_after_motion_correction( recordings_list, motion_info_list, align_sessions_kwargs={ @@ -271,7 +269,19 @@ def _prep_recording(recording, plot=False): compute_alignment_kwargs=compute_alignment_kwargs, ) +<<<<<<< HEAD + +plotting_session_alignment.SessionAlignmentWidget( + recordings_list, + peaks_list, + peak_locations_list, + extra_info["session_histogram_list"], + **extra_info["corrected"], + spatial_bin_centers=extra_info["spatial_bin_centers"], + drift_raster_map_kwargs={"clim":(-250, 0)} # TODO: option to fix this across recordings. +) +======= plotting_session_alignment.SessionAlignmentWidget( recordings_list, peaks_list, @@ -282,14 +292,22 @@ def _prep_recording(recording, plot=False): drift_raster_map_kwargs={"clim":(-250, 0)} # TODO: option to fix this across recordings. ) +>>>>>>> 978c5343c (Reformatting alignment methods and add 2D, need to tidy up.) plt.show() # TODO: estimate chunk size Hz needs to be scaled for time? is it not been done correctly? +<<<<<<< HEAD # No, even two sessions is a mess # TODO: working assumptions, maybe after rigid, make a template for nonrigid alignment # as at the moment all nonrigid to eachother is a mess if False: +======= +if False: + # No, even two sessions is a mess + # TODO: working assumptions, maybe after rigid, make a template for nonrigid alignment + # as at the moment all nonrigid to eachother is a mess +>>>>>>> 978c5343c (Reformatting alignment methods and add 2D, need to tidy up.) A = extra_info["histogram_info_list"][2]["chunked_histograms"] mean_ = alignment_utils.get_chunked_hist_mean(A) diff --git a/debugging/peak_locs_1.npy b/debugging/peak_locs_1.npy new file mode 100644 index 0000000000..fc70903e6b Binary files /dev/null and b/debugging/peak_locs_1.npy differ diff --git a/debugging/peak_locs_2.npy b/debugging/peak_locs_2.npy new file mode 100644 index 0000000000..b88eb42a4c Binary files /dev/null and b/debugging/peak_locs_2.npy differ diff --git a/debugging/peaks_1.npy b/debugging/peaks_1.npy new file mode 100644 index 0000000000..26de87d743 Binary files /dev/null and b/debugging/peaks_1.npy differ diff --git a/debugging/peaks_2.npy b/debugging/peaks_2.npy new file mode 100644 index 0000000000..4332b48b12 Binary files /dev/null and b/debugging/peaks_2.npy differ diff --git a/debugging/playing.py b/debugging/playing.py index a571fcd1a9..793cb3ef29 100644 --- a/debugging/playing.py +++ b/debugging/playing.py @@ -1,4 +1,3 @@ - from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_localization import localize_peaks @@ -9,57 +8,112 @@ ) import matplotlib.pyplot as plt +import spikeinterface.full as si -# -------------------------------------------------------------------------------------- -# Load / generate some recordings -# -------------------------------------------------------------------------------------- +si.set_global_job_kwargs(n_jobs=10) -recordings_list, _ = generate_session_displacement_recordings( - num_units=50, - recording_durations=[50, 50, 50], - recording_shifts=((0, 0), (0, 50), (0, 75)) -) -# -------------------------------------------------------------------------------------- -# Compute the peaks / locations with your favourite method -# -------------------------------------------------------------------------------------- -# Note if you did motion correction the peaks are on the motion object. -# There is a function 'session_alignment.align_sessions_after_motion_correction() -# you can use instead of the below. - -peaks_list, peak_locations_list = session_alignment.compute_peaks_locations_for_session_alignment( - recordings_list, - detect_kwargs={"method": "locally_exclusive"}, - localize_peaks_kwargs={"method": "grid_convolution"}, -) +if __name__ == '__main__': -# -------------------------------------------------------------------------------------- -# Do the estimation -# -------------------------------------------------------------------------------------- -# For each session, an 'activity histogram' is generated. This can be `entire_session` -# or the session can be chunked into segments and some summary generated taken over then. -# This might be useful if periods of the recording have weird kinetics or noise. -# See `session_alignment.py` for docs on these settings. - -non_rigid_window_kwargs = session_alignment.get_non_rigid_window_kwargs() -non_rigid_window_kwargs["rigid"] = True - -corrected_recordings_list, extra_info = session_alignment.align_sessions( - recordings_list, - peaks_list, - peak_locations_list, - alignment_order="to_session_1", # "to_session_X" or "to_middle" - non_rigid_window_kwargs=non_rigid_window_kwargs, -) + # -------------------------------------------------------------------------------------- + # Load / generate some recordings + # -------------------------------------------------------------------------------------- -plotting_session_alignment.SessionAlignmentWidget( - recordings_list, - peaks_list, - peak_locations_list, - extra_info["session_histogram_list"], - **extra_info["corrected"], - spatial_bin_centers=extra_info["spatial_bin_centers"], - drift_raster_map_kwargs={"clim":(-250, 0), "scatter_decimate": 10} -) -plt.show() + + recordings_list, _ = generate_session_displacement_recordings( + num_units=20, + recording_durations=[400, 400, 400], + recording_shifts=((0, 0), (0, 200), (0, -125)), + non_rigid_gradient=0.005, + seed=52, + ) + if False: + import numpy as np + + + recordings_list = [ + si.read_zarr(r"C:\Users\Joe\Downloads\M25_D18_2024-11-05_12-38-28_VR1.zarr\M25_D18_2024-11-05_12-38-28_VR1.zarr"), + si.read_zarr(r"C:\Users\Joe\Downloads\M25_D18_2024-11-05_12-08-47_OF1.zarr\M25_D18_2024-11-05_12-08-47_OF1.zarr"), + ] + + recordings_list = [si.astype(rec, np.float32) for rec in recordings_list] + recordings_list = [si.bandpass_filter(rec) for rec in recordings_list] + recordings_list = [si.common_reference(rec, operator="median") for rec in recordings_list] + + # -------------------------------------------------------------------------------------- + # Compute the peaks / locations with your favourite method + # -------------------------------------------------------------------------------------- + # Note if you did motion correction the peaks are on the motion object. + # There is a function 'session_alignment.align_sessions_after_motion_correction() + # you can use instead of the below. + + peaks_list, peak_locations_list = session_alignment.compute_peaks_locations_for_session_alignment( + recordings_list, + detect_kwargs={"method": "locally_exclusive"}, + localize_peaks_kwargs={"method": "grid_convolution"}, + ) + + if False: + np.save("peaks_1.npy", peaks_list[0]) + np.save("peaks_2.npy", peaks_list[1]) + np.save("peak_locs_1.npy", peak_locations_list[0]) + np.save("peak_locs_2.npy", peak_locations_list[1]) + + if False: + peaks_list = [np.load("peaks_1.npy"), np.load("peaks_2.npy")] + peak_locations_list = [np.load("peak_locs_1.npy"), np.load("peak_locs_2.npy")] + + # -------------------------------------------------------------------------------------- + # Do the estimation + # -------------------------------------------------------------------------------------- + # For each session, an 'activity histogram' is generated. This can be `entire_session` + # or the session can be chunked into segments and some summary generated taken over then. + # This might be useful if periods of the recording have weird kinetics or noise. + # See `session_alignment.py` for docs on these settings. + + non_rigid_window_kwargs = session_alignment.get_non_rigid_window_kwargs() + non_rigid_window_kwargs["rigid"] = False + # non_rigid_window_kwargs["win_shape"] = "rect" + # non_rigid_window_kwargs["win_step_um"] = 25 + + estimate_histogram_kwargs = session_alignment.get_estimate_histogram_kwargs() + estimate_histogram_kwargs["method"] = "chunked_mean" + estimate_histogram_kwargs["histogram_type"] = "activity_1d" + estimate_histogram_kwargs["bin_um"] = 5 + + corrected_recordings_list, extra_info = session_alignment.align_sessions( + recordings_list, + peaks_list, + peak_locations_list, + alignment_order="to_session_1", # "to_session_X" or "to_middle" + non_rigid_window_kwargs=non_rigid_window_kwargs, + estimate_histogram_kwargs=estimate_histogram_kwargs, + ) + + # TODO: nonlinear is not working well 'to middle', investigate + # TODO: also finalise the estimation of bin number of nonrigid. + + if False: + plt.plot(extra_info["histogram_info_list"][0]["chunked_histograms"].T, color="black") + + M = extra_info["session_histogram_list"][0] + S = extra_info["histogram_info_list"][0]["session_histogram_variation"] + + plt.plot(M, color="red") + plt.plot(M + S, color="green") + plt.plot(M - S, color="green") + + plt.show() + + plotting_session_alignment.SessionAlignmentWidget( + recordings_list, + peaks_list, + peak_locations_list, + extra_info["session_histogram_list"], + **extra_info["corrected"], + spatial_bin_centers=extra_info["spatial_bin_centers"], + drift_raster_map_kwargs={"clim":(-250, 0), "scatter_decimate": 10} + ) + + plt.show() diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py b/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py index 69f196e6fa..bac8a38b0e 100644 --- a/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py +++ b/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py @@ -1,5 +1,7 @@ from spikeinterface import BaseRecording import numpy as np + +from spikeinterface.preprocessing import center from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram from scipy.optimize import minimize from scipy.ndimage import gaussian_filter @@ -77,7 +79,7 @@ def get_activity_histogram( activity_histogram *= scaler if log_scale: - activity_histogram = np.log10(1 + activity_histogram) + activity_histogram = np.log10(1 + activity_histogram) # TODO: make_2d_motion_histogram uses log2 return activity_histogram, temporal_bin_centers, spatial_bin_centers @@ -405,7 +407,7 @@ def compute_histogram_crosscorrelation( mean does not make sense over sessions. """ num_sessions = len(session_histogram_list) - num_bins = session_histogram_list[0].size # all hists are same length + num_bins = session_histogram_list.shape[1] # all hists are same length num_windows = non_rigid_windows.shape[0] shift_matrix = np.zeros((num_sessions, num_sessions, num_windows)) @@ -416,36 +418,75 @@ def compute_histogram_crosscorrelation( for j in range(num_sessions): # Create the (num windows, num_bins) matrix for this pair of sessions - xcorr_matrix = np.zeros((non_rigid_windows.shape[0], num_bins * 2 - 1)) + + import matplotlib.pyplot as plt + + # TODO: plot everything + + num_iter = ( + num_bins * 2 - 1 if not num_shifts_block else num_shifts_block * 2 + ) # TODO: make sure this is clearly defined, it is either side... + xcorr_matrix = np.zeros((non_rigid_windows.shape[0], num_iter)) # For each window, window the session histograms (`window` is binary) # and perform the cross correlations for win_idx, window in enumerate(non_rigid_windows): - windowed_histogram_i = session_histogram_list[i, :] * window - windowed_histogram_j = session_histogram_list[j, :] * window - xcorr = np.correlate( - windowed_histogram_i, windowed_histogram_j, mode="full" - ) # TODO: add weight option. + # breakpoint() + # TODO: track the 2d histogram through all steps to check everything is working okay + + # TOOD: gaussian window with crosscorr, won't it strongly bias zero shifts by maximising the signal at 0? + # TODO: add weight option. + # TODO: damn this is slow for 2D, speed up. + if session_histogram_list.ndim == 3: + windowed_histogram_i = session_histogram_list[i, :] * window[:, np.newaxis] + windowed_histogram_j = session_histogram_list[j, :] * window[:, np.newaxis] + + from scipy.signal import correlate2d + + # carefully check indices + xcorr = correlate2d( + windowed_histogram_i - np.mean(windowed_histogram_i, axis=1)[:, np.newaxis], + windowed_histogram_j - np.mean(windowed_histogram_j, axis=1)[:, np.newaxis], + ) # TOOD: check speed, probs don't remove mean because we want zeros for unmasked version + + mid_idx = windowed_histogram_j.shape[1] - 1 + xcorr = xcorr[:, mid_idx] + + else: + windowed_histogram_i = session_histogram_list[i, :] * window + + window_target = True # this makes less sense now that things could be very far apart + if window_target: + windowed_histogram_j = session_histogram_list[j, :] * window + else: + windowed_histogram_j = session_histogram_list[j, :] + + xcorr = np.correlate(windowed_histogram_i, windowed_histogram_j, mode="full") + + # plt.plot(windowed_histogram_i) + # plt.plot(windowed_histogram_j) + # plt.show() if num_shifts_block: window_indices = np.arange(center_bin - num_shifts_block, center_bin + num_shifts_block) - mask = np.zeros_like(xcorr) - mask[window_indices] = 1 - xcorr *= mask + xcorr = xcorr[window_indices] + shift_center_bin = ( + num_shifts_block # np.floor(num_shifts_block / 2) # TODO: CHECK! and move out of loop! + ) + else: + shift_center_bin = center_bin + + # plt.plot(xcorr) + # plt.show() xcorr_matrix[win_idx, :] = xcorr + # TODO: check absolute value of different bins, they are quite different (log scale, zero mean histograms) + # TODO: print out a load of quality metrics from this! + # Smooth the cross-correlations across the bins if smoothing_sigma_bin: - breakpoint() - import matplotlib.pyplot as plt - - plt.plot(xcorr_matrix[0, :]) - X = gaussian_filter(xcorr_matrix, smoothing_sigma_bin, axes=1) - plt.plot(X[0, :]) - plt.show() - xcorr_matrix = gaussian_filter(xcorr_matrix, smoothing_sigma_bin, axes=1) # Smooth the cross-correlations across the windows @@ -470,7 +511,9 @@ def compute_histogram_crosscorrelation( else: xcorr_peak = np.argmax(xcorr_matrix, axis=1) - shift = xcorr_peak - center_bin + # breakpoint() + + shift = xcorr_peak - shift_center_bin # center_bin shift_matrix[i, j, :] = shift return shift_matrix @@ -502,8 +545,16 @@ def shift_array_fill_zeros(array: np.ndarray, shift: int) -> np.ndarray: """ abs_shift = np.abs(shift) pad_tuple = (0, abs_shift) if shift > 0 else (abs_shift, 0) + + if array.ndim == 2: + pad_tuple = (pad_tuple, (0, 0)) + padded_hist = np.pad(array, pad_tuple, mode="constant") - cut_padded_array = padded_hist[abs_shift:] if shift >= 0 else padded_hist[:-abs_shift] + + if padded_hist.ndim == 2: + cut_padded_array = padded_hist[abs_shift:, :] if shift >= 0 else padded_hist[:-abs_shift, :] # TOOD: tidy up + else: + cut_padded_array = padded_hist[abs_shift:] if shift >= 0 else padded_hist[:-abs_shift] return cut_padded_array diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/plotting_session_alignment.py b/src/spikeinterface/preprocessing/inter_session_alignment/plotting_session_alignment.py index 3c69e915de..ac305c839d 100644 --- a/src/spikeinterface/preprocessing/inter_session_alignment/plotting_session_alignment.py +++ b/src/spikeinterface/preprocessing/inter_session_alignment/plotting_session_alignment.py @@ -311,6 +311,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): elif isinstance(spatial_bin_centers, np.ndarray): spatial_bin_centers = [spatial_bin_centers] * num_histograms + # TOOD: For 2D histogram, will need to subplot and just plot the individual histograms... for i in range(num_histograms): self.ax.plot(spatial_bin_centers[i], dp.session_histogram_list[i], color=colors[i], linewidth=linewidths[i]) diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py index 3ee7eac8d1..a583c5d672 100644 --- a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py +++ b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py @@ -66,7 +66,7 @@ def get_compute_alignment_kwargs() -> dict: windows along the probe depth. See `get_spatial_windows`. """ return { - "num_shifts_block": 5, + "num_shifts_block": 100, # TODO: estimate this properly, make take as some factor of the window width? Also check if it is 2x the block xcorr in motion correction "interpolate": False, "interp_factor": 10, "kriging_sigma": 1, @@ -109,6 +109,8 @@ def get_interpolate_motion_kwargs(): # Public Entry Level Functions ############################################################################### +# TODO: sometimes with small bins, the interpolation spreads the signal over too small a bin and flattens it on the corrected histogram + def align_sessions( recordings_list: list[BaseRecording], @@ -684,9 +686,7 @@ def _create_motion_recordings( if isinstance(recording, InterpolateMotionRecording): - print( - "Recording is already an `InterpolateMotionRecording. " "Adding shifts directly the recording object." - ) + print("Recording is already an `InterpolateMotionRecording. Adding shifts directly the recording object.") corrected_recording = _add_displacement_to_interpolate_recording(recording, motion) else: @@ -882,13 +882,20 @@ def _compute_session_alignment( return rigid_shifts, non_rigid_windows, non_rigid_window_centers # For non-rigid, first shift the histograms according to the rigid shift - shifted_histograms = np.zeros_like(session_histogram_array) - for ses_idx, orig_histogram in enumerate(session_histogram_array): - shifted_histogram = alignment_utils.shift_array_fill_zeros( - array=orig_histogram, shift=int(rigid_shifts[ses_idx, 0]) - ) - shifted_histograms[ses_idx, :] = shifted_histogram + # When there is non-rigid drift, the rigid drift can be very wrong! + # So we depart from the kilosort approach for inter-session, + # for non-rigid, it makes sense to start without rigid alignment + shifted_histograms = session_histogram_array.copy() + + if False: + shifted_histograms = np.zeros_like(session_histogram_array) + for ses_idx, orig_histogram in enumerate(session_histogram_array): + + shifted_histogram = alignment_utils.shift_array_fill_zeros( + array=orig_histogram, shift=int(rigid_shifts[ses_idx, 0]) + ) + shifted_histograms[ses_idx, :] = shifted_histogram # Then compute the nonrigid shifts nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation( @@ -901,10 +908,10 @@ def _compute_session_alignment( interp_nonrigid_shifts = alignment_utils.akima_interpolate_nonrigid_shifts( non_rigid_shifts, non_rigid_window_centers, spatial_bin_centers ) - shifts = rigid_shifts + interp_nonrigid_shifts + shifts = interp_nonrigid_shifts # rigid_shifts + interp_nonrigid_shifts non_rigid_window_centers = spatial_bin_centers else: - shifts = rigid_shifts + non_rigid_shifts + shifts = non_rigid_shifts # rigid_shifts + non_rigid_shifts return shifts, non_rigid_windows, non_rigid_window_centers @@ -937,12 +944,12 @@ def _estimate_rigid_alignment( compute_alignment_kwargs = copy.deepcopy(compute_alignment_kwargs) compute_alignment_kwargs["num_shifts_block"] = False - rigid_window = np.ones_like(session_histogram_array[0, :])[np.newaxis, :] + rigid_window = np.ones(session_histogram_array.shape[1])[np.newaxis, :] rigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation( session_histogram_array, rigid_window, - **compute_alignment_kwargs, + **compute_alignment_kwargs, # TODO: remove the copy above and pass directly. COnsider removing this function... ) optimal_shift_indices = alignment_utils.get_shifts_from_session_matrix( alignment_order, rigid_session_offsets_matrix diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index 5af0def67c..e54721a447 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -330,6 +330,8 @@ def get_spatial_windows( window_centers = np.arange(num_windows + 1) * win_step_um + min_ + border windows = [] + print("CENTERS: ", window_centers.size) + for win_center in window_centers: if win_shape == "gaussian": win = np.exp(-((spatial_bin_centers - win_center) ** 2) / (2 * win_scale_um**2))