diff --git a/debugging/DELinter_session_displacement.py b/debugging/DELinter_session_displacement.py new file mode 100644 index 0000000000..fabad71825 --- /dev/null +++ b/debugging/DELinter_session_displacement.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +import copy + +import numpy as np +import json +from pathlib import Path +import time + +from spikeinterface.core.baserecording import BaseRecording +from spikeinterface.core import get_noise_levels, fix_job_kwargs, get_random_data_chunks +from spikeinterface.core.job_tools import _shared_job_kwargs_doc +from spikeinterface.core.core_tools import SIJsonEncoder +from spikeinterface.core.job_tools import _shared_job_kwargs_doc + +# TODO: update motion docstrings around the 'select' step. + + +# TODO: +# 1) detect peaks and peak locations if not already provided. +# - could use only a subset of data, for ease now just estimate +# everything on the entire dataset +# 2) Calcualte the activity histogram across the entire session +# - will be better ways to estimate this, i.e. from the end +# of the session, from periods of stability, etc. +# taking a weighted average of histograms +# 3) Optimise for drift correction for each session across +# all histograms, minimising lost data at edges and keeping +# shift similar for all sessions. Could alternatively shift +# to the average histogram but this seems like a bad idea. +# 4) Store the motion vectors, ether adding to existing (of motion +# objects passed) otherwise. + + +def correct_inter_session_displacement( + recordings_list: list[BaseRecording], + existing_motion_info: Optional[list[Dict]] = None, + keep_channels_constant=False, + detect_kwargs={}, # TODO: make non-mutable (same for motion.py) + select_kwargs={}, + localize_peaks_kwargs={}, + job_kwargs={}, +): + from spikeinterface.sortingcomponents.peak_detection import detect_peaks, detect_peak_methods + from spikeinterface.sortingcomponents.peak_detection import detect_peaks, detect_peak_methods + from spikeinterface.sortingcomponents.peak_selection import select_peaks + from spikeinterface.sortingcomponents.peak_localization import localize_peaks, localize_peak_methods + from spikeinterface.sortingcomponents.motion.motion_estimation import estimate_motion + from spikeinterface.sortingcomponents.motion.motion_interpolation import InterpolateMotionRecording + from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline + from spikeinterface.sortingcomponents.motion.motion_utils import Motion, get_spatial_windows + + # TODO: do not accept multi-segment recordings. + # TODO: check all recordings have the same probe dimensions! + # Check if exsting_motion_info is passed then the recordings have the motion vector (I guess this is stored somewhere? maybe it is on the motion object) + if existing_motion_info is not None: + if not isinstance(existing_motion_info, list) and len(recordings_list) != len(existing_motion_info): + raise ValueError( + "`estimate_motion_info` if provided, must be" + "a list of `motion_info` with each associated with" + "the corresponding recording in `recordings_list`." + ) + + # TODO: do not handle select peaks option yet as probably better to chunk + # rather than select peaks? no sure can discuss. + if existing_motion_info is None: + + peaks_list = [] + peak_locations_list = [] + + for recording in recordings_list: + # TODO: this is a direct copy from motion.detect_motion(). + # Factor into own function in motion.py + gather_mode = "memory" + # node detect + method = detect_kwargs.pop("method", "locally_exclusive") + method_class = detect_peak_methods[method] + node0 = method_class(recording, **detect_kwargs) + + node1 = ExtractDenseWaveforms(recording, parents=[node0], ms_before=0.1, ms_after=0.3) + + # node detect + localize + method = localize_peaks_kwargs.pop("method", "center_of_mass") + method_class = localize_peak_methods[method] + node2 = method_class(recording, parents=[node0, node1], return_output=True, **localize_peaks_kwargs) + pipeline_nodes = [node0, node1, node2] + + peaks, peak_locations = run_node_pipeline( + recording, + pipeline_nodes, + job_kwargs, + job_name="detect and localize", + gather_mode=gather_mode, + gather_kwargs=None, + squeeze_output=False, + folder=None, + names=None, + ) + peaks_list.append(peaks) + peak_locations_list.append(peak_locations) + else: + peaks_list = [info["peaks"] for info in existing_motion_info] + peak_locations_list = [info["peak_locations"] for info in existing_motion_info] + + from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram, make_3d_motion_histograms + + # make motion histogram + motion_histogram_dim = "2D" # "2D" or "3D", for now only handle 2D case + + motion_histogram_list = [] + all_temporal_bin_edges = [] # TODO: fix naming + + bin_um = 2 # TODO: critial paraneter. easier to take no binning and gaus smooth? + + # TODO: own function + for recording, peaks, peak_locations in zip( + recordings_list, + peaks_list, + peak_locations_list, # TODO: this is overwriting above variable names. Own function! + ): # TODO: do a lot of checks to make sure these bin sizes make sesnese + # Do some checks on temporal and spatial bin edges that they are all the same? + + if motion_histogram_dim == "2D": + motion_histogram = make_2d_motion_histogram( + recording, + peaks, + peak_locations, + weight_with_amplitude=False, + direction="y", + bin_s=recording.get_duration(segment_index=0), # 1.0, + bin_um=bin_um, + hist_margin_um=50, + spatial_bin_edges=None, + ) + else: + assert NotImplementedError # TODO: might be old API pre-dredge + motion_histogram = make_3d_motion_histograms( + recording, + peaks, + peak_locations, + direction="y", + bin_duration_s=recording.get_duration(segment_index=0), # 1.0, + bin_um=bin_um, + margin_um=50, + num_amp_bins=20, + log_transform=True, + spatial_bin_edges=None, + ) + motion_histogram_list.append(motion_histogram[0].squeeze()) + # store bin edges + all_temporal_bin_edges.append(motion_histogram[1]) + spatial_bin_edges_um = motion_histogram[2] # should be same across all recordings + + # Do some checks on temporal and spatial bin edges that they are all the same? + # TODO: do some smoothing? Try some other methds (e.g. NMI, KL divergence) + # Let's do a very basic optimisation to find the best midpoint, just + # align everything to the first session. This isn't great because + # introduces some bias. Maybe align to all sessions and then take some + # average. Certainly cannot optimise brute force over the whole space + # which is (2P-1)^N where P is length of motion histogram and N is number of recordings. + # TODO: double-check what is done in kilosort-like / DREDGE + # put histograms into X and do X^T X then mean(U), det or eigs of covar mat + # can try iterative template. Not sure it will work so well taking the mean + # over only a few histograms that could be wildy different. + # Displacemene + num_recordings = len(recordings_list) + + shifts = np.zeros(num_recordings) + + # TODO: not checked any of the below properly + first_hist = motion_histogram_list[0] / motion_histogram_list[0].sum() + # first_hist -= np.mean(first_hist) # TODO: pretty sure not necessary + + for i in range(1, num_recordings): + + hist = motion_histogram_list[i] / motion_histogram_list[i].sum() + # hist -= np.mean(hist) # TODO: pretty sure not necessary + conv = np.correlate(first_hist, hist, mode="full") + + if conv.size % 2 == 0: + midpoint = conv.size / 2 + else: + midpoint = (conv.size - 1) / 2 # TODO: carefully double check! + + # TODO: think will need to make this negative + shifts[i] = (midpoint - np.argmax(conv)) * bin_um # # TODO: the bin spacing is super important for resoltuion + + # half + # TODO: need to figure out interpolation to the center point, weird;y + # the below does not work + # shifts[0] = (shifts[1] / 2) + # shifts[1] = (shifts[1] / 2) * -1 + # print("SHIFTS", shifts) + # TODO: handle only the 2D case for now + # TODO: do multi-session optimisation + + # Handle drift + interpolate_motion_kwargs = {} + + # TODO: add motion to motion if exists otherwise create InterpolateMotionRecording object! + # Will need the y-axis bins for this + all_recording_corrected = [] + all_motion_info = [] + for i, recording in enumerate(recordings_list): + + # TODO: direct copy, use 'get_window' from motion machinery + if False: + bin_centers = spatial_bin_edges[:-1] + bin_um / 2.0 + n = bin_centers.size + non_rigid_windows = [np.ones(n, dtype="float64")] + middle = (spatial_bin_edges[0] + spatial_bin_edges[-1]) / 2.0 + non_rigid_window_centers = np.array([middle]) + + dim = 1 # ["x", "y", "z"].index(direction) + contact_depths = recording.get_channel_locations()[:, dim] + spatial_bin_centers = 0.5 * (spatial_bin_edges_um[1:] + spatial_bin_edges_um[:-1]) + + _, window_centers = get_spatial_windows( + contact_depths, spatial_bin_centers, rigid=True # TODO: handle non-rigid case + ) + # win_shape=win_shape, TODO: handle defaults better + # win_step_um=win_step_um, + # win_scale_um=win_scale_um, + # win_margin_um=win_margin_um, + # zero_threshold=1e-5, + + # if shifts[i] == 0: + ## all_recording_corrected.append(recording) # TODO + # continue + temporal_bin_edges = all_temporal_bin_edges[i] + temporal_bins = 0.5 * (temporal_bin_edges[1:] + temporal_bin_edges[:-1]) + + motion_array = np.zeros((temporal_bins.size, window_centers.size)) # TODO: check this is the expected shape + motion_array[:, :] = shifts[i] # TODO: this is the rigid case! + + motion = Motion( + [motion_array], [temporal_bins], window_centers, direction="y" + ) # will be same for all except for shifts + all_motion_info.append(motion) # not certain on this + + if isinstance(recording, InterpolateMotionRecording): + raise NotImplementedError + recording_corrected = copy.deepcopy(recording) + # TODO: add interpolation to the existing one. + # Not if inter-session motion correction already exists, but further + # up the preprocessing chain, it will NOT be added and interpolation + # will occur twice. Throw a warning here! + else: + recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs) + all_recording_corrected.append(recording_corrected) + + displacement_info = { + "all_motion_info": all_motion_info, + "all_motion_histograms": motion_histogram_list, # TODO: naming + "all_shifts": shifts, + } + + if keep_channels_constant: + # TODO: use set + import functools + + common_channels = functools.reduce( + np.intersect1d, [recording.channel_ids for recording in all_recording_corrected] + ) + + all_recording_corrected = [recording.channel_slice(common_channels) for recording in all_recording_corrected] + + return all_recording_corrected, displacement_info # TODO: output more stuff later e.g. the Motion object diff --git a/debugging/__init__.py b/debugging/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/debugging/_test_session_alignment.py b/debugging/_test_session_alignment.py new file mode 100644 index 0000000000..f71c754222 --- /dev/null +++ b/debugging/_test_session_alignment.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings +import matplotlib.pyplot as plt +import numpy as np +import pickle +from spikeinterface.preprocessing.inter_session_alignment import ( + session_alignment, + plotting_session_alignment, + alignment_utils +) +from spikeinterface.sortingcomponents.peak_detection import detect_peaks +from spikeinterface.sortingcomponents.peak_localization import localize_peaks +import spikeinterface.full as si + + +# TODO: all of the nonrigid methods (and even rigid) could be having some strange affects on AP +# waveforms. definately needs looking into! + +# TODO: ask about best way to chunk, as ofc the peak detection takes +# recording as inputs so cannot use get traces with chunks function as planned. +# TODO: expose trimmed versions, robust xcorr + +# Note, the cross correlation is intrinsically limited because for large +# shifts the value is too reduced by the reduction in number of points. +# but, of course cannot scale by number of points due to instability at edges +# This is a major problem, e.g. see the strange results for: +""" + scalings = [np.ones(25), np.r_[np.zeros(10), np.ones(15)]] + recordings_list, _ = generate_session_displacement_recordings( + non_rigid_gradient=None, # 0.05, # 0.05, + num_units=55, + recording_durations=(100, 100, 100, 100), + recording_shifts=( + (0, 0), (0, 250), (0, -150), (0, -210), + ), + recording_amplitude_scalings=None, # {"method": "by_amplitude_and_firing_rate", "scalings": scalings}, + generate_unit_locations_kwargs={"margin_um": 0, "minimum_z": 0, "maximum_z": 0}, + seed=42, + ) +""" + +""" +TODO: in this case, it is not necessary to run peak detection across + the entire recording, would probably be sufficient to + take a few chunks, of size determined by firing frequency of + the neurons in the recording (or just take user defined size). + For now, run on the entire recording and discuss the best way to + run on chunked sections with Sam. +""" + +# with nonrigid shift. This is less of a problem when restricting to a small +# windwo for the nonrigid because even if it fails catistrophically the nonrigid +# error will only be max(non rigid shifts). But its still not good. + +# TODO: add different modes (to mean, to nth session...) +# TODO: document that the output is Hz + +# TODO: major check, refactor and tidy up +# list out carefully all notes +# handle the case where the passed recordings are not motion correction recordings. + +# 3) think about and add new neurons that are introduced when shifted + +# 4) add interpolation of the histograms prior to cross correlation +# 5) add robust cross-correlation +# 6) add trimmed methods +# 7) add better way to estimate chunk length. + +# try and interpolate /smooth the xcorr. What about smoothing the activity histograms directly? +# look into te akima spline + +# TODO: think about the nonrigid alignment, it correlates +# over the entire window. is this wise? try cutting it down a bit? + +# TODO: We this interpolate, smooth for xcorr in both rigid and non-rigid case. Think aboout this / check this is ok +# maybe we only want to apply the smoothings etc for nonrigid like KS motion correction + +# TODO: try forcing all unit locations to actually +# be within the probe. Add some notes on this because it is confusing. + +# 1) write argument checks +# 2) go through with a fine tooth comb, fix all outstanding issues, tidy up, +# plot everything to make sure it is working prior to writing tests. + +# 4) to an optimisation shift and scale instead of the current xcorr method. +# 5) finalise estimation of chunk size (skip for now, try a new alignment method) +# and optimal bin size. +# 6) make some presets? should estimate a lot of parameters based on the data, especially for nonrigid. +# these are all basically based on probe geometry. + + +# Note, shifting can move a unit closer to the channel e.g. if separated +# by 20 um which can increase the signal and make shift estimation harder. + +# go through everything and plot to check before writing tests. +# there is a relationship between bin size and nonrigid bins. If bins are +# too small then nonrigid is very unstable. So either choosing a bigger bin +# size or smoothing over the histogram in relation to the number +# of nonrigid bins may make sense. + +# the results for nonrigid are very dependent on chosen parameters, +# in particular the number of nonrigid windows, gaussian scale, +# smoothing of the histgram. An optimaisation method may also +# serve to help reduce the number of parameters to choose. + +# what you really want is for the window size to adapt to how +# busy the histogram is. + +# Suprisingly, the session that is aligned TO can have a +# major affect. + +# problem with current: +# - xcorr is not the best for large shifts due to lower num overlapping samples +# - + +def _prep_recording(recording, plot=False): + """ + :param recording: + :return: + """ + peaks = detect_peaks(recording, method="locally_exclusive") + + peak_locations = localize_peaks(recording, peaks, method="grid_convolution") + + if plot: + si.plot_drift_raster_map( + peaks=peaks, + peak_locations=peak_locations, + recording=recording, + clim=(-300, 0), # fix clim for comparability across plots + ) + plt.show() + + return peaks, peak_locations + +MOTION = True # True +SAVE = True +PLOT = False +BIN_UM = 5 + + +if SAVE: + scalings = [np.ones(25), np.r_[np.zeros(10), np.ones(15)]] + recordings_list, _ = generate_session_displacement_recordings( + non_rigid_gradient=None, # 0.05, # 0.05, # 0.05, + num_units=55, + recording_durations=(50, 50), # , 100), + recording_shifts=( + (0, 0), + (0, 75), + ), + recording_amplitude_scalings=None, # {"method": "by_amplitude_and_firing_rate", "scalings": scalings}, + generate_unit_locations_kwargs={"margin_um": 0, "minimum_z": 0, "maximum_z": 0}, + generate_templates_kwargs=dict( + ms_before=1.5, + ms_after=3.0, + mode="sphere", # this is key to maintaining consistent unit positions with shift + unit_params=dict( + alpha=(75, 125.0), # firing rate + spatial_decay=(10, 45), + ), + ), + seed=42, + ) + + if not MOTION: + peaks_list = [] + peak_locations_list = [] + + for recording in recordings_list: + peaks, peak_locations = _prep_recording( + recording, + plot=PLOT, + ) + peaks_list.append(peaks) + peak_locations_list.append(peak_locations) + + # something relatively easy, only 15 units + with open("all_recordings.pickle", "wb") as handle: + pickle.dump((recordings_list, peaks_list, peak_locations_list), handle, protocol=pickle.HIGHEST_PROTOCOL) + else: + # if False: + # TODO: need to align spatial bin calculation between estimate motion and + # estimate session methods so they are more easily interoperable. OR + # just take spatial bin centers from interpoalte! + recordings_list_new = [] + peaks_list = [] + peak_locations_list = [] + motion_info_list = [] + from spikeinterface.preprocessing.motion import correct_motion + + for i in range(len(recordings_list)): + new_recording, motion_info = correct_motion( + recordings_list[i], + output_motion_info=True, + estimate_motion_kwargs={ + "rigid": False, + "win_shape": "gaussian", + "win_step_um": 50, + "win_margin_um": 0, + }, + ) + recordings_list_new.append(new_recording) + motion_info_list.append(motion_info) + recordings_list = recordings_list_new + + with open("all_recordings_motion.pickle", "wb") as handle: + pickle.dump((recordings_list, motion_info_list), handle, protocol=pickle.HIGHEST_PROTOCOL) + +if MOTION: + with open("all_recordings_motion.pickle", "rb") as handle: + recordings_list, motion_info_list = pickle.load(handle) +else: + with open("all_recordings.pickle", "rb") as handle: + recordings_list, peaks_list, peak_locations_list = pickle.load(handle) + +# TODO: need docs to be super clear from estimate from existing motion, +# as will use motion correction nonrigid bins even if it is suboptimal. + +estimate_histogram_kwargs = { + "bin_um": BIN_UM, + "method": "first_eigenvector", # CHANGE NAME!! # TODO: double check scaling + "chunked_bin_size_s": "estimate", + "log_scale": True, + "depth_smooth_um": 10, + "histogram_type": "activity_1d", # "y_only", "locations_2d", "activity_2d"" TOOD: better names! +} +compute_alignment_kwargs = { + "num_shifts_block": None, # TODO: can be in um so comaprable with window kwargs. + "interpolate": False, + "interp_factor": 10, + "kriging_sigma": 1, + "kriging_p": 2, + "kriging_d": 2, + "smoothing_sigma_bin": False, # 0.5, + "smoothing_sigma_window": False, # 0.5, + "akima_interp_nonrigid": False, +} +non_rigid_window_kwargs = { + "rigid": False, + "win_shape": "gaussian", + "win_step_um": 250, + "win_scale_um": 250, + "win_margin_um": None, + "zero_threshold": None, +} + +if MOTION: + corrected_recordings_list, extra_info = session_alignment.align_sessions_after_motion_correction( + recordings_list, + motion_info_list, + align_sessions_kwargs={ + "alignment_order": "to_middle", + "estimate_histogram_kwargs": estimate_histogram_kwargs, + "compute_alignment_kwargs": compute_alignment_kwargs, + "non_rigid_window_kwargs": non_rigid_window_kwargs, + } + ) + peaks_list = [info["peaks"] for info in motion_info_list] + peak_locations_list = [info["peak_locations"] for info in motion_info_list] +else: + corrected_recordings_list, extra_info = session_alignment.align_sessions( + recordings_list, + peaks_list, + peak_locations_list, + alignment_order="to_session_1", + estimate_histogram_kwargs=estimate_histogram_kwargs, + compute_alignment_kwargs=compute_alignment_kwargs, + ) + +<<<<<<< HEAD + +plotting_session_alignment.SessionAlignmentWidget( + recordings_list, + peaks_list, + peak_locations_list, + extra_info["session_histogram_list"], + **extra_info["corrected"], + spatial_bin_centers=extra_info["spatial_bin_centers"], + drift_raster_map_kwargs={"clim":(-250, 0)} # TODO: option to fix this across recordings. +) + +======= +plotting_session_alignment.SessionAlignmentWidget( + recordings_list, + peaks_list, + peak_locations_list, + extra_info["session_histogram_list"], + **extra_info["corrected"], + spatial_bin_centers=extra_info["spatial_bin_centers"], + drift_raster_map_kwargs={"clim":(-250, 0)} # TODO: option to fix this across recordings. +) + +>>>>>>> 978c5343c (Reformatting alignment methods and add 2D, need to tidy up.) +plt.show() + +# TODO: estimate chunk size Hz needs to be scaled for time? is it not been done correctly? + +<<<<<<< HEAD +# No, even two sessions is a mess +# TODO: working assumptions, maybe after rigid, make a template for nonrigid alignment +# as at the moment all nonrigid to eachother is a mess +if False: +======= +if False: + # No, even two sessions is a mess + # TODO: working assumptions, maybe after rigid, make a template for nonrigid alignment + # as at the moment all nonrigid to eachother is a mess +>>>>>>> 978c5343c (Reformatting alignment methods and add 2D, need to tidy up.) + A = extra_info["histogram_info_list"][2]["chunked_histograms"] + + mean_ = alignment_utils.get_chunked_hist_mean(A) + median_ = alignment_utils.get_chunked_hist_median(A) + supremum_ = alignment_utils.get_chunked_hist_supremum(A) + poisson_ = alignment_utils.get_chunked_hist_poisson_estimate(A) + eigenvector_ = alignment_utils.get_chunked_hist_eigenvector(A) + + plt.plot(mean_) + plt.plot(median_) + plt.plot(supremum_) + plt.plot(poisson_) + plt.plot(eigenvector_) + plt.legend(["mean", "median", "supremum", "poisson", "eigenvector"]) + plt.show() diff --git a/debugging/playing.py b/debugging/playing.py new file mode 100644 index 0000000000..dba45ad5b9 --- /dev/null +++ b/debugging/playing.py @@ -0,0 +1,127 @@ +from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings +from spikeinterface.sortingcomponents.peak_detection import detect_peaks +from spikeinterface.sortingcomponents.peak_localization import localize_peaks + +from spikeinterface.preprocessing.inter_session_alignment import ( + session_alignment, + plotting_session_alignment, +) +import matplotlib.pyplot as plt + +import spikeinterface.full as si +import numpy as np + + +si.set_global_job_kwargs(n_jobs=10) + + +if __name__ == '__main__': + + # -------------------------------------------------------------------------------------- + # Load / generate some recordings + # -------------------------------------------------------------------------------------- + + recordings_list, _ = generate_session_displacement_recordings( + num_units=20, + recording_durations=[400, 400, 400], + recording_shifts=((0, 0), (0, 200), (0, -125)), + non_rigid_gradient=0.005, + seed=52, + ) + if False: + import numpy as np + + + recordings_list = [ + si.read_zarr(r"C:\Users\Joe\Downloads\M25_D18_2024-11-05_12-38-28_VR1.zarr\M25_D18_2024-11-05_12-38-28_VR1.zarr"), + si.read_zarr(r"C:\Users\Joe\Downloads\M25_D18_2024-11-05_12-08-47_OF1.zarr\M25_D18_2024-11-05_12-08-47_OF1.zarr"), + ] + + recordings_list = [si.astype(rec, np.float32) for rec in recordings_list] + recordings_list = [si.bandpass_filter(rec) for rec in recordings_list] + recordings_list = [si.common_reference(rec, operator="median") for rec in recordings_list] + + # -------------------------------------------------------------------------------------- + # Compute the peaks / locations with your favourite method + # -------------------------------------------------------------------------------------- + # Note if you did motion correction the peaks are on the motion object. + # There is a function 'session_alignment.align_sessions_after_motion_correction() + # you can use instead of the below. + + if False: + peaks_list, peak_locations_list = session_alignment.compute_peaks_locations_for_session_alignment( + recordings_list, + detect_kwargs={"method": "locally_exclusive"}, + localize_peaks_kwargs={"method": "grid_convolution"}, + ) + + np.save("peaks_1.npy", peaks_list[0]) + np.save("peaks_2.npy", peaks_list[1]) + np.save("peaks_3.npy", peaks_list[2]) + np.save("peak_locs_1.npy", peak_locations_list[0]) + np.save("peak_locs_2.npy", peak_locations_list[1]) + np.save("peak_locs_3.npy", peak_locations_list[2]) + + # if False: + peaks_list = [np.load("peaks_1.npy"), np.load("peaks_2.npy"), np.load("peaks_3.npy")] + peak_locations_list = [np.load("peak_locs_1.npy"), np.load("peak_locs_2.npy"), np.load("peak_locs_3.npy")] + + # -------------------------------------------------------------------------------------- + # Do the estimation + # -------------------------------------------------------------------------------------- + # For each session, an 'activity histogram' is generated. This can be `entire_session` + # or the session can be chunked into segments and some summary generated taken over then. + # This might be useful if periods of the recording have weird kinetics or noise. + # See `session_alignment.py` for docs on these settings. + + non_rigid_window_kwargs = session_alignment.get_non_rigid_window_kwargs() + non_rigid_window_kwargs["rigid_mode"] = "nonrigid" + non_rigid_window_kwargs["win_shape"] = "rect" + non_rigid_window_kwargs["win_step_um"] = 200.0 + non_rigid_window_kwargs["win_scale_um"] = 400.0 + + estimate_histogram_kwargs = session_alignment.get_estimate_histogram_kwargs() + estimate_histogram_kwargs["method"] = "chunked_median" + estimate_histogram_kwargs["histogram_type"] = "activity_1d" # TODO: investigate this case thoroughly + estimate_histogram_kwargs["bin_um"] = 0.5 + estimate_histogram_kwargs["log_scale"] = True + estimate_histogram_kwargs["weight_with_amplitude"] = False + + compute_alignment_kwargs = session_alignment.get_compute_alignment_kwargs() + compute_alignment_kwargs["num_shifts_block"] = 300 + + corrected_recordings_list, extra_info = session_alignment.align_sessions( + recordings_list, + peaks_list, + peak_locations_list, + alignment_order="to_session_2", # "to_session_X" or "to_middle" + non_rigid_window_kwargs=non_rigid_window_kwargs, + estimate_histogram_kwargs=estimate_histogram_kwargs, + ) + + # TODO: nonlinear is not working well 'to middle', investigate + # TODO: also finalise the estimation of bin number of nonrigid. + + if False: + plt.plot(extra_info["histogram_info_list"][0]["chunked_histograms"].T, color="black") + + M = extra_info["session_histogram_list"][0] + S = extra_info["histogram_info_list"][0]["session_histogram_variation"] + + plt.plot(M, color="red") + plt.plot(M + S, color="green") + plt.plot(M - S, color="green") + + plt.show() + + plotting_session_alignment.SessionAlignmentWidget( + recordings_list, + peaks_list, + peak_locations_list, + extra_info["session_histogram_list"], + **extra_info["corrected"], + spatial_bin_centers=extra_info["spatial_bin_centers"], + drift_raster_map_kwargs={"clim":(-250, 0), "scatter_decimate": 10} + ) + + plt.show() diff --git a/debugging/playing2.py b/debugging/playing2.py new file mode 100644 index 0000000000..63710586a8 --- /dev/null +++ b/debugging/playing2.py @@ -0,0 +1,220 @@ +import numpy as np +import matplotlib.pyplot as plt +import scipy + + +def shift_array_fill_zeros(array: np.ndarray, shift: int) -> np.ndarray: + abs_shift = np.abs(shift) + pad_tuple = (0, abs_shift) if shift > 0 else (abs_shift, 0) + padded_hist = np.pad(array, pad_tuple, mode="constant") + cut_padded_array = padded_hist[abs_shift:] if shift >= 0 else padded_hist[:-abs_shift] + return cut_padded_array + + +# Load and normalize signals +signal1 = np.load(r"C:\Users\Joe\work\git-repos\forks\spikeinterface\debugging\signal1_1.npy") +signal2 = np.load(r"C:\Users\Joe\work\git-repos\forks\spikeinterface\debugging\signal2_1.npy") + + +def cross_correlate(sig1, sig2, thr= None): + xcorr = np.correlate(sig1, sig2, mode="full") + + n = sig1.size + low_cut_idx = np.arange(0, n - thr) # double check + high_cut_idx = np.arange(n + thr, 2 * n - 1) + + xcorr[low_cut_idx] = 0 + xcorr[high_cut_idx] = 0 + + if np.max(xcorr) < 0.01: + shift = 0 + else: + shift = np.argmax(xcorr) - xcorr.size // 2 + + return shift + +def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plot=True): + """ + """ + best_correlation = 0 + best_displacements = np.zeros_like(signa11_blanked) + + # TODO: use kriging interp + + xcorr = [] + + for scale in np.linspace(0.85, 1.15, 10): + + nonzero = np.where(signa11_blanked > 0)[0] + if not np.any(nonzero): + continue + + midpoint = nonzero[0] + np.ptp(nonzero) / 2 + x_scale = (x - midpoint) * scale + midpoint + + interp_f = scipy.interpolate.interp1d(x_scale, signa11_blanked, fill_value=0.0, bounds_error=False) # TODO: try cubic etc... or Kriging + + scaled_func = interp_f(x) + + # plt.plot(signa11_blanked) + # plt.plot(scaled_func) + # plt.show() + + # breakpoint() + + for sh in np.arange(-thr, thr): # TODO: we are off by one here + + shift_signal1_blanked = shift_array_fill_zeros(scaled_func, sh) + + x_shift = x_scale - sh # TODO: rename + + # is this pull back? + # interp_f = scipy.interpolate.interp1d(xs, shift_signal1_blanked, fill_value=0.0, bounds_error=False) # TODO: try cubic etc... or Kriging + + # scaled_func = interp_f(x_shift) + + corr_value = np.correlate( + shift_signal1_blanked - np.mean(shift_signal1_blanked), + signal2_blanked - np.mean(signal2_blanked), + ) / signa11_blanked.size + + if corr_value > best_correlation: + best_displacements = x_shift + best_correlation = corr_value + + if False and np.abs(sh) == 1: + print(corr_value) + + plt.plot(shift_signal1_blanked) + plt.plot(signal2_blanked) + plt.show() + # plt.draw() # Draw the updated figure + # plt.pause(0.1) # Pause for 0.5 seconds before updating + # plt.clf() + + # breakpoint() + + + # xcorr.append(np.max(np.r_[xcorr_scale])) + + if False: + xcorr = np.r_[xcorr] + # shift = np.argmax(xcorr) - thr + + print("MAX", np.max(xcorr)) + + if np.max(xcorr) < 0.0001: + shift = 0 + else: + shift = np.argmax(xcorr) - thr + + print("output shift", shift) + + return best_displacements + +# plt.plot(signal1) +# plt.plot(signal2) + +def get_shifts(signal1, signal2, windows, plot=True): + + import matplotlib.pyplot as plt + + signa11_blanked = signal1.copy() + signal2_blanked = signal2.copy() + + best_displacements = np.zeros_like(signal1) + + if (first_idx := windows[0][0]) != 0: + print("first idx", first_idx) + signa11_blanked[:first_idx] = 0 + signal2_blanked[:first_idx] = 0 + + if (last_idx := windows[-1][-1]) != signal1.size - 1: # double check + print("last idx", last_idx) + signa11_blanked[last_idx:] = 0 + signal2_blanked[last_idx:] = 0 + + segment_shifts = np.empty(len(windows)) + + + x = np.arange(signa11_blanked.size) + x_orig = x.copy() + + for round in range(len(windows)): + + #if round == 0: + # shift = cross_correlate(signa11_blanked, signal2_blanked, thr=100) # for first rigid, do larger! + #else: + displacements = cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=200, plot=False) + + + + # breakpoint() + + interpf = scipy.interpolate.interp1d(displacements, signa11_blanked, fill_value=0.0, bounds_error=False) # TODO: move away from this indexing sceheme + signa11_blanked = interpf(x) + + + + # cum_shifts.append(shift) + # print("shift", shift) + + # shift the signal1, or use indexing + +# signa11_blanked = shift_array_fill_zeros(signa11_blanked, shift) # INTERP HERE, KRIGING. but will accumulate interpolation errors... + + # if plot: + # print("round", round) + # plt.plot(signa11_blanked) + # plt.plot(signal2_blanked) + # plt.show() + + window_corrs = np.empty(len(windows)) + for i, idx in enumerate(windows): + window_corrs[i] = np.correlate( + signa11_blanked[idx] - np.mean(signa11_blanked[idx]), + signal2_blanked[idx] - np.mean(signal2_blanked[idx]), + ) / signa11_blanked[idx].size + + max_window = np.argmax(window_corrs) # TODO: cutoff! + + if False: + small_shift = cross_correlate(signa11_blanked[windows[max_window]], signal2_blanked[windows[max_window]], thr=windows[max_window].size //2) + signa11_blanked = shift_array_fill_zeros(signa11_blanked, small_shift) + segment_shifts[max_window] = np.sum(cum_shifts) + small_shift + + best_displacements[windows[max_window]] = displacements[windows[max_window]] + + x = displacements + + signa11_blanked[windows[max_window]] = 0 + signal2_blanked[windows[max_window]] = 0 + + # TODO: need to carry over displacements! + + print(best_displacements) + interpf = scipy.interpolate.interp1d(best_displacements, signal1, fill_value=0.0, bounds_error=False) # TODO: move away from this indexing sceheme + final = interpf(x_orig) + + plt.plot(final) + plt.plot(signal2) + plt.show() + + return segment_shifts + + +num_windows = 5 + +windows = np.arange(signal1.size) + +windows = np.array_split(windows, num_windows) + +shifts = get_shifts(signal1, signal2, windows) + +if False: + + shifts[0::2] = np.array(shifts1) # TODO: MOVE + shifts[1::2] = np.array(shifts2) + + breakpoint() + print("done") diff --git a/debugging/playing_inter-session-alignment.py b/debugging/playing_inter-session-alignment.py new file mode 100644 index 0000000000..44c3e4dad4 --- /dev/null +++ b/debugging/playing_inter-session-alignment.py @@ -0,0 +1,108 @@ +from pathlib import Path +import spikeinterface.full as si +import numpy as np + +# base_path = Path(r"X:\neuroinformatics\scratch\jziminski\ephys\inter-session-alignment\test_motion_project_short\derivatives\1119617") +# sessions = [ +# "1119617_LSE1_shank12_g0", +# "1119617_posttest1_shank12_g0", +# "1119617_pretest1_shank12_g0", +# ] + +base_path = Path(r"X:\neuroinformatics\scratch\jziminski\ephys\inter-session-alignment\test_motion_project\derivatives\sub-013_id-1121381") +sessions = [ + "ses-006_date-20231223_type-lse2", + "ses-003_date-20231221_type-pretest", + "ses-007_date-20231223_type-posttest2", +] + +recordings_list = [] +peaks_list = [] +peak_locations_list = [] + +for ses in sessions: + print(ses) + + ses_path = base_path / ses + + rec = si.load_extractor(ses_path / "preprocessing" / "si_recording") + rec = si.astype(rec, np.float32) + + recordings_list.append(rec) + peaks_list.append(np.load(ses_path / "motion_npy_files" / "peaks.npy" )) + peak_locations_list.append(np.load(ses_path / "motion_npy_files" / "peak_locations.npy")) + + +estimate_histogram_kwargs = { + "bin_um": 5, + "method": "chunked_median", # TODO: double check scaling + "chunked_bin_size_s": "estimate", + "log_scale": False, # TODO: this will mess up time chunk estimation? not currently but definately test this carefully. + "depth_smooth_um": 5, +} +compute_alignment_kwargs = { + "num_shifts_block": None, # TODO: can be in um so comaprable with window kwargs. + "interpolate": False, + "interp_factor": 10, + "kriging_sigma": 1, + "kriging_p": 2, + "kriging_d": 2, + "smoothing_sigma_bin": False, # 0.5, + "smoothing_sigma_window": False, # 0.5, + "akima_interp_nonrigid": False, +} +non_rigid_window_kwargs = { + "rigid": True, + "win_shape": "gaussian", + "win_step_um": 400, + "win_scale_um": 400, + "win_margin_um": None, + "zero_threshold": None, +} + +from spikeinterface.preprocessing.inter_session_alignment import ( + session_alignment, + plotting_session_alignment, + alignment_utils +) +import matplotlib.pyplot as plt + +# TODO: add some print statements for progress +corrected_recordings_list, extra_info = session_alignment.align_sessions( + recordings_list, + peaks_list, + peak_locations_list, + alignment_order="to_session_1", + estimate_histogram_kwargs=estimate_histogram_kwargs, + compute_alignment_kwargs=compute_alignment_kwargs, + non_rigid_window_kwargs=non_rigid_window_kwargs, +) + +plotting_session_alignment.SessionAlignmentWidget( + recordings_list, + peaks_list, + peak_locations_list, + extra_info["session_histogram_list"], + **extra_info["corrected"], + spatial_bin_centers=extra_info["spatial_bin_centers"], + drift_raster_map_kwargs={"clim":(-250, 0), "scatter_decimate": 10} # TODO: option to fix this across recordings. +) + +plt.show() + +A = extra_info["histogram_info_list"][0]["chunked_histograms"] + +mean_ = alignment_utils.get_chunked_hist_mean(A) +median_ = alignment_utils.get_chunked_hist_median(A) +supremum_ = alignment_utils.get_chunked_hist_supremum(A) +poisson_ = alignment_utils.get_chunked_hist_poisson_estimate(A) +eigenvector_ = alignment_utils.get_chunked_hist_eigenvector(A) + +plt.plot(extra_info["spatial_bin_centers"], A.T, color="k") +plt.plot(extra_info["spatial_bin_centers"], mean_) +plt.plot(extra_info["spatial_bin_centers"], median_) +plt.plot(extra_info["spatial_bin_centers"], supremum_) +plt.plot(extra_info["spatial_bin_centers"], poisson_) +plt.plot(extra_info["spatial_bin_centers"], eigenvector_) +plt.legend(["mean", "median", "supremum", "poisson", "eigenvector"]) +plt.show() diff --git a/debugging/simulating/playing_with_estimate_time.py b/debugging/simulating/playing_with_estimate_time.py new file mode 100644 index 0000000000..234a280534 --- /dev/null +++ b/debugging/simulating/playing_with_estimate_time.py @@ -0,0 +1,31 @@ +import numpy as np + +lambda_hat_s = 2 +range_percent = 0.1 + +confidence_z = 1.645 # TODO: check, based on 90% confidence + +e = lambda_hat_s * range_percent + +n = lambda_hat_s / (e / confidence_z)**2 + + +MC = 10000 + +sim_data = np.empty(MC) + +for i in range(MC): + + # Dont do this, model a poisson process + draws = np.random.exponential(1/lambda_hat_s, size=10000) # way too many, calculate properly + + in_time_range = np.cumsum(draws) < n + assert not np.all(in_time_range) / n, "need to increase size" + + count = np.sum(in_time_range) / n + + sim_data[i] = count + +in_range = np.logical_or(sim_data < lambda_hat_s - e, sim_data > lambda_hat_s + e) + +print(f"confidence : {1 - np.mean(in_range)}") # this is compeltely wrong diff --git a/debugging/simulating/sim_histogram_alignment.py b/debugging/simulating/sim_histogram_alignment.py new file mode 100644 index 0000000000..e83e986d29 --- /dev/null +++ b/debugging/simulating/sim_histogram_alignment.py @@ -0,0 +1,124 @@ +import numpy as np +import matplotlib.pyplot as plt + +fs = 1 +ts = 1 +num_chan = 384 +linear_shift_chan = -75 + +num_units = 50 +unit_means = np.random.random_integers(0, 384, num_units) +unit_stds = np.random.random_integers(1, 15, num_units) +unit_firing_rates = np.random.random(num_units) +recording_time_s = 1000 + +unit_spike_times_1 = [] +unit_spike_times_2 = [] +for i in range(num_units): + + spikes = np.random.exponential(1 / unit_firing_rates[i], 1000) # figure this out + spike_times = np.cumsum(spikes) + spike_times = spike_times[spike_times < recording_time_s] + unit_spike_times_1.append(spike_times) + + spikes = np.random.exponential(1 / unit_firing_rates[i], 1000) # figure this out + spike_times = np.cumsum(spikes) + spike_times = spike_times[spike_times < recording_time_s] + unit_spike_times_2.append(spike_times) + +# TODO: do the above twice? + +# Sample the location of each spike +# Do twice, once with a shift (linear, ignore nonlinear for now) +unit_spike_locations_1 = [] +unit_spike_locations_2 = [] +for i in range(num_units): + + spike_locs_1 = np.random.normal(unit_means[i], unit_stds[i], size=len(unit_spike_times_1[i])) + + spike_locs_2 = np.random.normal(unit_means[i] + linear_shift_chan, unit_stds[i], size=len(unit_spike_times_2[i])) + + unit_spike_locations_1.append(spike_locs_1) + unit_spike_locations_2.append(spike_locs_2) + +# if False: +for i in range(num_units): + plt.scatter(unit_spike_times_1[i], unit_spike_locations_1[i]) +plt.ylim(0, 384) +plt.show() + +for i in range(num_units): + plt.scatter(unit_spike_times_2[i], unit_spike_locations_2[i]) +plt.ylim(0, 384) +plt.show() + +all_hist_1 = [] +edges_1 = [] +all_hist_2 = [] +edges_2 = [] + +locs_1 = np.hstack(unit_spike_locations_1) +locs_1 = locs_1[np.where(np.logical_and(locs_1 >= 0, locs_1 <= 384))] +locs_2 = np.hstack(unit_spike_locations_2) +locs_2 = locs_2[np.where(np.logical_and(locs_2 >= 0, locs_2 <= 384))] +for i in range(1, 385): + + bins = np.linspace(0, 1, i + 1) * 384 + hist_1, bin_edges_1 = np.histogram(locs_1, bins=bins) + hist_2, bin_edges_2 = np.histogram(locs_2, bins=bins) + + all_hist_1.append(hist_1) + bin_edges_1 = (bin_edges_1[1:] + bin_edges_1[:-1]) / 2 + edges_1.append(bin_edges_1) + + bin_edges_2 = (bin_edges_2[1:] + bin_edges_2[:-1]) / 2 + all_hist_2.append(hist_2) + edges_2.append(bin_edges_2) + +estimated_shift = np.zeros(384) +for i in range(384): + + xcorr = np.correlate(all_hist_1[i], all_hist_2[i], mode="same") + xmax = np.argmax(xcorr) + + half_bin = len(all_hist_1[i]) / 2 + estimated_shift[i] = half_bin - xmax # TODO: check this lol, do it better + + if False: + if i in (100, 200, 300): + plt.bar(edges_1[i], all_hist_1[i], width=384 / (i + 1)) + plt.xlim(0, 384) # handle this for correlation + plt.show() + assert np.array_equal(edges_1[i], edges_2[i]) + plt.bar(edges_2[i], all_hist_2[i], width=384 / (i + 1), color="orange") + plt.xlim(0, 384) # handle this for correlation + plt.show() + + plt.plot(xcorr) + plt.vlines(half_bin, 0, np.max(xcorr)) + plt.show() + + +plt.plot(np.arange(384), estimated_shift) +plt.hlines(linear_shift_chan, 0, 384) +plt.show() + +from scipy.stats import norm + +# TODO: can check that the prediction +fake_array = np.zeros(384) +for i in range(384): + H_i = 0 + for j in range(num_units): + prob = norm.pdf(i, loc=unit_means[j], scale=unit_stds[j]) + H_i += prob * recording_time_s * unit_firing_rates[j] + fake_array[i] = H_i + +plt.plot(fake_array) +plt.plot(edges_1[-1], all_hist_1[-1]) +plt.show() + +# create a histogram +# smooth all possible smooithings + +# Compute cross-correlation. diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 0316b3bab1..9fa9ee457d 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -96,6 +96,7 @@ def generate_sorting( add_spikes_on_borders=False, num_spikes_per_border=3, border_size_samples=20, + extra_outputs=False, seed=None, ): """ @@ -136,10 +137,14 @@ def generate_sorting( num_segments = len(durations) unit_ids = np.arange(num_units) + extra_outputs_dict = { + "firing_rates": [], + } + spikes = [] for segment_index in range(num_segments): num_samples = int(sampling_frequency * durations[segment_index]) - samples, labels = synthesize_poisson_spike_vector( + samples, labels, firing_rates_array = synthesize_poisson_spike_vector( num_units=num_units, sampling_frequency=sampling_frequency, duration=durations[segment_index], @@ -173,12 +178,17 @@ def generate_sorting( ) spikes.append(spikes_on_borders) + extra_outputs_dict["firing_rates"].append(firing_rates_array) + spikes = np.concatenate(spikes) spikes = spikes[np.lexsort((spikes["sample_index"], spikes["segment_index"]))] sorting = NumpySorting(spikes, sampling_frequency, unit_ids) - return sorting + if extra_outputs: + return sorting, extra_outputs_dict + else: + return sorting def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): @@ -786,7 +796,7 @@ def synthesize_poisson_spike_vector( unit_indices = unit_indices[sort_indices] spike_frames = spike_frames[sort_indices] - return spike_frames, unit_indices + return spike_frames, unit_indices, firing_rates def synthesize_random_firings( @@ -2434,12 +2444,19 @@ def generate_ground_truth_recording( parent_recording=noise_rec, upsample_vector=upsample_vector, ) - recording.annotate(is_filtered=True) - recording.set_probe(probe, in_place=True) - recording.set_channel_gains(1.0) - recording.set_channel_offsets(0.0) - + setup_inject_templates_recording(recording, probe) recording.name = "GroundTruthRecording" sorting.name = "GroundTruthSorting" return recording, sorting + + +def setup_inject_templates_recording(recording: BaseRecording, probe: Probe) -> None: + """ + Convenience function to modify a generated + recording in-place with annotation and probe details + """ + recording.annotate(is_filtered=True) + recording.set_probe(probe, in_place=True) + recording.set_channel_gains(1.0) + recording.set_channel_offsets(0.0) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 53c2445c77..9f6b4d88e0 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -503,7 +503,7 @@ def run_node_pipeline( Here a "spike" is a spike with any a label so already sorted. The main idea is to have a graph of nodes. - Every node is doing a computaion of some peaks and related traces. + Every node is doing a computation of some peaks and related traces. The first node is PeakSource so either a peak detector PeakDetector or peak/spike replay (PeakRetriever/SpikeRetriever) Every node can have one or several output that can be directed to other nodes (aka nodes have parents). @@ -542,7 +542,7 @@ def run_node_pipeline( Skip the computation after n_peaks. This is not an exact because internally this skip is done per worker in average. recording_slices : None | list[tuple] - Optionaly give a list of slices to run the pipeline only on some chunks of the recording. + Optionally give a list of slices to run the pipeline only on some chunks of the recording. It must be a list of (segment_index, frame_start, frame_stop). If None (default), the function iterates over the entire duration of the recording. diff --git a/src/spikeinterface/generation/drifting_generator.py b/src/spikeinterface/generation/drifting_generator.py index 6ff8adadd2..9e1a595640 100644 --- a/src/spikeinterface/generation/drifting_generator.py +++ b/src/spikeinterface/generation/drifting_generator.py @@ -8,6 +8,7 @@ """ +from __future__ import annotations import numpy as np from probeinterface import generate_multi_columns_probe @@ -21,6 +22,7 @@ ) from .drift_tools import DriftingTemplates, make_linear_displacement, InjectDriftingTemplatesRecording from .noise_tools import generate_noise +from probeinterface import Probe # this should be moved in probeinterface but later @@ -181,7 +183,7 @@ def generate_displacement_vector( duration : float Duration of the displacement vector in seconds unit_locations : np.array - The unit location with shape (num_units, 3) + The unit location with shape (num_units, 2) displacement_sampling_frequency : float, default: 5. The sampling frequency of the displacement vector drift_start_um : list of float, default: [0, 20.] @@ -240,22 +242,64 @@ def generate_displacement_vector( if non_rigid_gradient is None: displacement_unit_factor[:, m] = 1 else: - gradient_direction = drift_stop_um - drift_start_um - gradient_direction /= np.linalg.norm(gradient_direction) - - proj = np.dot(unit_locations, gradient_direction).squeeze() - factors = (proj - np.min(proj)) / (np.max(proj) - np.min(proj)) - if non_rigid_gradient < 0: - # reverse - factors = 1 - factors - f = np.abs(non_rigid_gradient) - displacement_unit_factor[:, m] = factors * (1 - f) + f + displacement_unit_factor[:, m] = calculate_displacement_unit_factor( + non_rigid_gradient, unit_locations, drift_start_um, drift_stop_um + ) displacement_vectors = np.concatenate(displacement_vectors, axis=2) return displacement_vectors, displacement_unit_factor, displacement_sampling_frequency, displacements_steps +def calculate_displacement_unit_factor( + non_rigid_gradient: float, unit_locations: np.array, drift_start_um: np.array, drift_stop_um: np.array +) -> np.array: + """ + In the case of introducing non-rigid drift, a set of scaling + factors (one per unit) is generated for scaling the displacement + as a function of unit position. + + The projections of the gradient vector (x, y) + and unit locations (x, y) are normalised to range between + 0 and 1 (i.e. based on relative location to the gradient). + These factors are scaled by `non_rigid_gradient`. + + Parameters + ---------- + + non_rigid_gradient : float + A number in the range [0, 1] by which to scale the scaling factors + that are based on unit location. This sets the weighting given to the factors + based on unit locations. When 1, the factors will all equal 1 (no effect), + when 0, the scaling factor based on unit location will be used directly. + unit_locations : np.array + The unit location with shape (num_units, 2) + drift_start_um : np.array + The start boundary of the motion in the x and y direction. + drift_stop_um : np.array + The stop boundary of the motion in the x and y direction. + + Returns + ------- + displacement_unit_factor : np.array + An array of scaling factors (one per unit) by which + to scale the displacement. + """ + gradient_direction = drift_stop_um - drift_start_um + gradient_direction /= np.linalg.norm(gradient_direction) + + proj = np.dot(unit_locations, gradient_direction).squeeze() + factors = (proj - np.min(proj)) / (np.max(proj) - np.min(proj)) + + if non_rigid_gradient < 0: # reverse + factors = 1 - factors + + f = np.abs(non_rigid_gradient) + displacement_unit_factor = factors * (1 - f) + f + + return displacement_unit_factor + + def generate_drifting_recording( num_units=250, duration=600.0, @@ -351,12 +395,9 @@ def generate_drifting_recording( This can be helpfull for motion benchmark. """ # probe - if generate_probe_kwargs is None: - generate_probe_kwargs = _toy_probes[probe_name] - probe = generate_multi_columns_probe(**generate_probe_kwargs) - num_channels = probe.get_contact_count() - probe.set_device_channel_indices(np.arange(num_channels)) + probe = generate_probe(generate_probe_kwargs, probe_name) channel_locations = probe.contact_positions + # import matplotlib.pyplot as plt # import probeinterface.plotting # fig, ax = plt.subplots() @@ -384,9 +425,7 @@ def generate_drifting_recording( unit_displacements[:, :, direction] += m # unit_params need to be fixed before the displacement steps - generate_templates_kwargs = generate_templates_kwargs.copy() - unit_params = _ensure_unit_params(generate_templates_kwargs.get("unit_params", {}), num_units, seed) - generate_templates_kwargs["unit_params"] = unit_params + generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed) # generate templates templates_array = generate_templates( @@ -478,3 +517,50 @@ def generate_drifting_recording( return static_recording, drifting_recording, sorting, extra_infos else: return static_recording, drifting_recording, sorting + + +def generate_probe(generate_probe_kwargs: dict, probe_name: str | None = None) -> Probe: + """ + Generate a probe for use in certain ground-truth recordings. + + Parameters + ---------- + + generate_probe_kwargs : dict + The kwargs to pass to `generate_multi_columns_probe()` + probe_name : str | None + The probe type if generate_probe_kwargs is None. + """ + if generate_probe_kwargs is None: + assert probe_name is not None, "`probe_name` must be set if `generate_probe_kwargs` is `None`." + generate_probe_kwargs = _toy_probes[probe_name] + probe = generate_multi_columns_probe(**generate_probe_kwargs) + num_channels = probe.get_contact_count() + probe.set_device_channel_indices(np.arange(num_channels)) + + return probe + + +def fix_generate_templates_kwargs(generate_templates_kwargs: dict, num_units: int, seed: int) -> dict: + """ + Fix the generate_template_kwargs such that the same units are created + across calls to `generate_template`. We must explicitly pre-set + the parameters for each unit, done in `_ensure_unit_params()`. + + Parameters + ---------- + + generate_templates_kwargs : dict + These kwargs will have the "unit_params" entry edited such that the + parameters are explicitly set for each unit to create (rather than + generated randomly on the fly). + num_units : int + Number of units to fix the kwargs for + seed : int + Random seed. + """ + generate_templates_kwargs = generate_templates_kwargs.copy() + unit_params = _ensure_unit_params(generate_templates_kwargs.get("unit_params", {}), num_units, seed) + generate_templates_kwargs["unit_params"] = unit_params + + return generate_templates_kwargs diff --git a/src/spikeinterface/generation/session_displacement_generator.py b/src/spikeinterface/generation/session_displacement_generator.py new file mode 100644 index 0000000000..39d3037d68 --- /dev/null +++ b/src/spikeinterface/generation/session_displacement_generator.py @@ -0,0 +1,500 @@ +import copy + +from spikeinterface.generation.drifting_generator import ( + generate_probe, + fix_generate_templates_kwargs, + calculate_displacement_unit_factor, +) +from spikeinterface.core.generate import ( + generate_unit_locations, + generate_sorting, + generate_templates, +) +import numpy as np +from spikeinterface.generation.noise_tools import generate_noise +from spikeinterface.core.generate import setup_inject_templates_recording, _ensure_firing_rates +from spikeinterface.core import InjectTemplatesRecording + + +def generate_session_displacement_recordings( + num_units=250, + recording_durations=(10, 10, 10), + recording_shifts=((0, 0), (0, 25), (0, 50)), + non_rigid_gradient=None, + recording_amplitude_scalings=None, + shift_units_outside_probe=False, + sampling_frequency=30000.0, + probe_name="Neuropixel-128", + generate_probe_kwargs=None, + generate_unit_locations_kwargs=dict( + margin_um=20.0, + minimum_z=5.0, + maximum_z=45.0, + minimum_distance=18.0, + max_iteration=100, + distance_strict=False, + ), + generate_templates_kwargs=dict( + ms_before=1.5, + ms_after=3.0, + mode="ellipsoid", + unit_params=dict( + alpha=(150.0, 500.0), + spatial_decay=(10, 45), + ), + ), + generate_sorting_kwargs=dict(firing_rates=(2.0, 8.0), refractory_period_ms=4.0), + generate_noise_kwargs=dict(noise_levels=(12.0, 15.0), spatial_decay=25.0), + extra_outputs=False, + seed=None, +): + """ + Generate a set of recordings simulating probe drift across recording + sessions. + + Rigid drift can be added in the (x, y) direction in `recording_shifts`. + These drifts can be made non-rigid (scaled dependent on the unit location) + with the `non_rigid_gradient` parameter. Amplitude of units can be scaled + (e.g. template signal removed by scaling with zero) by specifying scaling + factors in `recording_amplitude_scalings`. + + Parameters + ---------- + + num_units : int + The number of units in the generated recordings. + recording_durations : list + An array of length (num_recordings,) specifying the + duration that each created recording should be. + recording_shifts : list + An array of length (num_recordings,) in which each element + is a 2-element array specifying the (x, y) shift for the recording. + Typically, the first recording will have shift (0, 0) so all further + recordings are shifted relative to it. e.g. to create two recordings, + the second shifted by 50 um in the x-direction and 250 um in the y + direction : ((0, 0), (50, 250)). + non_rigid_gradient : float + Factor which sets the level of non-rigidty in the displacement. + See `calculate_displacement_unit_factor` for details. + recording_amplitude_scalings : dict + A dict with keys: + "method" - order by which to apply the scalings. + "by_passed_order" - scalings are applied to the unit templates + in order passed + "by_firing_rate" - scalings are applied to the units in order of + maximum to minimum firing rate + "by_amplitude_and_firing_rate" - scalings are applied to the units + in order of amplitude * firing_rate (maximum to minimum) + "scalings" - a list of numpy arrays, one for each recording, with + each entry an array of length num_units holding the unit scalings. + e.g. for 3 recordings, 2 units: ((1, 1), (1, 1), (0.5, 0.5)). + shift_units_outside_probe : bool + By default (`False`), when units are shifted across sessions, new units are + not introduced into the recording (e.g. the region in which units + have been shifted out of is left at baseline level). In reality, + when the probe shifts new units from outside the original recorded + region are shifted into the recording. When `True`, new units + are shifted into the generated recording. + generate_sorting_kwargs : dict + Only `firing_rates` and `refractory_period_ms` are expected if passed. + + All other parameters are used as in from `generate_drifting_recording()`. + + Returns + ------- + output_recordings : list + A list of recordings with units shifted (i.e. replicated probe shift). + output_sortings : list + A list of corresponding sorting objects. + extra_outputs_dict (options) : dict + When `extra_outputs` is `True`, a dict containing variables used + in the generation process. + "unit_locations" : A list (length num records) of shifted unit locations + "templates_array_moved" : list[np.array] + A list (length num records) of (num_units, num_samples, num_channels) + arrays of templates that have been shifted. + + Notes + ----- + It is important to consider what unit properties are maintained + across the session. Here, all `generate_template_kwargs` are fixed + across sessions, to be sure the unit properties do not change. + The firing rates passed to `generate_sorting` for each unit are + also fixed across sessions. When a seed is set, the exact spike times + will also be fixed across recordings. otherwise, when seed is `None` + the actual spike times will be different across recordings, although + all other unit properties will be maintained (except any location + shifting and template scaling applied). + """ + # temporary fix + generate_unit_locations_kwargs = copy.deepcopy(generate_unit_locations_kwargs) + generate_templates_kwargs = copy.deepcopy(generate_templates_kwargs) + generate_sorting_kwargs = copy.deepcopy(generate_sorting_kwargs) + generate_noise_kwargs = copy.deepcopy(generate_noise_kwargs) + + _check_generate_session_displacement_arguments( + num_units, recording_durations, recording_shifts, recording_amplitude_scalings + ) + + probe = generate_probe(generate_probe_kwargs, probe_name) + channel_locations = probe.contact_positions + + # Create the starting unit locations (which will be shifted). + unit_locations = generate_unit_locations( + num_units, + channel_locations, + seed=seed, + **generate_unit_locations_kwargs, + ) + + # Fix generate template kwargs, so they are the same for every created recording. + # Also fix unit firing rates across recordings. + fixed_generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed) + + fixed_firing_rates = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed) + fixed_generate_sorting_kwargs = copy.deepcopy(generate_sorting_kwargs) + fixed_generate_sorting_kwargs["firing_rates"] = fixed_firing_rates + + if shift_units_outside_probe: + num_units, unit_locations, fixed_generate_templates_kwargs, fixed_generate_sorting_kwargs = ( + _update_kwargs_for_extended_units( + num_units, + channel_locations, + unit_locations, + generate_unit_locations_kwargs, + generate_templates_kwargs, + generate_sorting_kwargs, + fixed_generate_templates_kwargs, + fixed_generate_sorting_kwargs, + seed, + ) + ) + + # Start looping over parameters, creating recordings shifted + # and scaled as required + extra_outputs_dict = { + "unit_locations": [], + "templates_array_moved": [], + "firing_rates": [], + } + output_recordings = [] + output_sortings = [] + + for rec_idx, (shift, duration) in enumerate(zip(recording_shifts, recording_durations)): + + displacement_vector, displacement_unit_factor = _get_inter_session_displacements( + shift, + non_rigid_gradient, + num_units, + unit_locations, + ) + + # Move the canonical `unit_locations` according to the set (x, y) shifts + unit_locations_moved = unit_locations.copy() + unit_locations_moved[:, :2] += displacement_vector[0, :][np.newaxis, :] * displacement_unit_factor + + # Generate the sorting (e.g. spike times) for the recording + sorting, sorting_extra_outputs = generate_sorting( + num_units=num_units, + sampling_frequency=sampling_frequency, + durations=[duration], + **fixed_generate_sorting_kwargs, + extra_outputs=True, + seed=seed, + ) + sorting.set_property("gt_unit_locations", unit_locations_moved) + + # Generate the noise in the recording + noise = generate_noise( + probe=probe, + sampling_frequency=sampling_frequency, + durations=[duration], + seed=seed, + **generate_noise_kwargs, + ) + + # Generate the (possibly shifted, scaled) unit templates + templates_array_moved = generate_templates( + channel_locations, + unit_locations_moved, + sampling_frequency=sampling_frequency, + seed=seed, + **fixed_generate_templates_kwargs, + ) + + if recording_amplitude_scalings is not None: + + first_rec_templates = ( + templates_array_moved if rec_idx == 0 else extra_outputs_dict["templates_array_moved"][0] + ) + + _amplitude_scale_templates_in_place( + first_rec_templates, templates_array_moved, recording_amplitude_scalings, sorting_extra_outputs, rec_idx + ) + + # Bring it all together in a `InjectTemplatesRecording` and + # propagate all relevant metadata to the recording. + ms_before = fixed_generate_templates_kwargs["ms_before"] + nbefore = int(sampling_frequency * ms_before / 1000.0) + + recording = InjectTemplatesRecording( + sorting=sorting, + templates=templates_array_moved, + nbefore=nbefore, + amplitude_factor=None, + parent_recording=noise, + num_samples=noise.get_num_samples(0), + upsample_vector=None, + check_borders=False, + ) + + setup_inject_templates_recording(recording, probe) + + recording.name = "InterSessionDisplacementRecording" + sorting.name = "InterSessionDisplacementSorting" + + output_recordings.append(recording) + output_sortings.append(sorting) + extra_outputs_dict["unit_locations"].append(unit_locations_moved) + extra_outputs_dict["templates_array_moved"].append(templates_array_moved) + extra_outputs_dict["firing_rates"].append(sorting_extra_outputs["firing_rates"][0]) + + if extra_outputs: + return output_recordings, output_sortings, extra_outputs_dict + else: + return output_recordings, output_sortings + + +def _get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_locations): + """ + Get the formatted `displacement_vector` and `displacement_unit_factor` + used to shift the `unit_locations`.. + + Parameters + --------- + shift : np.array | list | tuple + A 2-element array with the shift in the (x, y) direction. + non_rigid_gradient : float + Factor which sets the level of non-rigidty in the displacement. + See `calculate_displacement_unit_factor` for details. + num_units : int + Number of units + unit_locations : np.array + (num_units, 3) array of unit locations (x, y, z). + + Returns + ------- + displacement_vector : np.array + A (:, 2) array of (x, y) of displacements + to add to (i.e. move) unit_locations. + e.g. array([[1, 2]]) + displacement_unit_factor : np.array + A (num_units, :) array of scaling values to apply to the + displacement vector in order to add nonrigid shift to + the displacement. Note the same scaling is applied to the + x and y dimension. + """ + displacement_vector = np.atleast_2d(shift) + + if non_rigid_gradient is None or (shift[0] == 0 and shift[1] == 0): + displacement_unit_factor = np.ones((num_units, 1)) + else: + displacement_unit_factor = calculate_displacement_unit_factor( + non_rigid_gradient, + unit_locations[:, :2], + drift_start_um=np.array([0, 0], dtype=float), + drift_stop_um=np.array(shift, dtype=float), + ) + displacement_unit_factor = displacement_unit_factor[:, np.newaxis] + + return displacement_vector, displacement_unit_factor + + +def _amplitude_scale_templates_in_place( + first_rec_templates, moved_templates, recording_amplitude_scalings, sorting_extra_outputs, rec_idx +): + """ + Scale a set of templates given a set of scaling values. The scaling + values can be applied in the order passed, or instead in order of + the unit firing range (max to min) or unit amplitude * firing rate (max to min). + This will chang the `templates_array` in place. + + Parameters + ---------- + first_rec_templates : np.array + The (num_units, num_samples, num_channels) templates array from the + first recording. Scaling by amplitude scales based on the amplitudes in + the first session. + moved_templates : np.array + A (num_units, num_samples, num_channels) array moved templates to the + current recording, that will be scaled. + recording_amplitude_scalings : dict + see `generate_session_displacement_recordings()`. + sorting_extra_outputs : dict + Extra output of `generate_sorting` holding the firing frequency of all units. + The unit order is assumed to match the templates. + rec_idx : int + The index of the recording for which the templates are being scaled. + + Notes + ----- + This method is used in the context of inter-session displacement. Often, + units may drop out of the recording across sessions. This simulates this by + directly scaling the template (e.g. if scaling by 0, the template is completely + dropped out). The provided scalings can be applied in the order passed, or + in the order of unit firing rate or firing rate * amplitude. The idea is + that it may be desired to remove to downscale the most activate neurons + that contribute most significantly to activity histograms. Similarly, + if amplitude is included in activity histograms the amplitude may + also want to be considered when ordering the units to downscale. + """ + method = recording_amplitude_scalings["method"] + + if method in ["by_amplitude_and_firing_rate", "by_firing_rate"]: + + firing_rates_hz = sorting_extra_outputs["firing_rates"][0] + + if method == "by_amplitude_and_firing_rate": + neg_ampl = np.min(np.min(first_rec_templates, axis=2), axis=1) + assert np.all(neg_ampl < 0), "assumes all amplitudes are negative here." + score = firing_rates_hz * neg_ampl + else: + score = firing_rates_hz + + order_idx = np.argsort(score) + ordered_rec_scalings = recording_amplitude_scalings["scalings"][rec_idx][order_idx, np.newaxis, np.newaxis] + + elif method == "by_passed_order": + + ordered_rec_scalings = recording_amplitude_scalings["scalings"][rec_idx][:, np.newaxis, np.newaxis] + + else: + raise ValueError("`recording_amplitude_scalings['method']` not recognised.") + + moved_templates *= ordered_rec_scalings + + +def _check_generate_session_displacement_arguments( + num_units, recording_durations, recording_shifts, recording_amplitude_scalings +): + """ + Function to check the input arguments related to recording + shift and scale parameters are the correct size. + """ + expected_num_recs = len(recording_durations) + + if len(recording_shifts) != expected_num_recs: + raise ValueError( + "`recording_shifts` and `recording_durations` must be " + "the same length, the number of recordings to generate." + ) + + shifts_are_2d = [len(shift) == 2 for shift in recording_shifts] + if not all(shifts_are_2d): + raise ValueError("Each record entry for `recording_shifts` must have two elements, the x and y shift.") + + if recording_amplitude_scalings is not None: + + keys = recording_amplitude_scalings.keys() + if not "method" in keys or not "scalings" in keys: + raise ValueError("`recording_amplitude_scalings` must be a dict with keys `method` and `scalings`.") + + allowed_methods = ["by_passed_order", "by_amplitude_and_firing_rate", "by_firing_rate"] + if not recording_amplitude_scalings["method"] in allowed_methods: + raise ValueError(f"`recording_amplitude_scalings` must be one of {allowed_methods}") + + rec_scalings = recording_amplitude_scalings["scalings"] + if not len(rec_scalings) == expected_num_recs: + raise ValueError("`recording_amplitude_scalings` 'scalings' must have one array per recording.") + + if not all([len(scale) == num_units for scale in rec_scalings]): + raise ValueError( + "The entry for each recording in `recording_amplitude_scalings` " + "must have the same length as the number of units." + ) + + +def _update_kwargs_for_extended_units( + num_units, + channel_locations, + unit_locations, + generate_unit_locations_kwargs, + generate_templates_kwargs, + generate_sorting_kwargs, + fixed_generate_templates_kwargs, + fixed_generate_sorting_kwargs, + seed, +): + """ + In a real world situation, if the probe moves up / down + not only will previously recorded units be shifted, but + new units will be introduced into the recording. + + This function extends the default num units, unit locations, + and template / sorting kwargs to extend the unit of units + one probe's height (y dimension) above and below the probe. + Now, when the unit locations are shifted, new units will be + introduced into the recording (from below or above). + + It is important that the unit kwargs for the units are kept the + same across runs when seeded (i.e. whether `shift_units_outside_probe` + is `True` or `False`). To acheive this, the fixed unit kwargs + are extended with new units located above and below these fixed + units. The seeds are shifted slightly, so the introduced + units do not duplicate the existing units. + + """ + seed_top = seed + 1 if seed is not None else None + seed_bottom = seed - 1 if seed is not None else None + + # Set unit locations above and below the probe and extend + # the `unit_locations` array. + channel_locations_extend_top = channel_locations.copy() + channel_locations_extend_top[:, 1] -= np.max(channel_locations[:, 1]) + + extend_top_locations = generate_unit_locations( + num_units, + channel_locations_extend_top, + seed=seed_top, + **generate_unit_locations_kwargs, + ) + + channel_locations_extend_bottom = channel_locations.copy() + channel_locations_extend_bottom[:, 1] += np.max(channel_locations[:, 1]) + + extend_bottom_locations = generate_unit_locations( + num_units, + channel_locations_extend_bottom, + seed=seed_bottom, + **generate_unit_locations_kwargs, + ) + + unit_locations = np.r_[extend_bottom_locations, unit_locations, extend_top_locations] + + # For the new units located above and below the probe, generate a set of + # firing rates and template kwargs. + + # Extend the template kwargs + template_kwargs_top = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed_top) + template_kwargs_bottom = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed_bottom) + + for key in fixed_generate_templates_kwargs["unit_params"].keys(): + fixed_generate_templates_kwargs["unit_params"][key] = np.r_[ + template_kwargs_top["unit_params"][key], + fixed_generate_templates_kwargs["unit_params"][key], + template_kwargs_bottom["unit_params"][key], + ] + + # Extend the firing rates + firing_rates_top = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed_top) + firing_rates_bottom = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed_bottom) + + fixed_generate_sorting_kwargs["firing_rates"] = np.r_[ + firing_rates_top, fixed_generate_sorting_kwargs["firing_rates"], firing_rates_bottom + ] + + # Update the number of units (3x as a + # new set above and below the existing units) + num_units *= 3 + + return num_units, unit_locations, fixed_generate_templates_kwargs, fixed_generate_sorting_kwargs diff --git a/src/spikeinterface/generation/tests/test_session_displacement_generator.py b/src/spikeinterface/generation/tests/test_session_displacement_generator.py new file mode 100644 index 0000000000..44f80acead --- /dev/null +++ b/src/spikeinterface/generation/tests/test_session_displacement_generator.py @@ -0,0 +1,485 @@ +import pytest + +from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings +from spikeinterface.generation.drifting_generator import generate_drifting_recording +from spikeinterface.core import order_channels_by_depth +import numpy as np +from spikeinterface.sortingcomponents.peak_detection import detect_peaks +from spikeinterface.sortingcomponents.peak_localization import localize_peaks + + +class TestSessionDisplacementGenerator: + """ + This class tests the `generate_session_displacement_recordings` that + returns a recordings / sorting in which the units are shifted + across sessions. This is achieved by shifting the unit locations + in both (x, y) on the generated templates that are used in + `InjectTemplatesRecording()`. + """ + + @pytest.fixture(scope="function") + def options(self): + """ + Set a set of base options that can be used in + `generate_session_displacement_recordings() ("kwargs") + and provide general information on the generated recordings. + These can be edited in the tests as required. + """ + options = { + "kwargs": { + "recording_durations": (10, 10, 25, 33), + "recording_shifts": ((0, 0), (2, -100), (-3, 275), (4, 1e6)), + "num_units": 5, + "extra_outputs": True, + "seed": 42, + }, + "num_recs": 4, + "y_bin_um": 10, + } + options["kwargs"]["generate_probe_kwargs"] = dict( + num_columns=1, + num_contact_per_column=128, + xpitch=16, + ypitch=options["y_bin_um"], + contact_shapes="square", + contact_shape_params={"width": 12}, + ) + + return options + + ### Tests + def test_x_y_rigid_shifts_are_properly_set(self, options): + """ + The session displacement works by generating a set of + templates shared across all recordings, but set with + different `unit_locations()`. Check here that the + (x, y) displacements passed in `recording_shifts` are properly + propagated. + + First, check the set `unit_locations` are moved as expected according + to the (x, y) shifts). Next, check the templates themselves are + moved as expected. The x-axis shift has the effect of changing + the template amplitude, and is not possible to test. However, + the y-axis shift shifts the maximum signal channel, so we check + the maximum signal channel o fthe templates is shifted as expected. + This implicitly tests the x-axis case as if the x-axis `unit_locations` + are shifted as expected, and the unit-locations are propagated + to the template, then (x, y) will both be working. + """ + output_recordings, _, extra_outputs = generate_session_displacement_recordings(**options["kwargs"]) + num_units = options["kwargs"]["num_units"] + recording_shifts = options["kwargs"]["recording_shifts"] + + # test unit locations are shifted as expected according + # to the record shifts + locations_1 = extra_outputs["unit_locations"][0] + + for rec_idx in range(1, 4): + + shifts = recording_shifts[rec_idx] + + assert np.array_equal( + locations_1 + np.r_[shifts, 0].astype(np.float32), extra_outputs["unit_locations"][rec_idx] + ) + + # Check that the generated templates are correctly shifted + # For each generated unit, check that the max loading channel is + # shifted as expected. In the case that the unit location is off the + # probe, check the maximum signal channel is the min / max channel on + # the probe, or zero (the unit is too far to reach the probe). + min_channel_loc = output_recordings[0].get_channel_locations()[0, 1] + max_channel_loc = output_recordings[0].get_channel_locations()[-1, 1] + for unit_idx in range(num_units): + + start_pos = self._get_peak_chan_loc_in_um( + extra_outputs["templates_array_moved"][0][unit_idx], + options["y_bin_um"], + ) + + for rec_idx in range(1, options["num_recs"]): + + new_pos = self._get_peak_chan_loc_in_um( + extra_outputs["templates_array_moved"][rec_idx][unit_idx], options["y_bin_um"] + ) + + y_shift = recording_shifts[rec_idx][1] + if start_pos + y_shift > max_channel_loc: + assert new_pos == max_channel_loc or new_pos == 0 + elif start_pos + y_shift < min_channel_loc: + assert new_pos == min_channel_loc or new_pos == 0 + else: + assert np.isclose(new_pos, start_pos + y_shift, options["y_bin_um"]) + + # Confidence check the correct templates are + # loaded to the recording object. + for rec_idx in range(options["num_recs"]): + assert np.array_equal( + output_recordings[rec_idx].templates, + extra_outputs["templates_array_moved"][rec_idx], + ) + + def _get_peak_chan_loc_in_um(self, template_array, y_bin_um): + """ + Convenience function to get the maximally loading + channel y-position in um for the template. + """ + return np.argmax(np.max(template_array, axis=0)) * y_bin_um + + def test_recordings_length(self, options): + """ + Test that the `recording_durations` that sets the + length of each recording changes the recording + length as expected. + """ + output_recordings = generate_session_displacement_recordings(**options["kwargs"])[0] + + for rec, expected_rec_length in zip(output_recordings, options["kwargs"]["recording_durations"]): + assert rec.get_total_duration() == expected_rec_length + + def test_spike_times_and_firing_rates_across_recordings(self, options): + """ + Check the randomisation of spike times across recordings. + When a seed is set, this is passed to `generate_sorting` + and so the spike times across all records are expected + to be identical. However, if no seed is set, then the spike + times will be different across recordings. + """ + options["kwargs"]["recording_durations"] = (10,) * options["num_recs"] + + output_sortings_same, extra_outputs_same = generate_session_displacement_recordings(**options["kwargs"])[1:3] + + options["kwargs"]["seed"] = None + output_sortings_different, extra_outputs_different = generate_session_displacement_recordings( + **options["kwargs"] + )[1:3] + + for unit_idx in range(options["kwargs"]["num_units"]): + for rec_idx in range(1, options["num_recs"]): + + # Exact spike times are not preserved when seed is None + assert np.array_equal( + output_sortings_same[0].get_unit_spike_train(unit_idx), + output_sortings_same[rec_idx].get_unit_spike_train(unit_idx), + ) + assert not np.array_equal( + output_sortings_different[0].get_unit_spike_train(unit_idx), + output_sortings_different[rec_idx].get_unit_spike_train(unit_idx), + ) + # Firing rates should always be preserved. + assert np.array_equal( + extra_outputs_same["firing_rates"][0][unit_idx], + extra_outputs_same["firing_rates"][rec_idx][unit_idx], + ) + assert np.array_equal( + extra_outputs_different["firing_rates"][0][unit_idx], + extra_outputs_different["firing_rates"][rec_idx][unit_idx], + ) + + @pytest.mark.parametrize("dim_idx", [0, 1]) + def test_x_y_shift_non_rigid(self, options, dim_idx): + """ + Check that the non-rigid shift changes the channel location + as expected. Non-rigid shifts are calculated depending on the + position of the channel. The `non_rigid_gradient` parameter + determines how much the position or 'distance' of the channel + (w.r.t the gradient of movement) affects the scaling. When + 0, the displacement is scaled by the distance. When 0, the + distance is ignored and all scalings are 1. + + This test checks the generated `unit_locations` under extreme + cases, when `non_rigid_gradient` is `None` or 0, which are equivalent, + and when it is `1`, and the displacement is directly propotional to + the unit position. + """ + options["kwargs"]["recording_shifts"] = ((0, 0), (10, 15), (15, 20), (20, 25)) + + _, _, rigid_info = generate_session_displacement_recordings( + **options["kwargs"], + non_rigid_gradient=None, + ) + _, _, nonrigid_max_info = generate_session_displacement_recordings( + **options["kwargs"], + non_rigid_gradient=0, + ) + _, _, nonrigid_none_info = generate_session_displacement_recordings( + **options["kwargs"], + non_rigid_gradient=1, + ) + + initial_locations = rigid_info["unit_locations"][0] + + # For each recording (i.e. each recording as different displacement + # w.r.t the first recording), check the rigid and nonrigid shifts + # are as expected. + for rec_idx in range(1, options["num_recs"]): + + shift = options["kwargs"]["recording_shifts"][rec_idx][dim_idx] + + # Get the rigid shift between the first recording and this shifted recording + # Check shifts for all unit locations are all the same. + shifts_rigid = self._get_shifts(rigid_info, rec_idx, dim_idx, initial_locations) + shifts_rigid = np.round(shifts_rigid, 5) + + assert np.unique(shifts_rigid).size == 1 + + # Get the nonrigid shift between the first recording and this recording. + # The shift for each unit should be directly proportional to its position. + y_shifts_nonrigid = self._get_shifts(nonrigid_max_info, rec_idx, dim_idx, initial_locations) + + distance = np.linalg.norm(initial_locations, axis=1) + norm_distance = (distance - np.min(distance)) / (np.max(distance) - np.min(distance)) + + assert np.unique(y_shifts_nonrigid).size == options["kwargs"]["num_units"] + + # There is some small rounding error due to difference in distance computation, + # the main thing is the relative order not the absolute value. + assert np.allclose(y_shifts_nonrigid, shift * norm_distance, rtol=0, atol=0.5) + + # then do again with non-ridig-gradient 1 and check it matches rigid case + shifts_rigid_2 = self._get_shifts(nonrigid_none_info, rec_idx, dim_idx, initial_locations) + assert np.array_equal(shifts_rigid, np.round(shifts_rigid_2, 5)) + + def _get_shifts(self, extras_dict, rec_idx, dim_idx, initial_locations): + return extras_dict["unit_locations"][rec_idx][:, dim_idx] - initial_locations[:, dim_idx] + + def test_displacement_with_peak_detection(self, options): + """ + This test checks that the session displacement occurs + as expected under normal usage. Create a recording with a + single unit and a y-axis displacement. Find the peak + locations and check the shifted peak location is as expected, + within the tolerate of the y-axis pitch. + """ + # The seed is important here, otherwise the unit positions + # might go off the end of the probe. These kwargs are + # chosen to make the recording as small as possible as this + # test is slow for larger recordings. + y_shift = 50 + options["kwargs"]["recording_shifts"] = ((0, 0), (0, y_shift)) + options["kwargs"]["recording_durations"] = (0.5, 0.5) + options["num_recs"] = 2 + options["kwargs"]["num_units"] = 1 + options["kwargs"]["generate_probe_kwargs"]["num_contact_per_column"] = 18 + + output_recordings, _, _ = generate_session_displacement_recordings( + **options["kwargs"], generate_noise_kwargs=dict(noise_levels=(1.0, 2.0), spatial_decay=1.0) + ) + + first_recording = output_recordings[0] + + # Peak location of unshifted recording + peaks = detect_peaks(first_recording, method="by_channel") + peak_locs = localize_peaks(first_recording, peaks, method="center_of_mass") + first_pos = np.mean(peak_locs["y"]) + + # Find peak location on shifted recording and check it is + # the original location + shift. + shifted_recording = output_recordings[1] + peaks = detect_peaks(shifted_recording, method="by_channel") + peak_locs = localize_peaks(shifted_recording, peaks, method="center_of_mass") + + new_pos = np.mean(peak_locs["y"]) + + assert np.isclose(new_pos, first_pos + y_shift, rtol=0, atol=options["y_bin_um"]) + + def test_amplitude_scalings(self, options): + """ + Test that the templates are scaled by the passed scaling factors + in the specified order. The order can be in the passed order, + in the order of highest-to-lowest firing unit, or in the order + of (amplitude * firing_rate) (highest to lowest unit). + """ + # Setup arguments to create an unshifted set of recordings + # where the templates are to be scaled with `true_scalings` + options["kwargs"]["recording_durations"] = (10, 10) + options["kwargs"]["recording_shifts"] = ((0, 0), (0, 0)) + options["kwargs"]["num_units"] == 5, + + true_scalings = np.array([0.1, 0.2, 0.3, 0.4, 0.5]) + + recording_amplitude_scalings = { + "method": "by_passed_order", + "scalings": (np.ones(5), true_scalings), + } + + _, output_sortings, extra_outputs = generate_session_displacement_recordings( + **options["kwargs"], + recording_amplitude_scalings=recording_amplitude_scalings, + ) + + # Check that the unit templates are scaled in the order + # the scalings were passed. + test_scalings = self._calculate_scalings_from_output(extra_outputs) + assert np.allclose(test_scalings, true_scalings) + + # Now run, again applying the scalings in the order of + # unit firing rates (highest to lowest). + firing_rates = np.array([5, 4, 3, 2, 1]) + generate_sorting_kwargs = dict(firing_rates=firing_rates, refractory_period_ms=4.0) + recording_amplitude_scalings["method"] = "by_firing_rate" + _, output_sortings, extra_outputs = generate_session_displacement_recordings( + **options["kwargs"], + recording_amplitude_scalings=recording_amplitude_scalings, + generate_sorting_kwargs=generate_sorting_kwargs, + ) + + test_scalings = self._calculate_scalings_from_output(extra_outputs) + assert np.allclose(test_scalings, true_scalings[np.argsort(firing_rates)]) + + # Finally, run again applying the scalings in the order of + # unit amplitude * firing_rate + recording_amplitude_scalings["method"] = "by_amplitude_and_firing_rate" # TODO: method -> order + amplitudes = np.min(np.min(extra_outputs["templates_array_moved"][0], axis=2), axis=1) + firing_rate_by_amplitude = np.argsort(amplitudes * firing_rates) + + _, output_sortings, extra_outputs = generate_session_displacement_recordings( + **options["kwargs"], + recording_amplitude_scalings=recording_amplitude_scalings, + generate_sorting_kwargs=generate_sorting_kwargs, + ) + + test_scalings = self._calculate_scalings_from_output(extra_outputs) + assert np.allclose(test_scalings, true_scalings[firing_rate_by_amplitude]) + + def _calculate_scalings_from_output(self, extra_outputs): + first, second = extra_outputs["templates_array_moved"] + first_min = np.min(np.min(first, axis=2), axis=1) + second_min = np.min(np.min(second, axis=2), axis=1) + test_scalings = second_min / first_min + return test_scalings + + def test_metadata(self, options): + """ + Check that metadata required to be set of generated recordings is present + on all output recordings. + """ + output_recordings, output_sortings, extra_outputs = generate_session_displacement_recordings( + **options["kwargs"], generate_noise_kwargs=dict(noise_levels=(1.0, 2.0), spatial_decay=1.0) + ) + num_chans = output_recordings[0].get_num_channels() + + for i in range(len(output_recordings)): + assert output_recordings[i].name == "InterSessionDisplacementRecording" + assert output_recordings[i]._annotations["is_filtered"] is True + assert output_recordings[i].has_probe() + assert np.array_equal(output_recordings[i].get_channel_gains(), np.ones(num_chans)) + assert np.array_equal(output_recordings[i].get_channel_offsets(), np.zeros(num_chans)) + + assert np.array_equal( + output_sortings[i].get_property("gt_unit_locations"), extra_outputs["unit_locations"][i] + ) + assert output_sortings[i].name == "InterSessionDisplacementSorting" + + def test_shift_units_outside_probe(self, options): + """ + When `shift_units_outside_probe` is `True`, a new set of + units above and below the probe (y dimension) are created, + such that they may be shifted into the recording. + + Here, check that these new units are created when `shift_units_outside_probe` + is on and that the kwargs for the central set of units match those + as when `shift_units_outside_probe` is `False`. + """ + num_sessions = len(options["kwargs"]["recording_durations"]) + _, _, baseline_outputs = generate_session_displacement_recordings( + **options["kwargs"], + ) + + _, _, outside_probe_outputs = generate_session_displacement_recordings( + **options["kwargs"], shift_units_outside_probe=True + ) + + num_units = options["kwargs"]["num_units"] + num_extended_units = num_units * 3 + + for ses_idx in range(num_sessions): + + # There are 3x the number of units when new units are created + # (one new set above, and one new set below the probe). + for key in ["unit_locations", "templates_array_moved", "firing_rates"]: + assert outside_probe_outputs[key][ses_idx].shape[0] == num_extended_units + + assert np.array_equal( + baseline_outputs[key][ses_idx], outside_probe_outputs[key][ses_idx][num_units:-num_units] + ) + + # The kwargs of the units in the central positions should be identical + # to those when `shift_units_outside_probe` is `False`. + lower_unit_pos = outside_probe_outputs["unit_locations"][ses_idx][-num_units:][:, 1] + upper_unit_pos = outside_probe_outputs["unit_locations"][ses_idx][:num_units][:, 1] + middle_unit_pos = baseline_outputs["unit_locations"][ses_idx][:, 1] + + assert np.min(upper_unit_pos) > np.max(middle_unit_pos) + assert np.max(lower_unit_pos) < np.min(middle_unit_pos) + + def test_same_as_generate_ground_truth_recording(self): + """ + It is expected that inter-session displacement randomly + generated recording and injected motion recording will + use exactly the same method to generate the ground-truth + recording (without displacement or motion). To check this, + set their kwargs equal and seed, then generate a non-displaced + recording. It should be identical to the static recroding + generated by `generate_drifting_recording()`. + """ + + # Set some shared kwargs + num_units = 5 + duration = 10 + sampling_frequency = 30000.0 + probe_name = "Neuropixel-128" + generate_probe_kwargs = None + generate_unit_locations_kwargs = dict() + generate_templates_kwargs = dict(ms_before=1.5, ms_after=3) + generate_sorting_kwargs = dict(firing_rates=1) + generate_noise_kwargs = dict() + seed = 42 + + # Generate a inter-session displacement recording with no displacement + no_shift_recording, _ = generate_session_displacement_recordings( + num_units=num_units, + recording_durations=[duration], + recording_shifts=((0, 0),), + sampling_frequency=sampling_frequency, + probe_name=probe_name, + generate_probe_kwargs=generate_probe_kwargs, + generate_unit_locations_kwargs=generate_unit_locations_kwargs, + generate_templates_kwargs=generate_templates_kwargs, + generate_sorting_kwargs=generate_sorting_kwargs, + generate_noise_kwargs=generate_noise_kwargs, + seed=seed, + ) + no_shift_recording = no_shift_recording[0] + + # Generate a drifting recording with no drift + static_recording, _, _ = generate_drifting_recording( + num_units=num_units, + duration=duration, + sampling_frequency=sampling_frequency, + probe_name=probe_name, + generate_probe_kwargs=generate_probe_kwargs, + generate_unit_locations_kwargs=generate_unit_locations_kwargs, + generate_templates_kwargs=generate_templates_kwargs, + generate_sorting_kwargs=generate_sorting_kwargs, + generate_noise_kwargs=generate_noise_kwargs, + generate_displacement_vector_kwargs=dict( + motion_list=[ + dict( + drift_mode="zigzag", + non_rigid_gradient=None, + t_start_drift=1.0, + t_end_drift=None, + period_s=200, + ), + ] + ), + seed=seed, + ) + + # Check the templates and raw data match exactly. + assert np.array_equal( + no_shift_recording.get_traces(start_frame=0, end_frame=10), + static_recording.get_traces(start_frame=0, end_frame=10), + ) + + assert np.array_equal(no_shift_recording.templates, static_recording.drifting_templates.templates_array) diff --git a/src/spikeinterface/preprocessing/__init__.py b/src/spikeinterface/preprocessing/__init__.py index 3343217090..5b09a71be6 100644 --- a/src/spikeinterface/preprocessing/__init__.py +++ b/src/spikeinterface/preprocessing/__init__.py @@ -2,6 +2,17 @@ from .motion import correct_motion, load_motion_info, save_motion_info, get_motion_parameters_preset, get_motion_presets +""" +from .inter_session_alignment.session_alignment import ( + get_estimate_histogram_kwargs, + get_compute_alignment_kwargs, + get_non_rigid_window_kwargs, + get_interpolate_motion_kwargs, + align_sessions, + align_sessions_after_motion_correction, + compute_peaks_locations_for_session_alignment, +) +""" from .preprocessing_tools import get_spatial_interpolation_kernel from .detect_bad_channels import detect_bad_channels from .correct_lsb import correct_lsb diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/__init__.py b/src/spikeinterface/preprocessing/inter_session_alignment/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py b/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py new file mode 100644 index 0000000000..a3ecac3db7 --- /dev/null +++ b/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py @@ -0,0 +1,636 @@ +from spikeinterface import BaseRecording +import numpy as np + +from spikeinterface.preprocessing import center +from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram +from scipy.optimize import minimize +from scipy.ndimage import gaussian_filter +from spikeinterface.sortingcomponents.motion.iterative_template import kriging_kernel + +# ############################################################################# +# Get Histograms +# ############################################################################# + + +def get_activity_histogram( + recording: BaseRecording, + peaks: np.ndarray, + peak_locations: np.ndarray, + spatial_bin_edges: np.ndarray, + log_scale: bool, + bin_s: float | None, + depth_smooth_um: float | None, + scale_to_hz: bool = False, + weight_with_amplitude: bool = False, +): + """ + Generate a 2D activity histogram for the session. Wraps the underlying + spikeinterface function with some adjustments for scaling to time and + log transform. + + Parameters + ---------- + + recording: BaseRecording, + A SpikeInterface recording object. + peaks: np.ndarray, + A SpikeInterface `peaks` array. + peak_locations: np.ndarray, + A SpikeInterface `peak_locations` array. + spatial_bin_edges: np.ndarray, + A (1 x n_bins + 1) array of spatial (probe y dimension) bin edges. + log_scale: bool, + If `True`, histogram is log scaled. + bin_s | None: float, + If `None`, a single histogram will be generated from all session + peaks. Otherwise, multiple histograms will be generated, one for + each time bin. + depth_smooth_um: float | None + If not `None`, smooth the histogram across the spatial + axis. see `make_2d_motion_histogram()` for details. + + TODO + ---- + - assumes 1-segment recording + - ask Sam whether it makes sense to integrate this function with `make_2d_motion_histogram`. + """ + activity_histogram, temporal_bin_edges, generated_spatial_bin_edges = make_2d_motion_histogram( + recording, + peaks, + peak_locations, + weight_with_amplitude=weight_with_amplitude, + direction="y", + bin_s=(bin_s if bin_s is not None else recording.get_duration(segment_index=0)), + bin_um=None, + hist_margin_um=None, + spatial_bin_edges=spatial_bin_edges, + depth_smooth_um=depth_smooth_um, + ) + assert np.array_equal(generated_spatial_bin_edges, spatial_bin_edges), "TODO: remove soon after testing" + + temporal_bin_centers = get_bin_centers(temporal_bin_edges) + spatial_bin_centers = get_bin_centers(spatial_bin_edges) + + if scale_to_hz: + if bin_s is None: + scaler = 1 / recording.get_duration() + else: + scaler = 1 / np.diff(temporal_bin_edges)[:, np.newaxis] + + activity_histogram *= scaler + + if log_scale: + activity_histogram = np.log10(1 + activity_histogram) # TODO: make_2d_motion_histogram uses log2 + + return activity_histogram, temporal_bin_centers, spatial_bin_centers + + +def get_bin_centers(bin_edges): + return (bin_edges[1:] + bin_edges[:-1]) / 2 + + +def estimate_chunk_size(scaled_activity_histogram): + """ + Estimate a chunk size based on the firing rate. Intuitively, we + want longer chunk size to better estimate low firing rates. The + estimation computes a summary of the the firing rates for the session + by taking the value 25% of the max of the activity histogram. + + Then, the chunk size that will accurately estimate this firing rate + within 90% accuracy, 90% of the time based on assumption of Poisson + firing (based on CLT) is computed. + + Parameters + ---------- + + scaled_activity_histogram: np.ndarray + The activity histogram scaled to firing rate in Hz. + + TODO + ---- + - make the details available. + """ + print("scaled max", np.max(scaled_activity_histogram)) + + firing_rate = np.max(scaled_activity_histogram) * 0.25 + + lambda_hat_s = firing_rate + range_percent = 0.1 + confidence_z = 1.645 # 90% of samples in the normal distribution + e = lambda_hat_s * range_percent + + t = lambda_hat_s / (e / confidence_z) ** 2 + + print( + f"Chunked histogram window size of: {t}s estimated " + f"for firing rate (25% of histogram peak) of {lambda_hat_s}" + ) + + return 10 + + +# ############################################################################# +# Chunked Histogram estimation methods +# ############################################################################# +# Given a set off chunked_session_histograms (num time chunks x num spatial bins) +# take the summary statistic over the time axis. + + +def get_chunked_hist_mean(chunked_session_histograms): + + mean_hist = np.mean(chunked_session_histograms, axis=0) + + std = np.std(chunked_session_histograms, axis=0, ddof=0) + + return mean_hist, std + + +def get_chunked_hist_median(chunked_session_histograms): + + median_hist = np.median(chunked_session_histograms, axis=0) + + quartile_1 = np.percentile(chunked_session_histograms, 25, axis=0) + quartile_3 = np.percentile(chunked_session_histograms, 75, axis=0) + + iqr = quartile_3 - quartile_1 + + return median_hist, iqr + + +def get_chunked_hist_supremum(chunked_session_histograms): + + max_hist = np.max(chunked_session_histograms, axis=0) + + min_hist = np.min(chunked_session_histograms, axis=0) + + scaled_range = (max_hist - min_hist) / (max_hist + 1e-12) + + return max_hist, scaled_range + + +def get_chunked_hist_poisson_estimate(chunked_session_histograms): + """ + Make a MLE estimate of the most likely value for each bin + given the assumption of Poisson firing. Turns out this is + basically identical to the mean :'D. + + Keeping for now as opportunity to add prior or do some outlier + removal per bin. But if not useful, deprecate in future. + """ + + def obj_fun(lambda_, m, sum_k): + return -(sum_k * np.log(lambda_) - m * lambda_) + + poisson_estimate = np.zeros(chunked_session_histograms.shape[1]) + std_devs = [] + for i in range(chunked_session_histograms.shape[1]): + + ks = chunked_session_histograms[:, i] + + std_devs.append(np.std(ks)) + m = ks.shape + sum_k = np.sum(ks) + + poisson_estimate[i] = minimize(obj_fun, 0.5, (m, sum_k), bounds=((1e-10, np.inf),)).x + + raise NotImplementedError("This is the same as the mean, deprecate") + + return poisson_estimate + + +def get_chunked_hist_eigenvector(chunked_session_histograms): + """ + TODO: a little messy with the 2D stuff. Will probably deprecate anyway. + """ + if chunked_session_histograms.shape[0] == 1: + return chunked_session_histograms.squeeze(), None + + is_2d = chunked_session_histograms.ndim == 3 + + if is_2d: + num_hist, num_spat_bin, num_amp_bin = chunked_session_histograms.shape + chunked_session_histograms = np.reshape(chunked_session_histograms, (num_hist, num_spat_bin * num_amp_bin)) + + A = chunked_session_histograms + S = (1 / A.shape[0]) * A.T @ A + + L, U = np.linalg.eigh(S) + + first_eigenvector = U[:, -1] * np.sqrt(L[-1]) + first_eigenvector = np.abs(first_eigenvector) # sometimes the eigenvector is negative + + # Project all vectors (histograms) onto the principal component, + # then take the standard deviation in each dimension (over bins) + v1 = first_eigenvector[:, np.newaxis] + projection_onto_v1 = (A @ v1 @ v1.T) / (v1.T @ v1) + + v1_std = np.std(projection_onto_v1, axis=0) + + if is_2d: # TODO: double check this + first_eigenvector = np.reshape(first_eigenvector, (num_spat_bin, num_amp_bin)) + v1_std = np.reshape(v1_std, (num_spat_bin, num_amp_bin)) + + return first_eigenvector, v1_std + + +def get_chunked_gaussian_process_regression(chunked_session_histogram): + """ """ + # TODO: this is currently a placeholder implementation where the + # mean and variance over repeated samples is taken to run quickly. + # It would be better to use sparse version with repeated measures + # as done in pymc. + # TODO: try https://github.com/cornellius-gp/gpytorch + # even better : https://www.pymc.io/projects/examples/en/latest/gaussian_processes/GP-Heteroskedastic.html + # + + from sklearn.gaussian_process import GaussianProcessRegressor + from sklearn.gaussian_process.kernels import RBF, ConstantKernel + from sklearn.preprocessing import StandardScaler + import GPy + + chunked_session_histogram = chunked_session_histogram.copy() + chunked_session_histogram = chunked_session_histogram + + num_hist = chunked_session_histogram.shape[0] + num_bins = chunked_session_histogram.shape[1] + + X = np.arange(num_bins) + X_scaled = X + + Y = chunked_session_histogram + + bias_mean = False + if bias_mean: + # this is cool, bias the estimation towards the peak + Y = Y + np.mean(Y, axis=0) - np.percentile(Y, 5, axis=0) # TODO: avoid copy, also fix dims in case of square + + # normalise X and set lengthscale to 1 bin + mu_x = np.mean(X) + std_x = np.std(X) + X_scaled = (X - mu_x) / std_x + + lengthscale = 1 / std_x # 1 spatial bin + + mu_ystar = np.mean(Y) + std_ystar = np.std(Y) + + # take the mean and variance of Y replicates. Scale to the mean / standard deviation of all y + y_mean = np.mean(Y, axis=0) + y_var = np.std(Y, axis=0) + + Y_mean_scaled = (y_mean - mu_ystar) / std_ystar # standardise the normal way + Y_var_scaled = ( + 1 / std_ystar**2 + ) * y_var # this is a variance so need to scale to the square (TODO: see overleaf notes) + + kernel = GPy.kern.RBF(input_dim=1, lengthscale=lengthscale, variance=np.mean(Y_var_scaled)) # TODO: check this + + output_index2 = np.arange(num_bins) + Y_metadata2 = {"output_index": output_index2} + + likelihood = GPy.likelihoods.HeteroscedasticGaussian( + Y_metadata2, variance=Y_var_scaled + ) # one variance per y, but should be repeated for the same x + + gp = GPy.models.GPRegression(X_scaled.reshape(-1, 1), Y_mean_scaled.reshape(-1, 1), kernel, Y_metadata2) + + gp.likelihood = likelihood + + gp.optimize(messages=True) + + mean_pred, var_pred = gp.predict(X_scaled.reshape(-1, 1), Y_metadata=Y_metadata2) + + mean_pred = (mean_pred * std_ystar) + mu_ystar + var_pred = var_pred * std_ystar**2 + + std_pred = np.sqrt(var_pred) + + return mean_pred, std_pred, gp + + +# ############################################################################# +# TODO: MOVE creating recordings +# ############################################################################# + + +def compute_histogram_crosscorrelation( + session_histogram_list: list[np.ndarray], + non_rigid_windows: np.ndarray, + num_shifts_block: int, + interpolate: bool, + interp_factor: int, + kriging_sigma: float, + kriging_p: float, + kriging_d: float, + smoothing_sigma_bin: float, + smoothing_sigma_window: float, +): + """ + Given a list of session activity histograms, cross-correlate + all histograms returning the peak correlation shift (in indices) + in a symmetric (num_session x num_session) matrix. + + Supports non-rigid estimation by windowing the activity histogram + and performing separate cross-correlations on each window separately. + + Parameters + ---------- + + session_histogram_list : list[np.ndarray] + non_rigid_windows : np.ndarray + A (num windows x num_bins) binary of weights by which to window + the activity histogram for non-rigid-registration. For example, if + 2 rectangular masks were used, there would be a two row binary mask + the first row with mask of the first half of the probe and the second + row a mask for the second half of the probe. + num_shifts_block : int + Number of indices by which to shift the histogram to find the maximum + of the cross correlation. If `None`, the entire activity histograms + are cross-correlated. + interpolate : bool + If `True`, the cross-correlation is interpolated before maximum is taken. + interp_factor: + Factor by which to interpolate the cross-correlation. + kriging_sigma : float + sigma parameter for kriging_kernel function. See `kriging_kernel`. + kriging_p : float + p parameter for kriging_kernel function. See `kriging_kernel`. + kriging_d : float + d parameter for kriging_kernel function. See `kriging_kernel`. + smoothing_sigma_bin : float + sigma parameter for the gaussian smoothing kernel over the + spatial bins. + smoothing_sigma_window : float + sigma parameter for the gaussian smoothing kernel over the + non-rigid windows. + + Returns + ------- + + shift_matrix : ndarray + A (num_session x num_session) symmetric matrix of shifts + (indices) between pairs of session activity histograms. + + Notes + ----- + + - This function is very similar to the IterativeTemplateRegistration + function used in motion correct, though slightly difference in scope. + It was not convenient to merge them at this time, but worth looking + into in future. + + - Some obvious performances boosts, not done so because already fast + 1) the cross correlations for each session comparison are performed + twice. They are slightly different due to interpolation, but + still probably better to calculate once and flip. + 2) `num_shifts_block` is implemented by simply making the full + cross correlation. Would probably be nicer to explicitly calculate + only where needed. However, in general these cross correlations are + only a few thousand datapoints and so are already extremely + fast to cross correlate. + + Notes + ----- + + - The original kilosort method does not work in the inter-session + context because it averages over time bins to form a template to + align too. In this case, averaging over a small number of possibly + quite different session histograms does not work well. + + - In the nonrigid case, this strategy can completely fail when the xcorr + is very bad for a certain window. The smoothing and interpolation + make it much worse, because bad xcorr are merged together. The x-corr + can be bad when the recording is shifted a lot and so there are empty + regions that are correlated with non-empty regions in the nonrigid + approach. A different approach will need to be taken in this case. + + Note that kilosort method does not work because creating a + mean does not make sense over sessions. + """ + num_sessions = len(session_histogram_list) + num_bins = session_histogram_list.shape[1] # all hists are same length + num_windows = non_rigid_windows.shape[0] + + shift_matrix = np.zeros((num_sessions, num_sessions, num_windows)) + + center_bin = np.floor((num_bins * 2 - 1) / 2).astype(int) + + for i in range(num_sessions): + for j in range(num_sessions): + + # Create the (num windows, num_bins) matrix for this pair of sessions + + import matplotlib.pyplot as plt + + # TODO: plot everything + + num_iter = ( + num_bins * 2 - 1 if not num_shifts_block else num_shifts_block * 2 + ) # TODO: make sure this is clearly defined, it is either side... + xcorr_matrix = np.zeros((non_rigid_windows.shape[0], num_iter)) + + # For each window, window the session histograms (`window` is binary) + # and perform the cross correlations + for win_idx, window in enumerate(non_rigid_windows): + + # breakpoint() + # TODO: track the 2d histogram through all steps to check everything is working okay + + # TOOD: gaussian window with crosscorr, won't it strongly bias zero shifts by maximising the signal at 0? + # TODO: add weight option. + # TODO: damn this is slow for 2D, speed up. + if session_histogram_list.ndim == 3: + windowed_histogram_i = session_histogram_list[i, :] * window[:, np.newaxis] + windowed_histogram_j = session_histogram_list[j, :] * window[:, np.newaxis] + + from scipy.signal import correlate2d + + # carefully check indices + xcorr = correlate2d( + windowed_histogram_i - np.mean(windowed_histogram_i, axis=1)[:, np.newaxis], + windowed_histogram_j - np.mean(windowed_histogram_j, axis=1)[:, np.newaxis], + ) # TOOD: check speed, probs don't remove mean because we want zeros for unmasked version + + mid_idx = windowed_histogram_j.shape[1] - 1 + xcorr = xcorr[:, mid_idx] + + else: + windowed_histogram_i = session_histogram_list[i, :] * window + + window_target = True # this makes less sense now that things could be very far apart + if window_target: + windowed_histogram_j = session_histogram_list[j, :] * window + else: + windowed_histogram_j = session_histogram_list[j, :] + + xcorr = np.correlate(windowed_histogram_i, windowed_histogram_j, mode="full") + + # plt.plot(windowed_histogram_i) + # plt.plot(windowed_histogram_j) + # plt.show() + + if num_shifts_block: + window_indices = np.arange(center_bin - num_shifts_block, center_bin + num_shifts_block) + xcorr = xcorr[window_indices] + shift_center_bin = ( + num_shifts_block # np.floor(num_shifts_block / 2) # TODO: CHECK! and move out of loop! + ) + else: + shift_center_bin = center_bin + + # plt.plot(xcorr) + # plt.show() + + xcorr_matrix[win_idx, :] = xcorr + + # TODO: check absolute value of different bins, they are quite different (log scale, zero mean histograms) + # TODO: print out a load of quality metrics from this! + + # Smooth the cross-correlations across the bins + if smoothing_sigma_bin: + xcorr_matrix = gaussian_filter(xcorr_matrix, smoothing_sigma_bin, axes=1) + + # Smooth the cross-correlations across the windows + if num_windows > 1 and smoothing_sigma_window: + xcorr_matrix = gaussian_filter(xcorr_matrix, smoothing_sigma_window, axes=0) + + # Upsample the cross-correlation + if interpolate: + shifts = np.arange(xcorr_matrix.shape[1]) + shifts_upsampled = np.linspace(shifts[0], shifts[-1], shifts.size * interp_factor) + + K = kriging_kernel( + np.c_[np.ones_like(shifts), shifts], + np.c_[np.ones_like(shifts_upsampled), shifts_upsampled], + kriging_sigma, + kriging_p, + kriging_d, + ) + xcorr_matrix = np.matmul(xcorr_matrix, K, axes=[(-2, -1), (-2, -1), (-2, -1)]) + + xcorr_peak = np.argmax(xcorr_matrix, axis=1) / interp_factor + else: + xcorr_peak = np.argmax(xcorr_matrix, axis=1) + + # breakpoint() + + shift = xcorr_peak - shift_center_bin # center_bin + shift_matrix[i, j, :] = shift + + return shift_matrix + + +def shift_array_fill_zeros(array: np.ndarray, shift: int) -> np.ndarray: + """ + Shift an array by `shift` indices, padding with zero. + Samples going out of bounds are dropped i,e, the array is not + extended and samples are not wrapped around to the start of the array. + + Parameters + ---------- + + array : np.ndarray + The array to pad. + shift : int + Number of indices why which to shift the array. If positive, the + zeros are added from the end of the array. If negative, the zeros + are added from the start of the array. + + Returns + ------- + + cut_padded_array : np.ndarray + The `array` padded with zeros and cut down (i.e. out of bounds + samples dropped). + + """ + abs_shift = np.abs(shift) + pad_tuple = (0, abs_shift) if shift > 0 else (abs_shift, 0) + + if array.ndim == 2: + pad_tuple = (pad_tuple, (0, 0)) + + padded_hist = np.pad(array, pad_tuple, mode="constant") + + if padded_hist.ndim == 2: + cut_padded_array = padded_hist[abs_shift:, :] if shift >= 0 else padded_hist[:-abs_shift, :] # TOOD: tidy up + else: + cut_padded_array = padded_hist[abs_shift:] if shift >= 0 else padded_hist[:-abs_shift] + + return cut_padded_array + + +def akima_interpolate_nonrigid_shifts( + non_rigid_shifts: np.ndarray, + non_rigid_window_centers: np.ndarray, + spatial_bin_centers: np.ndarray, +): + """ + Perform Akima spline interpolation on a set of non-rigid shifts. + The non-rigid shifts are per segment of the probe, each segment + containing a number of channels. Interpolating these non-rigid + shifts to the spatial bin centers gives a more accurate shift + per channel. + + Parameters + ---------- + non_rigid_shifts : np.ndarray + non_rigid_window_centers : np.ndarray + spatial_bin_centers : np.ndarray + + Returns + ------- + interp_nonrigid_shifts : np.ndarray + An array (length num_spatial_bins) of shifts + interpolated from the non-rigid shifts. + + TODO + ---- + requires scipy 14 + """ + from scipy.interpolate import Akima1DInterpolator + + x = non_rigid_window_centers + xs = spatial_bin_centers + + num_sessions = non_rigid_shifts.shape[0] + num_bins = spatial_bin_centers.shape[0] + + interp_nonrigid_shifts = np.zeros((num_sessions, num_bins)) + for ses_idx in range(num_sessions): + + y = non_rigid_shifts[ses_idx] + y_new = Akima1DInterpolator(x, y, method="akima", extrapolate=True)(xs) + interp_nonrigid_shifts[ses_idx, :] = y_new + + return interp_nonrigid_shifts + + +def get_shifts_from_session_matrix(alignment_order: str, session_offsets_matrix: np.ndarray): + """ + Given a matrix of displacements between all sessions, find the + shifts (one per session) to bring the sessions into alignment. + + Parameters + ---------- + alignment_order : "to_middle" or "to_session_X" where + "N" is the number of the session to align to. + session_offsets_matrix : np.ndarray + The num_sessions x num_sessions symmetric matrix + of displacements between all sessions, generated by + `_compute_session_alignment()`. + + Returns + ------- + optimal_shift_indices : np.ndarray + A 1 x num_sessions array of shifts to apply to + each session in order to bring all sessions into + alignment. + """ + if alignment_order == "to_middle": + optimal_shift_indices = -np.mean(session_offsets_matrix, axis=0) + else: + ses_idx = int(alignment_order.split("_")[-1]) - 1 + optimal_shift_indices = -session_offsets_matrix[ses_idx, :, :] + + return optimal_shift_indices diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/plotting_session_alignment.py b/src/spikeinterface/preprocessing/inter_session_alignment/plotting_session_alignment.py new file mode 100644 index 0000000000..0f29d1f508 --- /dev/null +++ b/src/spikeinterface/preprocessing/inter_session_alignment/plotting_session_alignment.py @@ -0,0 +1,328 @@ +import itertools + +from spikeinterface.core import BaseRecording +import numpy as np +from spikeinterface.widgets.base import BaseWidget +from spikeinterface.widgets.base import to_attr +from spikeinterface.widgets.motion import DriftRasterMapWidget +from matplotlib.animation import FuncAnimation + +# TODO: decide on name, Displacement vs. Alignment + + +# Animation +# TODO: temp functions +def _plot_2d_histogram_as_animation(chunked_histogram): + fig, ax = plt.subplots() + im = ax.imshow(chunked_histograms[0, :, :], origin="lower", cmap="Blues", aspect="auto") + + def update(frame): + im.set_data(chunked_histograms[frame, :, :]) + ax.set_title(f"Slice {frame}") + return [im] + + FuncAnimation(fig, update, frames=chunked_histograms.shape[0], interval=100) + plt.show() + + +def _plot_session_histogram_and_variation(session_histogram, variation): + plt.imshow(session_histogram, origin="lower", cmap="Blues", aspect="auto") + plt.title("Summary Histogram") + plt.xlabel("Amplitude bin") + plt.ylabel("Depth (um)") + plt.show() + + plt.imshow(variation, origin="lower", cmap="Blues", aspect="auto") + plt.title("Variation") + plt.xlabel("Amplitude bin") + plt.ylabel("Depth (um)") + plt.show() + + +class SessionAlignmentWidget(BaseWidget): + def __init__( + self, + recordings_list: list[BaseRecording], + peaks_list: list[np.ndarray], + peak_locations_list: list[np.ndarray], + session_histogram_list: list[np.ndarray], + spatial_bin_centers: np.ndarray | None = None, + corrected_peak_locations_list: list[np.ndarray] | None = None, + corrected_session_histogram_list: list[np.ndarray] = None, + drift_raster_map_kwargs: dict | None = None, + session_alignment_histogram_kwargs: dict | None = None, + **backend_kwargs, + ): + """ + Widget to display the output of inter-session alignment. + In the top section, `DriftRasterMapWidget`s are used to display + the raster maps for each session, before and after alignment. + The order of all lists should correspond to the same recording. + + If histograms are provided, `SessionAlignmentHistogramWidget` + are used to show the activity histograms, before and after alignment. + See `align_sessions` for context. + + Corrected and uncorrected activity histograms are generated + as part of the `align_sessions` step. + + Parameters + ---------- + + recordings_list : list[BaseRecording] + List of recordings to plot. + peaks_list : list[np.ndarray] + List of detected peaks for each session. + peak_locations_list : list[np.ndarray] + List of detected peak locations for each session. + session_histogram_list : np.ndarray | None + A list of activity histograms as output from `align_sessions`. + If `None`, no histograms will be displayed. + spatial_bin_centers=None : np.ndarray | None + Spatial bin centers for the histogram (each session activity + histogram will have the same spatial bin centers). + corrected_peak_locations_list : list[np.ndarray] | None + A list of corrected peak locations. If provided, the corrected + raster plots will be displayed. + corrected_session_histogram_list : list[np.ndarray] + A list of corrected session activity histograms, as + output from `align_sessions`. + drift_raster_map_kwargs : dict | None + Kwargs to be passed to `DriftRasterMapWidget`. + session_alignment_histogram_kwargs : dict | None + Kwargs to be passed to `SessionAlignmentHistogramWidget`. + **backend_kwargs + """ + + # TODO: check all lengths more carefully e.g. histogram vs. peaks. + + assert len(recordings_list) <= 8, ( + "At present, this widget supports plotting up to 8 sessions. " + "Please contact SpikeInterface to discuss increasing." + ) + if corrected_session_histogram_list is not None: + if not len(corrected_session_histogram_list) == len(session_histogram_list): + raise ValueError( + "`corrected_session_histogram_list` must be the same length as `session_histogram_list`. " + "Entries should correspond exactly, with the histogram in each position being the corrected" + "version of `session_histogram_list`." + ) + if corrected_peak_locations_list is not None: + if not len(corrected_peak_locations_list) == len(peak_locations_list): + raise ValueError( + "`corrected_peak_locations_list` must be the same length as `peak_locations_list`. " + "Entries should correspond exactly, with the histogram in each position being the corrected" + "version of `peak_locations_list`." + ) + if (corrected_peak_locations_list is None) != (corrected_session_histogram_list is None): + raise ValueError( + "If either `corrected_peak_locations_list` or `corrected_session_histogram_list` " + "is passed, they must both be passed." + ) + + if drift_raster_map_kwargs is None: + drift_raster_map_kwargs = {} + + if session_alignment_histogram_kwargs is None: + session_alignment_histogram_kwargs = {} + + plot_data = dict( + recordings_list=recordings_list, + peaks_list=peaks_list, + peak_locations_list=peak_locations_list, + session_histogram_list=session_histogram_list, + spatial_bin_centers=spatial_bin_centers, + corrected_peak_locations_list=corrected_peak_locations_list, + corrected_session_histogram_list=corrected_session_histogram_list, + drift_raster_map_kwargs=drift_raster_map_kwargs, + session_alignment_histogram_kwargs=session_alignment_histogram_kwargs, + ) + + BaseWidget.__init__(self, plot_data, backend="matplotlib", **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + """ + Create the `SessionAlignmentWidget` for matplotlib. + """ + from spikeinterface.widgets.utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + # TODO: direct copy + assert backend_kwargs["axes"] is None, "axes argument is not allowed in MotionWidget" + assert backend_kwargs["ax"] is None, "ax argument is not allowed in MotionWidget" + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + fig = self.figure + fig.clear() + + # TODO: use self.axes I think. + min_y = np.min(np.hstack([locs["y"] for locs in dp.peak_locations_list])) + max_y = np.max(np.hstack([locs["y"] for locs in dp.peak_locations_list])) + + if dp.corrected_peak_locations_list is None: + # TODO: Own function + num_cols = np.min([4, len(dp.peak_locations_list)]) + num_rows = 1 if num_cols <= 4 else 2 + + ordered_row_col = list(itertools.product(range(num_rows), range(num_cols))) + + gs = fig.add_gridspec(num_rows + 1, num_cols, wspace=0.3, hspace=0.5) + + for i, row_col in enumerate(ordered_row_col): + + ax = fig.add_subplot(gs[row_col]) + + DriftRasterMapWidget( + dp.peaks_list[i], + dp.peak_locations_list[i], + recording=dp.recordings_list[i], + ax=ax, + **dp.drift_raster_map_kwargs, + ) + ax.set_ylim((min_y, max_y)) + else: + + # Own function, then see if can compare + num_cols = len(dp.peak_locations_list) + num_rows = 2 + + gs = fig.add_gridspec(num_rows + 1, num_cols, wspace=0.3, hspace=0.5) + + for i in range(num_cols): + + ax_top = fig.add_subplot(gs[0, i]) + ax_bottom = fig.add_subplot(gs[1, i]) + + DriftRasterMapWidget( + dp.peaks_list[i], + dp.peak_locations_list[i], + recording=dp.recordings_list[i], + ax=ax_top, + **dp.drift_raster_map_kwargs, + ) + ax_top.set_title(f"Session {i + 1}") + ax_top.set_xlabel(None) + ax_top.set_ylim((min_y, max_y)) + + DriftRasterMapWidget( + dp.peaks_list[i], + dp.corrected_peak_locations_list[i], + recording=dp.recordings_list[i], + ax=ax_bottom, + **dp.drift_raster_map_kwargs, + ) + ax_bottom.set_title(f"Corrected Session {i + 1}") + ax_bottom.set_ylim((min_y, max_y)) + + # TODO: then histograms. + num_sessions = len(dp.session_histogram_list) + + if "legend" not in dp.session_alignment_histogram_kwargs: + sessions = [f"session {i + 1}" for i in range(num_sessions)] + dp.session_alignment_histogram_kwargs["legend"] = sessions + + if not dp.corrected_session_histogram_list: + + ax = fig.add_subplot(gs[num_rows, :]) + + SessionAlignmentHistogramWidget( + dp.session_histogram_list, + dp.spatial_bin_centers, + ax=ax, + **dp.session_alignment_histogram_kwargs, + ) + ax.legend(loc="upper left") + else: + + gs_sub = gs[num_rows, :].subgridspec(1, 2) + + ax_left = fig.add_subplot(gs_sub[0]) + ax_right = fig.add_subplot(gs_sub[1]) + + SessionAlignmentHistogramWidget( + dp.session_histogram_list, + dp.spatial_bin_centers, + ax=ax_left, + **dp.session_alignment_histogram_kwargs, + ) + SessionAlignmentHistogramWidget( + dp.corrected_session_histogram_list, + dp.spatial_bin_centers, + ax=ax_right, + **dp.session_alignment_histogram_kwargs, + ) + ax_left.get_legend().set_loc("upper right") + ax_left.set_title("Original Histogram") + ax_right.get_legend().set_loc("upper right") + ax_right.set_title("Corrected Histogram") + + +class SessionAlignmentHistogramWidget(BaseWidget): + """ """ + + def __init__( + self, + session_histogram_list: list[np.ndarray], + spatial_bin_centers: list[np.ndarray] | np.ndarray | None, + legend: None | list[str] = None, + linewidths: None | list[float] = 2, + colors: None | list = None, + **backend_kwargs, + ): + + plot_data = dict( + session_histogram_list=session_histogram_list, + spatial_bin_centers=spatial_bin_centers, + legend=legend, + linewidths=linewidths, + colors=colors, + ) + + BaseWidget.__init__(self, plot_data, backend="matplotlib", **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + from spikeinterface.widgets.utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + legend = dp.legend + colors = dp.colors + linewidths = dp.linewidths + spatial_bin_centers = dp.spatial_bin_centers + + assert backend_kwargs["axes"] is None, "use `ax` to pass an axis to set." + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + num_histograms = len(dp.session_histogram_list) + + if isinstance(colors, int) or colors is None: + colors = [colors] * num_histograms + + if isinstance(linewidths, int): + linewidths = [linewidths] * num_histograms + + # TODO: this leads to quite unexpected behaviours, figure something else out. + if spatial_bin_centers is None: + num_bins = dp.session_histogram_list[0].size + spatial_bin_centers = [np.arange(num_bins)] * num_histograms + + elif isinstance(spatial_bin_centers, np.ndarray): + spatial_bin_centers = [spatial_bin_centers] * num_histograms + + if dp.session_histogram_list[0].ndim == 2: + histogram_list = [np.sum(hist_, axis=1) for hist_ in dp.session_histogram_list] + print("2D histogram passed, will be summed across first (i.e. amplitude) axis.") + else: + histogram_list = dp.session_histogram_list + + for i in range(num_histograms): + self.ax.plot(spatial_bin_centers[i], histogram_list[i], color=colors[i], linewidth=linewidths[i]) + + if legend is not None: + self.ax.legend(legend) + + self.ax.set_xlabel("Spatial bins (um)") + self.ax.set_ylabel("Firing rate (Hz)") # TODO: this is an assumption based on the + # output of histogram estimation diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py new file mode 100644 index 0000000000..949a740bf0 --- /dev/null +++ b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py @@ -0,0 +1,1286 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from spikeinterface.core.baserecording import BaseRecording + +import numpy as np +from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording +from spikeinterface.sortingcomponents.motion.motion_utils import get_spatial_windows, Motion, get_spatial_bins +from spikeinterface.sortingcomponents.motion.motion_interpolation import correct_motion_on_peaks + +from spikeinterface.preprocessing.inter_session_alignment import alignment_utils +from spikeinterface.preprocessing.motion import run_peak_detection_pipeline_node +import copy +import scipy + + +def get_estimate_histogram_kwargs() -> dict: + """ + A dictionary controlling how the histogram for each session is + computed. The session histograms are estimated by chunking + the recording into time segments and computing histograms + for each chunk, then performing some summary statistic over + the chunked histograms. + + Returns + ------- + A dictionary with entries: + + "bin_um" : number of spatial histogram bins. As the estimated peak + locations are continuous (i.e. real numbers) this is not constrained + by the number of channels. + "method" : may be "chunked_mean", "chunked_median", "chunked_supremum", + "chunked_poisson". Determines the summary statistic used over + the histograms computed across a session. See `alignment_utils.py + for details on each method. + "chunked_bin_size_s" : The length in seconds (float) to chunk the recording + for estimating the chunked histograms. Can be set to "estimate" (str), + and the size is estimated from firing frequencies. + "log_scale" : if `True`, histograms are log transformed. + "depth_smooth_um" : if `None`, no smoothing is applied. See + `make_2d_motion_histogram`. + """ + return { + "bin_um": 2, + "method": "chunked_mean", + "chunked_bin_size_s": "estimate", + "log_scale": False, + "depth_smooth_um": None, + "histogram_type": "activity_1d", + "weight_with_amplitude": True, + } + + +def get_compute_alignment_kwargs() -> dict: + """ + A dictionary with settings controlling how inter-session + alignment is estimated and computed given a set of + session activity histograms. + + All keys except for "non_rigid_window_kwargs" determine + how alignment is estimated, based on the kilosort ("kilosort_like" + in spikeinterface) motion correction method. See + `iterative_template_registration` for details. + + "non_rigid_window_kwargs" : if nonrigid alignment + is performed, this determines the nature of the + windows along the probe depth. See `get_spatial_windows`. + """ + return { + "num_shifts_block": 50, # TODO: estimate this properly, make take as some factor of the window width? Also check if it is 2x the block xcorr in motion correction + "interpolate": False, + "interp_factor": 10, + "kriging_sigma": 1, + "kriging_p": 2, + "kriging_d": 2, + "smoothing_sigma_bin": 0.5, + "smoothing_sigma_window": 0.5, + "akima_interp_nonrigid": False, + } + + +def get_non_rigid_window_kwargs(): + """ + see get_spatial_windows() for parameters. + + TODO + ---- + merge with motion correction kwargs which are + defined in the function signature. + """ + return { + "rigid_mode": "rigid", # "rigid", "rigid_nonrigid", "nonrigid" + "win_shape": "gaussian", + "win_step_um": 50, + "win_scale_um": 50, + "win_margin_um": None, + "zero_threshold": None, + } + + +def get_interpolate_motion_kwargs(): + """ + Settings to pass to `InterpolateMotionRecording`, + see that class for parameter descriptions. + """ + return {"border_mode": "remove_channels", "spatial_interpolation_method": "kriging", "sigma_um": 20.0, "p": 2} + + +############################################################################### +# Public Entry Level Functions +############################################################################### + +# TODO: sometimes with small bins, the interpolation spreads the signal over too small a bin and flattens it on the corrected histogram + + +def align_sessions( + recordings_list: list[BaseRecording], + peaks_list: list[np.ndarray], + peak_locations_list: list[np.ndarray], + alignment_order: str = "to_middle", + non_rigid_window_kwargs: dict = get_non_rigid_window_kwargs(), + estimate_histogram_kwargs: dict = get_estimate_histogram_kwargs(), + compute_alignment_kwargs: dict = get_compute_alignment_kwargs(), + interpolate_motion_kwargs: dict = get_interpolate_motion_kwargs(), +) -> tuple[list[BaseRecording], dict]: + """ + Estimate probe displacement across recording sessions and + return interpolated, displacement-corrected recording. Displacement + is only estimated along the "y" dimension. + + This assumes peaks and peak locations have already been computed. + See `compute_peaks_locations_for_session_alignment` for generating + `peaks_list` and `peak_locations_list` from a `recordings_list`. + + If a recording in `recordings_list` is already an `InterpolateMotionRecording`, + the displacement will be added to the existing shifts to avoid duplicate + interpolations. Note the returned, corrected recording is a copy + (recordings in `recording_list` are not edited in-place). + + Parameters + ---------- + recordings_list : list[BaseRecording] + A list of recordings to be aligned. + peaks_list : list[np.ndarray] + A list of peaks detected from the recordings in `recordings_list`, + as returned from the `detect_peaks` function. Each entry in + `peaks_list` should be from the corresponding entry in `recordings_list`. + peak_locations_list : list[np.ndarray] + A list of peak locations, as computed by `localize_peaks`. Each entry + in `peak_locations_list` should be matched to the corresponding entry + in `peaks_list` and `recordings_list`. + alignment_order : str + "to_middle" will align all sessions to the mean position. + Alternatively, "to_session_N" where "N" is a session number + will align to the Nth session. + non_rigid_window_kwargs : dict + see `get_non_rigid_window_kwargs` + estimate_histogram_kwargs : dict + see `get_estimate_histogram_kwargs()` + compute_alignment_kwargs : dict + see `get_compute_alignment_kwargs()` + interpolate_motion_kwargs : dict + see `get_interpolate_motion_kwargs()` + + Returns + ------- + `corrected_recordings_list : list[BaseRecording] + List of displacement-corrected recordings (corresponding + in order to `recordings_list`). If an input recordings is + an InterpolateMotionRecording` recording, the corrected + output recording will be a copy of the input recording with + the additional displacement correction added. + + extra_outputs_dict : dict + Dictionary of features used in the alignment estimation and correction. + + shifts_array : np.ndarray + A (num_sessions x num_rigid_windows) array of shifts. + session_histogram_list : list[np.ndarray] + A list of histograms (one per session) used for the alignment. + spatial_bin_centers : np.ndarray + The spatial bin centers, shared between all recordings. + temporal_bin_centers_list : list[np.ndarray] + List of temporal bin centers. As alignment is based on a single + histogram per session, this contains only 1 value per recording, + which is the mid-timepoint of the recording. + non_rigid_window_centers : np.ndarray + Window centers of the probe segments used for non-rigid alignment. + If rigid alignment is performed, this is a single value (mid-probe). + non_rigid_windows : np.ndarray + A (num nonrigid windows, num spatial_bin_centers) binary array used to mask + the probe segments for non-rigid alignment. If rigid alignment is performed, + this a vector of ones with length (spatial_bin_centers,) + histogram_info_list :list[dict] + see `_get_single_session_activity_histogram()` for details. + motion_objects_list : + List of motion objects containing the shifts and spatial and temporal + bins for each recording. Note this contains only displacement + associated with the inter-session alignment, and so will differ from + the motion on corrected recording objects if the recording is + already an `InterpolateMotionRecording` object containing + within-session motion correction. + corrected : dict + Dictionary containing corrected-histogram + information. + corrected_peak_locations_list : + Displacement-corrected `peak_locations`. + corrected_session_histogram_list : + Corrected activity histogram (computed from the corrected peak locations). + """ + non_rigid_window_kwargs = copy.deepcopy(non_rigid_window_kwargs) + estimate_histogram_kwargs = copy.deepcopy(estimate_histogram_kwargs) + compute_alignment_kwargs = copy.deepcopy(compute_alignment_kwargs) + interpolate_motion_kwargs = copy.deepcopy(interpolate_motion_kwargs) + + # Ensure list lengths match and all channel locations are the same across recordings. + _check_align_sessions_inputs( + recordings_list, peaks_list, peak_locations_list, alignment_order, estimate_histogram_kwargs + ) + + print("Computing a single activity histogram from each session...") + + (session_histogram_list, temporal_bin_centers_list, spatial_bin_centers, spatial_bin_edges, histogram_info_list) = ( + _compute_session_histograms(recordings_list, peaks_list, peak_locations_list, **estimate_histogram_kwargs) + ) + + print("Aligning the activity histograms across sessions...") + + contact_depths = recordings_list[0].get_channel_locations()[:, 1] + + shifts_array, non_rigid_windows, non_rigid_window_centers = _compute_session_alignment( + session_histogram_list, + contact_depths, + spatial_bin_centers, + alignment_order, + non_rigid_window_kwargs, + compute_alignment_kwargs, + ) + shifts_array *= estimate_histogram_kwargs["bin_um"] + + print("Creating corrected recordings...") + + corrected_recordings_list, motion_objects_list = _create_motion_recordings( + recordings_list, shifts_array, temporal_bin_centers_list, non_rigid_window_centers, interpolate_motion_kwargs + ) + + print("Creating corrected peak locations and histograms...") + + corrected_peak_locations_list, corrected_session_histogram_list = _correct_session_displacement( + corrected_recordings_list, + peaks_list, + peak_locations_list, + motion_objects_list, + spatial_bin_edges, + estimate_histogram_kwargs, + ) + + extra_outputs_dict = { + "shifts_array": shifts_array, + "session_histogram_list": session_histogram_list, + "spatial_bin_centers": spatial_bin_centers, + "temporal_bin_centers_list": temporal_bin_centers_list, + "non_rigid_window_centers": non_rigid_window_centers, + "non_rigid_windows": non_rigid_windows, + "histogram_info_list": histogram_info_list, + "motion_objects_list": motion_objects_list, + "corrected": { + "corrected_peak_locations_list": corrected_peak_locations_list, + "corrected_session_histogram_list": corrected_session_histogram_list, + }, + } + return corrected_recordings_list, extra_outputs_dict + + +def align_sessions_after_motion_correction( + recordings_list: list[BaseRecording], motion_info_list: list[dict], align_sessions_kwargs: dict | None +) -> tuple[list[BaseRecording], dict]: + """ + Convenience function to run `align_sessions` to correct for + inter-session displacement from the outputs of motion correction. + + The estimated displacement will be added directly to the recording. + + Parameters + ---------- + recordings_list : list[BaseRecording] + A list of motion-corrected (`InterpolateMotionRecording`) recordings. + motion_info_list : list[dict] + A list of `motion_info` objects, as output from `correct_motion`. + Each entry should correspond to a recording in `recording_list`. + align_sessions_kwargs : dict + A dictionary of keyword arguments passed to `align_sessions`. + + TODO + ---- + add a test that checks the output of motion_info created + by correct_motion is as expected. + """ + # Check motion kwargs are the same across all recordings + motion_kwargs_list = [info["parameters"]["estimate_motion_kwargs"] for info in motion_info_list] + if not all(kwargs == motion_kwargs_list[0] for kwargs in motion_kwargs_list): + raise ValueError( + "The motion correct settings used on the `recordings_list` must be identical for all recordings" + ) + + motion_window_kwargs = copy.deepcopy(motion_kwargs_list[0]) + if motion_window_kwargs["direction"] != "y": + raise ValueError("motion correct must have been performed along the 'y' dimension.") + + if align_sessions_kwargs is None: + align_sessions_kwargs = get_compute_alignment_kwargs() + + # If motion correction was nonrigid, we must use the same settings for + # inter-session alignment, or we will not be able to add the nonrigid + # shifts together. + if ( + "non_rigid_window_kwargs" in align_sessions_kwargs + and "nonrigid" in align_sessions_kwargs["non_rigid_window_kwargs"]["rigid_mode"] + ): + + if motion_window_kwargs["rigid"] is False: + print( + "Nonrigid inter-session alignment must use the motion correct " + "nonrigid settings. Overwriting any passed `non_rigid_window_kwargs` " + "with the motion object non_rigid_window_kwargs." + ) + motion_window_kwargs.pop("method") + motion_window_kwargs.pop("direction") + align_sessions_kwargs = copy.deepcopy(align_sessions_kwargs) + align_sessions_kwargs["non_rigid_window_kwargs"] = motion_window_kwargs + + return align_sessions( + recordings_list, + [info["peaks"] for info in motion_info_list], + [info["peak_locations"] for info in motion_info_list], + **align_sessions_kwargs, + ) + + +def compute_peaks_locations_for_session_alignment( + recording_list: list[BaseRecording], + detect_kwargs: dict, + localize_peaks_kwargs: dict, + job_kwargs: dict | None = None, + gather_mode: str = "memory", +): + """ + A convenience function to compute `peaks_list` and `peak_locations_list` + from a list of recordings, for `align_sessions`. + + Parameters + ---------- + recording_list : list[BaseRecording] + A list of recordings to compute `peaks` and + `peak_locations` for. + detect_kwargs : dict + Arguments to be passed to `detect_peaks`. + localize_peaks_kwargs : dict + Arguments to be passed to `localise_peaks`. + job_kwargs : dict | None + `job_kwargs` for `run_node_pipeline()`. + gather_mode : str + The mode for `run_node_pipeline()`. + """ + if job_kwargs is None: + job_kwargs = {} + + peaks_list = [] + peak_locations_list = [] + + for recording in recording_list: + peaks, peak_locations, _ = run_peak_detection_pipeline_node( + recording, gather_mode, detect_kwargs, localize_peaks_kwargs, job_kwargs + ) + peaks_list.append(peaks) + peak_locations_list.append(peak_locations) + + return peaks_list, peak_locations_list + + +############################################################################### +# Private Functions +############################################################################### + + +def _compute_session_histograms( + recordings_list: list[BaseRecording], + peaks_list: list[np.ndarray], + peak_locations_list: list[np.ndarray], + histogram_type, # TODO think up better names + bin_um: float, + method: str, + chunked_bin_size_s: float | "estimate", + depth_smooth_um: float, + log_scale: bool, + weight_with_amplitude: bool, +) -> tuple[list[np.ndarray], list[np.ndarray], np.ndarray, np.ndarray, list[dict]]: + """ + Compute a 1d activity histogram for the session. As + sessions may be long, the approach taken is to chunk + the recording into time segments and compute separate + histograms for each. Then, a summary statistic is computed + over the histograms. This accounts for periods of noise + in the recording or segments of irregular spiking. + + Parameters + ---------- + see `align_sessions` for `recording_list`, `peaks_list`, + `peak_locations_list`. + + see `get_estimate_histogram_kwargs()` for all other kwargs. + + Returns + ------- + + session_histogram_list : list[np.ndarray] + A list of activity histograms (1 x n_bins), one per session. + This is the histogram which summarises all chunked histograms. + + temporal_bin_centers_list : list[np.ndarray] + A list of temporal bin centers, one per session. We have one + histogram per session, the temporal bin has 1 entry, the + mid-time point of the session. + + spatial_bin_centers : np.ndarray + A list of spatial bin centers corresponding to the session + activity histograms. + + spatial_bin_edges : np.ndarray + The corresponding spatial bin edges + + histogram_info_list : list[dict] + A list of extra information on the histograms generation + (e.g. chunked histograms). One per session. See + `_get_single_session_activity_histogram()` for details. + """ + # Get spatial windows (shared across all histograms) + # and estimate the session histograms + temporal_bin_centers_list = [] + + spatial_bin_centers, spatial_bin_edges, _ = get_spatial_bins( + recordings_list[0], direction="y", hist_margin_um=0, bin_um=bin_um + ) + + session_histogram_list = [] + histogram_info_list = [] + + for recording, peaks, peak_locations in zip(recordings_list, peaks_list, peak_locations_list): + + session_hist, temporal_bin_centers, histogram_info = _get_single_session_activity_histogram( + recording, + peaks, + peak_locations, + histogram_type, + spatial_bin_edges, + method, + log_scale, + chunked_bin_size_s, + depth_smooth_um, + weight_with_amplitude, + ) + temporal_bin_centers_list.append(temporal_bin_centers) + session_histogram_list.append(session_hist) + histogram_info_list.append(histogram_info) + + return ( + session_histogram_list, + temporal_bin_centers_list, + spatial_bin_centers, + spatial_bin_edges, + histogram_info_list, + ) + + +def _get_single_session_activity_histogram( + recording: BaseRecording, + peaks: np.ndarray, + peak_locations: np.ndarray, + histogram_type, + spatial_bin_edges: np.ndarray, + method: str, + log_scale: bool, + chunked_bin_size_s: float | "estimate", + depth_smooth_um: float, + weight_with_amplitude: bool, +) -> tuple[np.ndarray, np.ndarray, dict]: + """ + Compute an activity histogram for a single session. + The recording is chunked into time segments, histograms + estimated and a summary statistic calculated across the histograms + + Note if `chunked_bin_size_is` is set to `"estimate"` the + histogram for the entire session is first created to get a good + estimate of the firing rates. + The firing rates are used to use a time segment size that will + allow a good estimation of the firing rate. + + Parameters + ---------- + `spatial_bin_edges : np.ndarray + The spatial bin edges for the created histogram. This is + explicitly required as for inter-session alignment, the + session histograms must share bin edges. + + see `_compute_session_histograms()` for all other keyword arguments. + + Returns + ------- + session_histogram : np.ndarray + Summary activity histogram for the session. + temporal_bin_centers : np.ndarray + Temporal bin center (session mid-point as we only have + one time point) for the session. + histogram_info : dict + A dict of additional info including: + "chunked_histograms" : The chunked histograms over which + the summary histogram was calculated. + "chunked_temporal_bin_centers" : The temporal vin centers + for the chunked histograms, with length num_chunks. + "session_std" : The mean across bin-wise standard deviation + of the chunked histograms. + "chunked_bin_size_s" : time of each chunk used to + calculate the chunked histogram. + """ + times = recording.get_times() + temporal_bin_centers = np.atleast_1d((times[-1] + times[0]) / 2) + + # Estimate an entire session histogram if requested or doing + # full estimation for chunked bin size + if chunked_bin_size_s == "estimate": + + one_bin_histogram, _, _ = alignment_utils.get_activity_histogram( + recording, + peaks, + peak_locations, + spatial_bin_edges, + log_scale=False, + bin_s=None, + depth_smooth_um=None, + scale_to_hz=False, + weight_with_amplitude=weight_with_amplitude, + ) + + # It is important that the passed histogram is scaled to firing rate in Hz + scaled_hist = one_bin_histogram / recording.get_duration() + chunked_bin_size_s = alignment_utils.estimate_chunk_size(scaled_hist) + chunked_bin_size_s = np.min([chunked_bin_size_s, recording.get_duration()]) + + if histogram_type == "activity_1d": + + chunked_histograms, chunked_temporal_bin_centers, _ = alignment_utils.get_activity_histogram( + recording, + peaks, + peak_locations, + spatial_bin_edges, + log_scale, + bin_s=chunked_bin_size_s, + depth_smooth_um=depth_smooth_um, + scale_to_hz=True, + ) + + elif histogram_type in ["activity_2d", "locations_2d"]: + + if histogram_type == "activity_2d": + from spikeinterface.sortingcomponents.motion.motion_utils import make_3d_motion_histograms + + chunked_histograms, chunked_temporal_bin_edges, _ = make_3d_motion_histograms( + recording, + peaks, + peak_locations, + direction="y", + bin_s=chunked_bin_size_s, + bin_um=None, + hist_margin_um=50, + num_amp_bins=20, # + log_transform=log_scale, + spatial_bin_edges=spatial_bin_edges, + ) + + else: + chunked_histograms, chunked_temporal_bin_edges = _get_peak_positions_as_histogram( + recording, spatial_bin_edges, chunked_bin_size_s, peaks, peak_locations + ) + + chunked_temporal_bin_centers = alignment_utils.get_bin_centers(chunked_temporal_bin_edges) + + if method == "chunked_mean": + session_histogram, hist_variability = alignment_utils.get_chunked_hist_mean(chunked_histograms) + + elif method == "chunked_median": + session_histogram, hist_variability = alignment_utils.get_chunked_hist_median(chunked_histograms) + + elif method == "chunked_supremum": + session_histogram, hist_variability = alignment_utils.get_chunked_hist_supremum(chunked_histograms) + + elif method == "chunked_poisson": + session_histogram, hist_variability = alignment_utils.get_chunked_hist_poisson_estimate(chunked_histograms) + + elif method == "first_eigenvector": + session_histogram, hist_variability = alignment_utils.get_chunked_hist_eigenvector(chunked_histograms) + + elif method == "chunked_gp": # TODO: better name + session_histogram, hist_variability, gp_model = alignment_utils.get_chunked_gaussian_process_regression( + chunked_histograms + ) + + # Take the average variability across bins as a summary measure. + session_mean_variability = np.mean(hist_variability) + + histogram_info = { + "chunked_histograms": chunked_histograms, + "chunked_temporal_bin_centers": chunked_temporal_bin_centers, + "session_mean_variability": session_mean_variability, + "chunked_bin_size_s": chunked_bin_size_s, + "session_histogram_variation": hist_variability, + } + + if method == "chunked_gp": + histogram_info.update({"gp_model": gp_model}) + + return session_histogram, temporal_bin_centers, histogram_info + + +def _get_peak_positions_as_histogram(recording, spatial_bin_edges, chunked_bin_size_s, peaks, peak_locations): + """ + This is just a temp function to see how it goes... + + # TODO: could add smoothing + """ + min_x = np.min(peak_locations["x"]) + max_x = np.max(peak_locations["x"]) + + num_x_bins = 20 # guess + x_bins = np.linspace(min_x, max_x, num_x_bins) + + # basically direct copy from make_3d_motion_histograms + n_samples = recording.get_num_samples() + mint_s = recording.sample_index_to_time(0) + maxt_s = recording.sample_index_to_time(n_samples - 1) + bin_s = chunked_bin_size_s + chunked_temporal_bin_edges = np.arange(mint_s, maxt_s + bin_s, bin_s) + + arr = np.zeros((peaks.size, 3), dtype="float64") + arr[:, 0] = recording.sample_index_to_time(peaks["sample_index"]) + arr[:, 1] = peak_locations["y"] + arr[:, 2] = peak_locations["x"] + + chunked_histograms, _ = np.histogramdd(arr, (chunked_temporal_bin_edges, spatial_bin_edges, x_bins)) + + return chunked_histograms, chunked_temporal_bin_edges + + +def _create_motion_recordings( + recordings_list: list[BaseRecording], + shifts_array: np.ndarray, + temporal_bin_centers_list: list[np.ndarray], + non_rigid_window_centers: np.ndarray, + interpolate_motion_kwargs: dict, +) -> tuple[list[BaseRecording], list[Motion]]: + """ + Given a set of recordings, motion shifts and bin information per-recording, + generate an InterpolateMotionRecording. If the recording is already an + InterpolateMotionRecording, then the shifts will be added to a copy + of it. Copies of the Recordings are made, nothing is changed in-place. + + Parameters + ---------- + shifts_array : num_sessions x num_nonrigid bins + + Returns + ------- + corrected_recordings_list : list[BaseRecording] + A list of InterpolateMotionRecording recordings of shift-corrected + recordings corresponding to `recordings_list`. + + motion_objects_list : list[Motion] + A list of Motion objects. If the recording in `recordings_list` + is already an InterpolateMotionRecording, this will be `None`, as + no motion object is created (the existing motion object is added to) + """ + assert all(array.ndim == 1 for array in shifts_array), "time dimension should be 1 for session displacement" + + corrected_recordings_list = [] + motion_objects_list = [] + for ses_idx, recording in enumerate(recordings_list): + + session_shift = shifts_array[ses_idx][np.newaxis, :] + + motion = Motion([session_shift], [temporal_bin_centers_list[ses_idx]], non_rigid_window_centers, direction="y") + motion_objects_list.append(motion) + + if isinstance(recording, InterpolateMotionRecording): + + print("Recording is already an `InterpolateMotionRecording. Adding shifts directly the recording object.") + + corrected_recording = _add_displacement_to_interpolate_recording(recording, motion) + else: + corrected_recording = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs) + + corrected_recordings_list.append(corrected_recording) + + return corrected_recordings_list, motion_objects_list + + +def _add_displacement_to_interpolate_recording( + original_recording: BaseRecording, + session_displacement_motion: Motion, +): + """ + This function adds a shift to an InterpolateMotionRecording. + + There are four cases: + - The original recording is rigid and new shift is rigid (shifts are added). + - The original recording is rigid and new shifts are non-rigid (sets the + non-rigid shifts onto the recording, then adds back the original shifts). + - The original recording is nonrigid and the new shifts are rigid (rigid + shift added to all nonlinear shifts) + - The original recording is nonrigid and the new shifts are nonrigid + (respective non-rigid shifts are added, must have same number of + non-rigid windows). + + Parameters + ---------- + see `_create_motion_recordings()` + + Returns + ------- + corrected_recording : InterpolateMotionRecording + A copy of the `recording` with new shifts added. + + TODO + ---- + Check + ask Sam if any other fields need to be changed. This is a little + hairy (4 possible combinations of new and old displacement shapes, + rigid or nonrigid, so test thoroughly. + """ + # Everything is done in place, so keep a short variable + # name reference to the new recordings `motion` object + # and update it.okay + corrected_recording = copy.deepcopy(original_recording) + + shifts_to_add = session_displacement_motion.displacement[0] + new_non_rigid_window_centers = session_displacement_motion.spatial_bins_um + + motion_ref = corrected_recording._recording_segments[0].motion + recording_bins = motion_ref.displacement[0].shape[1] + + # If the new displacement is a scalar (i.e. rigid), + # just add it to the existing displacements + if shifts_to_add.shape[1] == 1: + motion_ref.displacement[0] += shifts_to_add[0, 0] + + else: + if recording_bins == 1: + # If the new displacement is nonrigid (multiple windows) but the motion + # recording is rigid, we update the displacement at all time bins + # with the new, nonrigid displacement added to the old, rigid displacement. + num_time_bins = motion_ref.displacement[0].shape[0] + tiled_nonrigid_displacement = np.repeat(shifts_to_add, num_time_bins, axis=0) + shifts_to_add = tiled_nonrigid_displacement + motion_ref.displacement + + motion_ref.displacement = shifts_to_add + motion_ref.spatial_bins_um = new_non_rigid_window_centers + else: + # Otherwise, if both the motion and new displacement are + # nonrigid, we need to make sure the nonrigid windows + # match exactly. + assert np.array_equal(motion_ref.spatial_bins_um, new_non_rigid_window_centers) + assert motion_ref.displacement[0].shape[1] == shifts_to_add.shape[1] + + motion_ref.displacement[0] += shifts_to_add + + return corrected_recording + + +def _correct_session_displacement( + recordings_list: list[BaseRecording], + peaks_list: list[np.ndarray], + peak_locations_list: list[np.ndarray], + motion_objects_list: list[Motion], + spatial_bin_edges: np.ndarray, + estimate_histogram_kwargs: dict, +): + """ + Internal function to apply the correction from `align_sessions` + to build a corrected histogram for comparison. First, create + new shifted peak locations. Then, create a new 'corrected' + activity histogram from the new peak locations. + + Parameters + ---------- + see `align_sessions()` for parameters. + + Returns + ------- + corrected_peak_locations_list : list[np.ndarray] + A list of peak locations corrected by the inter-session + shifts (one entry per session). + corrected_session_histogram_list : list[np.ndarray] + A list of histograms calculated from the corrected peaks (one per session). + """ + corrected_peak_locations_list = [] + + for recording, peaks, peak_locations, motion in zip( + recordings_list, peaks_list, peak_locations_list, motion_objects_list + ): + + # Note this `motion` is not necessarily the same as the motion on the recording. If the recording + # is an `InterpolateMotionRecording`, it will contain correction for both motion and inter-session displacement. + # Here we want to correct only the motion associated with inter-session displacement. + corrected_peak_locs = correct_motion_on_peaks( + peaks, + peak_locations, + motion, + recording, + ) + corrected_peak_locations_list.append(corrected_peak_locs) + + corrected_session_histogram_list = [] + + for recording, peaks, corrected_locations in zip(recordings_list, peaks_list, corrected_peak_locations_list): + session_hist, _, _ = _get_single_session_activity_histogram( + recording, + peaks, + corrected_locations, + estimate_histogram_kwargs["histogram_type"], + spatial_bin_edges, + estimate_histogram_kwargs["method"], + estimate_histogram_kwargs["log_scale"], + estimate_histogram_kwargs["chunked_bin_size_s"], + estimate_histogram_kwargs["depth_smooth_um"], + estimate_histogram_kwargs["weight_with_amplitude"], + ) + corrected_session_histogram_list.append(session_hist) + + return corrected_peak_locations_list, corrected_session_histogram_list + + +def cross_correlate(sig1, sig2, thr=None): + xcorr = np.correlate(sig1, sig2, mode="full") + + n = sig1.size + low_cut_idx = np.arange(0, n - thr) # double check + high_cut_idx = np.arange(n + thr, 2 * n - 1) + + xcorr[low_cut_idx] = 0 + xcorr[high_cut_idx] = 0 + + if np.max(xcorr) < 0.01: + shift = 0 + else: + shift = np.argmax(xcorr) - xcorr.size // 2 + + return shift + + +def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plot=True): + """ """ + best_correlation = 0 + best_displacements = np.zeros_like(signa11_blanked) + + # TODO: use kriging interp + + xcorr = [] + + for scale in np.linspace(0.85, 1.15, 10): + + nonzero = np.where(signa11_blanked > 0)[0] + if not np.any(nonzero): + continue + + midpoint = nonzero[0] + np.ptp(nonzero) / 2 + x_scale = (x - midpoint) * scale + midpoint + + interp_f = scipy.interpolate.interp1d( + x_scale, signa11_blanked, fill_value=0.0, bounds_error=False + ) # TODO: try cubic etc... or Kriging + + scaled_func = interp_f(x) + + # plt.plot(signa11_blanked) + # plt.plot(scaled_func) + # plt.show() + + # breakpoint() + + for sh in np.arange(-thr, thr): # TODO: we are off by one here + + shift_signal1_blanked = alignment_utils.shift_array_fill_zeros(scaled_func, sh) + + x_shift = x_scale - sh # TODO: rename + + # is this pull back? + # interp_f = scipy.interpolate.interp1d(xs, shift_signal1_blanked, fill_value=0.0, bounds_error=False) # TODO: try cubic etc... or Kriging + + # scaled_func = interp_f(x_shift) + + corr_value = ( + np.correlate( + shift_signal1_blanked - np.mean(shift_signal1_blanked), + signal2_blanked - np.mean(signal2_blanked), + ) + / signa11_blanked.size + ) + + if corr_value > best_correlation: + best_displacements = x_shift + best_correlation = corr_value + + if False and np.abs(sh) == 1: + print(corr_value) + + plt.plot(shift_signal1_blanked) + plt.plot(signal2_blanked) + plt.show() + # plt.draw() # Draw the updated figure + # plt.pause(0.1) # Pause for 0.5 seconds before updating + # plt.clf() + + # breakpoint() + + # xcorr.append(np.max(np.r_[xcorr_scale])) + + if False: + xcorr = np.r_[xcorr] + # shift = np.argmax(xcorr) - thr + + print("MAX", np.max(xcorr)) + + if np.max(xcorr) < 0.0001: + shift = 0 + else: + shift = np.argmax(xcorr) - thr + + print("output shift", shift) + + return best_displacements + + +# plt.plot(signal1) +# plt.plot(signal2) + + +def get_shifts(signal1, signal2, windows, plot=True): + + import matplotlib.pyplot as plt + + signa11_blanked = signal1.copy() + signal2_blanked = signal2.copy() + + best_displacements = np.zeros_like(signal1) + + if (first_idx := windows[0][0]) != 0: + print("first idx", first_idx) + signa11_blanked[:first_idx] = 0 + signal2_blanked[:first_idx] = 0 + + if (last_idx := windows[-1][-1]) != signal1.size - 1: # double check + print("last idx", last_idx) + signa11_blanked[last_idx:] = 0 + signal2_blanked[last_idx:] = 0 + + segment_shifts = np.empty(len(windows)) + + x = np.arange(signa11_blanked.size) + x_orig = x.copy() + + for round in range(len(windows)): + + # if round == 0: + # shift = cross_correlate(signa11_blanked, signal2_blanked, thr=100) # for first rigid, do larger! + # else: + displacements = cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=200, plot=False) + + # breakpoint() + + interpf = scipy.interpolate.interp1d( + displacements, signa11_blanked, fill_value=0.0, bounds_error=False + ) # TODO: move away from this indexing sceheme + signa11_blanked = interpf(x) + + # cum_shifts.append(shift) + # print("shift", shift) + + # shift the signal1, or use indexing + + # signa11_blanked = shift_array_fill_zeros(signa11_blanked, shift) # INTERP HERE, KRIGING. but will accumulate interpolation errors... + + # if plot: + # print("round", round) + # plt.plot(signa11_blanked) + # plt.plot(signal2_blanked) + # plt.show() + + window_corrs = np.empty(len(windows)) + for i, idx in enumerate(windows): + window_corrs[i] = ( + np.correlate( + signa11_blanked[idx] - np.mean(signa11_blanked[idx]), + signal2_blanked[idx] - np.mean(signal2_blanked[idx]), + ) + / signa11_blanked[idx].size + ) + + max_window = np.argmax(window_corrs) # TODO: cutoff! + + if False: + small_shift = cross_correlate( + signa11_blanked[windows[max_window]], + signal2_blanked[windows[max_window]], + thr=windows[max_window].size // 2, + ) + signa11_blanked = alignment_utils.shift_array_fill_zeros(signa11_blanked, small_shift) + segment_shifts[max_window] = np.sum(cum_shifts) + small_shift + + best_displacements[windows[max_window]] = displacements[windows[max_window]] + + x = displacements + + signa11_blanked[windows[max_window]] = 0 + signal2_blanked[windows[max_window]] = 0 + + # TODO: need to carry over displacements! + + print(best_displacements) + interpf = scipy.interpolate.interp1d( + best_displacements, signal1, fill_value=0.0, bounds_error=False + ) # TODO: move away from this indexing sceheme + final = interpf(x_orig) + + # plt.plot(final) + # plt.plot(signal2) + # plt.show() + + return np.floor(best_displacements - x_orig) + + +def _compute_session_alignment( + session_histogram_list: list[np.ndarray], + contact_depths: np.ndarray, + spatial_bin_centers: np.ndarray, + alignment_order: str, + non_rigid_window_kwargs: dict, + compute_alignment_kwargs: dict, +) -> tuple[np.ndarray, ...]: + """ + Given a list of activity histograms (one per session) compute + rigid or non-rigid set of shifts (one per session) that will bring + all sessions into alignment. + + For rigid shifts, a cross-correlation between activity + histograms is performed. For non-rigid shifts, the probe + is split into segments, and linear estimation of shift + performed for each segment. + + Parameters + ---------- + See `align_sessions()` for parameters + + Returns + ------- + shifts : np.ndarray + A (num_sessions x num_rigid_windows) array of shifts to bring + the histograms in `session_histogram_list` into alignment. + non_rigid_windows : np.ndarray + An array (num_non_rigid_windows x num_spatial_bins) of weightings + for each bin in each window. For rect, these are in the range [0, 1], + for Gaussian these are gaussian etc. + non_rigid_window_centers : np.ndarray + The centers (spatial, in um) of each non-rigid window. + """ + session_histogram_array = np.array(session_histogram_list) + + akima_interp_nonrigid = compute_alignment_kwargs.pop("akima_interp_nonrigid") + + rigid_mode = non_rigid_window_kwargs.pop("rigid_mode") # TODO: carefully check all popped kwargs + non_rigid_window_kwargs["rigid"] = rigid_mode == "rigid" + + non_rigid_windows, non_rigid_window_centers = get_spatial_windows( + contact_depths, spatial_bin_centers, **non_rigid_window_kwargs + ) + + rigid_shifts = _estimate_rigid_alignment( + session_histogram_array, + alignment_order, + compute_alignment_kwargs, + ) + + if rigid_mode == "rigid": + return rigid_shifts, non_rigid_windows, non_rigid_window_centers + + # For non-rigid, first shift the histograms according to the rigid shift + + # When there is non-rigid drift, the rigid drift can be very wrong! + # So we depart from the kilosort approach for inter-session, + # for non-rigid, it makes sense to start without rigid alignment + shifted_histograms = session_histogram_array.copy() + + if rigid_mode == "rigid_nonrigid": # TOOD: add to docs + shifted_histograms = np.zeros_like(session_histogram_array) + for ses_idx, orig_histogram in enumerate(session_histogram_array): + + shifted_histogram = alignment_utils.shift_array_fill_zeros( + array=orig_histogram, shift=int(rigid_shifts[ses_idx, 0]) + ) + shifted_histograms[ses_idx, :] = shifted_histogram + + nonrigid_session_offsets_matrix = np.empty((shifted_histograms.shape[0], shifted_histograms.shape[0])) + + # windows = [] + # for i in range(non_rigid_windows.shape[0]): + # idxs = np.arange(non_rigid_windows.shape[1])[non_rigid_windows[i, :].astype(bool)] + # windows.append(idxs) + # TODO: check assumptions these are always the same size + # windows = np.vstack(windows) + + num_windows = non_rigid_windows.shape[0] + + windows = np.arange(shifted_histograms.shape[1]) + windows1 = np.array_split(windows, num_windows) + + # import matplotlib.pyplot as plt + # plt.plot(non_rigid_windows.T) + # plt.show() + # num_windows = + # windows1 = windows[::2, :] + + nonrigid_session_offsets_matrix = np.empty( + (shifted_histograms.shape[0], shifted_histograms.shape[0], spatial_bin_centers.size) + ) + + print("NUM WINDOWS: ", num_windows) + + for i in range(shifted_histograms.shape[0]): + for j in range(shifted_histograms.shape[0]): + + shifts1 = get_shifts(shifted_histograms[i, :], shifted_histograms[j, :], windows1, plot=True) + + # shifts2 = get_shifts(shifted_histograms[i, :], shifted_histograms[j, :], windows2) + # shifts = np.empty(shifts1.size + shifts1.size - 1) + # breakpoint() + # shifts[::2] = shifts1 + # shifts[1::2] = (shifts1[:-1] + shifts1[1:]) / 2 # np.shifts2 + # breakpoint() + nonrigid_session_offsets_matrix[i, j, :] = shifts1 + + # TODO: there are gaps in between rect, rect seems weird, they are non-overlapping :S + + # breakpoint() + # Then compute the nonrigid shifts + # nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation( + # shifted_histograms, non_rigid_windows, **compute_alignment_kwargs + # ) + non_rigid_shifts = nonrigid_session_offsets_matrix[ + 2, :, : + ] # alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix) + non_rigid_window_centers = spatial_bin_centers + shifts = non_rigid_shifts + + if False: + # Akima interpolate the nonrigid bins if required. + if akima_interp_nonrigid: + interp_nonrigid_shifts = alignment_utils.akima_interpolate_nonrigid_shifts( + non_rigid_shifts, non_rigid_window_centers, spatial_bin_centers + ) + shifts = interp_nonrigid_shifts # rigid_shifts + interp_nonrigid_shifts + non_rigid_window_centers = spatial_bin_centers + else: + # TODO: so check + add a test, the interpolator will handle this? + shifts = non_rigid_shifts # rigid_shifts + non_rigid_shifts + + if rigid_mode == "rigid_nonrigid": + shifts += rigid_shifts + + return shifts, non_rigid_windows, non_rigid_window_centers + + +def _estimate_rigid_alignment( + session_histogram_array: np.ndarray, + alignment_order: str, + compute_alignment_kwargs: dict, +): + """ + Estimate the rigid alignment from a set of activity + histograms, using simple cross-correlation. + + Parameters + ---------- + session_histogram_array : np.ndarray + A (num_sessions x num_spatial_bins) array of activity + histograms to align + alignment_order : str + Align "to_middle" or "to_session_N" (where "N" is the session number) + compute_alignment_kwargs : dict + See `get_compute_alignment_kwargs()`. + + Returns + ------- + optimal_shift_indices : np.ndarray + An array (num_sessions x 1) of shifts to bring all + session histograms into alignment. + """ + compute_alignment_kwargs = copy.deepcopy(compute_alignment_kwargs) + compute_alignment_kwargs["num_shifts_block"] = False + + rigid_window = np.ones(session_histogram_array.shape[1])[np.newaxis, :] + + rigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation( + session_histogram_array, + rigid_window, + **compute_alignment_kwargs, # TODO: remove the copy above and pass directly. COnsider removing this function... + ) + optimal_shift_indices = alignment_utils.get_shifts_from_session_matrix( + alignment_order, rigid_session_offsets_matrix + ) + + return optimal_shift_indices + + +# ----------------------------------------------------------------------------- +# Checkers +# ----------------------------------------------------------------------------- + + +def _check_align_sessions_inputs( + recordings_list: list[BaseRecording], + peaks_list: list[np.ndarray], + peak_locations_list: list[np.ndarray], + alignment_order: str, + estimate_histogram_kwargs: dict, +): + """ + Perform checks on the input of `align_sessions()` + """ + num_sessions = len(recordings_list) + + if len(peaks_list) != num_sessions or len(peak_locations_list) != num_sessions: + raise ValueError( + "`recordings_list`, `peaks_list` and `peak_locations_list` " + "must be the same length. They must contains list of corresponding " + "recordings, peak and peak location objects." + ) + + if not all(rec.get_num_segments() == 1 for rec in recordings_list): + raise ValueError( + "Multi-segment recordings not supported. All recordings in `recordings_list` but have only 1 segment." + ) + + channel_locs = [rec.get_channel_locations() for rec in recordings_list] + if not all(np.array_equal(locs, channel_locs[0]) for locs in channel_locs): + raise ValueError( + "The recordings in `recordings_list` do not all have " + "the same channel locations. All recordings must be " + "performed using the same probe." + ) + + accepted_hist_methods = [ + "entire_session", + "chunked_mean", + "chunked_median", + "chunked_supremum", + "first_eigenvector", + "chunked_gp", + ] + method = estimate_histogram_kwargs["method"] + if method not in accepted_hist_methods: + raise ValueError(f"`method` option must be one of: {accepted_hist_methods}") + + if alignment_order != "to_middle": + + split_name = alignment_order.split("_") + if not "_".join(split_name[:2]) == "to_session": + raise ValueError( + "`alignment_order` must take the form 'to_session_X' where X is the session number to align to." + ) + + ses_num = int(split_name[-1]) + if ses_num > num_sessions: + raise ValueError( + f"`alignment_order` session {ses_num} is larger than the number of sessions in `recordings_list`." + ) + + if ses_num == 0: + raise ValueError("`alignment_order` required the session number, not session index.") diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 14c565a290..36d326b980 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -337,11 +337,10 @@ def correct_motion( for plotting. See `plot_motion_info()` """ # local import are important because "sortingcomponents" is not important by default - from spikeinterface.sortingcomponents.peak_detection import detect_peaks, detect_peak_methods + from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks - from spikeinterface.sortingcomponents.peak_localization import localize_peaks, localize_peak_methods + from spikeinterface.sortingcomponents.peak_localization import localize_peaks from spikeinterface.sortingcomponents.motion import estimate_motion, InterpolateMotionRecording - from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline # get preset params and update if necessary params = motion_options_preset[preset] @@ -385,34 +384,11 @@ def correct_motion( if not do_selection: # maybe do this directly in the folder when not None, but might be slow on external storage gather_mode = "memory" - # node detect - method = detect_kwargs.pop("method", "locally_exclusive") - method_class = detect_peak_methods[method] - node0 = method_class(recording, **detect_kwargs) - - node1 = ExtractDenseWaveforms(recording, parents=[node0], ms_before=0.1, ms_after=0.3) - - # node detect + localize - method = localize_peaks_kwargs.pop("method", "center_of_mass") - method_class = localize_peak_methods[method] - node2 = method_class(recording, parents=[node0, node1], return_output=True, **localize_peaks_kwargs) - pipeline_nodes = [node0, node1, node2] - t0 = time.perf_counter() - peaks, peak_locations = run_node_pipeline( - recording, - pipeline_nodes, - job_kwargs, - job_name="detect and localize", - gather_mode=gather_mode, - gather_kwargs=None, - squeeze_output=False, - folder=None, - names=None, - ) - t1 = time.perf_counter() - run_times = dict( - detect_and_localize=t1 - t0, + + peaks, peak_locations, peaks_run_time = run_peak_detection_pipeline_node( + recording, gather_mode, detect_kwargs, localize_peaks_kwargs, job_kwargs ) + run_times = dict(detect_and_localize=peaks_run_time) else: # localization is done after select_peaks() pipeline_nodes = None @@ -462,6 +438,43 @@ def correct_motion( return out +def run_peak_detection_pipeline_node(recording, gather_mode, detect_kwargs, localize_peaks_kwargs, job_kwargs): + """ + TODO: add docstring + """ + from spikeinterface.sortingcomponents.peak_detection import detect_peak_methods + from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline + from spikeinterface.sortingcomponents.peak_localization import localize_peak_methods + + # node detect + method = detect_kwargs.pop("method", "locally_exclusive") + method_class = detect_peak_methods[method] + node0 = method_class(recording, **detect_kwargs) + + node1 = ExtractDenseWaveforms(recording, parents=[node0], ms_before=0.1, ms_after=0.3) + + # node detect + localize + method = localize_peaks_kwargs.pop("method", "center_of_mass") + method_class = localize_peak_methods[method] + node2 = method_class(recording, parents=[node0, node1], return_output=True, **localize_peaks_kwargs) + pipeline_nodes = [node0, node1, node2] + t0 = time.perf_counter() + peaks, peak_locations = run_node_pipeline( + recording, + pipeline_nodes, + job_kwargs, + job_name="detect and localize", + gather_mode=gather_mode, + gather_kwargs=None, + squeeze_output=False, + folder=None, + names=None, + ) + run_time = time.perf_counter() - t0 + + return peaks, peak_locations, run_time + + _doc_presets = "\n" for k, v in motion_options_preset.items(): if k == "": diff --git a/src/spikeinterface/sortingcomponents/motion/decentralized.py b/src/spikeinterface/sortingcomponents/motion/decentralized.py index a6bb9a5145..e6c452eccd 100644 --- a/src/spikeinterface/sortingcomponents/motion/decentralized.py +++ b/src/spikeinterface/sortingcomponents/motion/decentralized.py @@ -3,7 +3,14 @@ from tqdm.auto import tqdm, trange -from .motion_utils import Motion, get_spatial_windows, get_spatial_bin_edges, make_2d_motion_histogram, scipy_conv1d +from .motion_utils import ( + Motion, + get_spatial_windows, + get_spatial_bin_edges, + make_2d_motion_histogram, + scipy_conv1d, + get_spatial_bins, +) from .dredge import normxcorr1d @@ -135,13 +142,9 @@ def run( lsqr_robust_n_iter=20, weight_with_amplitude=False, ): - - dim = ["x", "y", "z"].index(direction) - contact_depths = recording.get_channel_locations()[:, dim] - - # spatial histogram bins - spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um) - spatial_bin_centers = 0.5 * (spatial_bin_edges[1:] + spatial_bin_edges[:-1]) + spatial_bin_centers, spatial_bin_edges, contact_depths = get_spatial_bins( + recording, direction, hist_margin_um, bin_um + ) # get spatial windows non_rigid_windows, non_rigid_window_centers = get_spatial_windows( diff --git a/src/spikeinterface/sortingcomponents/motion/iterative_template.py b/src/spikeinterface/sortingcomponents/motion/iterative_template.py index 1b5eb75508..5212f51bea 100644 --- a/src/spikeinterface/sortingcomponents/motion/iterative_template.py +++ b/src/spikeinterface/sortingcomponents/motion/iterative_template.py @@ -288,6 +288,8 @@ def iterative_template_registration( return optimal_shift_indices, target_spikecount_hist, shift_covs_block +# TODO: this is duplicate of get_kriging_kernel_distance() but that +# doesnt expose d parameter, could combine? def kriging_kernel(source_location, target_location, sigma=1, p=2, d=2): from scipy.spatial.distance import cdist diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index 680d75f221..00a7dd6e05 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -59,7 +59,10 @@ def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y" def check_properties(self): assert all(d.ndim == 2 for d in self.displacement) assert all(t.ndim == 1 for t in self.temporal_bins_s) - assert all(self.spatial_bins_um.shape == (d.shape[1],) for d in self.displacement) + try: + assert all(self.spatial_bins_um.shape == (d.shape[1],) for d in self.displacement) + except: + breakpoint() def __repr__(self): nbins = self.spatial_bins_um.shape[0] @@ -68,7 +71,13 @@ def __repr__(self): else: rigid_txt = f"non-rigid - {nbins} spatial bins" - interval_s = self.temporal_bins_s[0][1] - self.temporal_bins_s[0][0] + if self.temporal_bins_s[0].size > 1: + interval_s = self.temporal_bins_s[0][1] - self.temporal_bins_s[0][0] + else: + # If there is only one temporal bin (entire session), assume the bin + # left edge is zero, and take twice it for the bin size. + interval_s = self.temporal_bins_s[0][0] * 2 + txt = f"Motion {rigid_txt} - interval {interval_s}s - {self.num_segments} segments" return txt @@ -150,6 +159,12 @@ def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_inde # reshape to grid domain shape if necessary displacement = displacement.reshape(out_shape) + # TODO: hacky + if self.temporal_bins_s[segment_index].size == 1 and self.spatial_bins_um.size == 1: + assert np.all(np.isnan(displacement)) + assert self.displacement[segment_index].size == 1 + displacement[:] = self.displacement[segment_index] + return displacement def to_dict(self): @@ -318,6 +333,8 @@ def get_spatial_windows( window_centers = np.arange(num_windows + 1) * win_step_um + min_ + border windows = [] + print("CENTERS: ", window_centers.size) + for win_center in window_centers: if win_shape == "gaussian": win = np.exp(-((spatial_bin_centers - win_center) ** 2) / (2 * win_scale_um**2)) @@ -400,6 +417,18 @@ def get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um): return spatial_bins +def get_spatial_bins(recording, direction, hist_margin_um, bin_um): + # TODO: could this be merged with the above function? + dim = ["x", "y", "z"].index(direction) + contact_depths = recording.get_channel_locations()[:, dim] + + # spatial histogram bins + spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um) + spatial_bin_centers = 0.5 * (spatial_bin_edges[1:] + spatial_bin_edges[:-1]) + + return spatial_bin_centers, spatial_bin_edges, contact_depths + + def make_2d_motion_histogram( recording, peaks,