diff --git a/src/spikeinterface/preprocessing/inter_session_displacement.py b/src/spikeinterface/preprocessing/inter_session_displacement.py index a53897a762..5702d3658a 100644 --- a/src/spikeinterface/preprocessing/inter_session_displacement.py +++ b/src/spikeinterface/preprocessing/inter_session_displacement.py @@ -1,5 +1,7 @@ from __future__ import annotations +import copy + import numpy as np import json from pathlib import Path @@ -45,6 +47,7 @@ def correct_inter_session_displacement( from spikeinterface.sortingcomponents.motion_estimation import estimate_motion from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline + from spikeinterface.sortingcomponents.motion_utils import Motion # TODO: do not accept multi-segment recordings. # TODO: check all recordings have the same probe dimensions! @@ -128,6 +131,7 @@ def correct_inter_session_displacement( spatial_bin_edges=None, ) else: + assert NotImplementedError motion_histogram = make_3d_motion_histograms( recording, peaks, @@ -141,6 +145,9 @@ def correct_inter_session_displacement( spatial_bin_edges=None, ) motion_histogram_list.append(motion_histogram[0].squeeze()) + # store bin edges + temporal_bin_edges = motion_histogram[1] + spatial_bin_edges = motion_histogram[2] # Do some checks on temporal and spatial bin edges that they are all the same? # TODO: do some smoothing? Try some other methds (e.g. NMI, KL divergence) @@ -180,8 +187,42 @@ def correct_inter_session_displacement( # TODO: do multi-session optimisation # Handle drift + interpolate_motion_kwargs = {} # TODO: add motion to motion if exists otherwise create InterpolateMotionRecording object! # Will need the y-axis bins for this - motion = Motion([motion_array], [temporal_bins], non_rigid_window_centers, direction=direction) - recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs) + all_recording_corrected = [] + all_motion_info = [] + for i, recording in enumerate(recordings_list): + + # TODO: direct copy, use 'get_window' from motion machinery + bin_centers = spatial_bin_edges[:-1] + bin_um / 2.0 + n = bin_centers.size + non_rigid_windows = [np.ones(n, dtype="float64")] + middle = (spatial_bin_edges[0] + spatial_bin_edges[-1]) / 2.0 + non_rigid_window_centers = np.array([middle]) + + motion_array = shifts[i] # TODO: this is the rigid case! + temporal_bins = 0.5 * (temporal_bin_edges[1:] + temporal_bin_edges[:-1]) + motion = Motion( + [np.atleast_2d(motion_array)], [temporal_bins], non_rigid_window_centers, direction="y" + ) # will be same for all except for shifts + all_motion_info.append(motion) # not certain on this + + if isinstance(recording, InterpolateMotionRecording): + raise NotImplementedError + recording_corrected = copy.deepcopy(recording) + # TODO: add interpolation to the existing one. + # Not if inter-session motion correction already exists, but further + # up the preprocessing chain, it will NOT be added and interpolation + # will occur twice. Throw a warning here! + else: + recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs) + all_recording_corrected.append(recording_corrected) + + displacement_info = { + "all_motion_info": all_motion_info, + "all_motion_histograms": motion_histogram_list, # TODO: naming + "all_shifts": shifts, + } + return all_recording_corrected, displacement_info # TODO: output more stuff later e.g. the Motion object