Skip to content

Commit

Permalink
Play around with slope drift for generate_drifting_recording.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Jul 15, 2024
1 parent 9509da0 commit 58a6962
Showing 1 changed file with 43 additions and 2 deletions.
45 changes: 43 additions & 2 deletions src/spikeinterface/preprocessing/inter_session_displacement.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import copy

import numpy as np
import json
from pathlib import Path
Expand Down Expand Up @@ -45,6 +47,7 @@ def correct_inter_session_displacement(
from spikeinterface.sortingcomponents.motion_estimation import estimate_motion
from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording
from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline
from spikeinterface.sortingcomponents.motion_utils import Motion

# TODO: do not accept multi-segment recordings.
# TODO: check all recordings have the same probe dimensions!
Expand Down Expand Up @@ -128,6 +131,7 @@ def correct_inter_session_displacement(
spatial_bin_edges=None,
)
else:
assert NotImplementedError
motion_histogram = make_3d_motion_histograms(
recording,
peaks,
Expand All @@ -141,6 +145,9 @@ def correct_inter_session_displacement(
spatial_bin_edges=None,
)
motion_histogram_list.append(motion_histogram[0].squeeze())
# store bin edges
temporal_bin_edges = motion_histogram[1]
spatial_bin_edges = motion_histogram[2]

# Do some checks on temporal and spatial bin edges that they are all the same?
# TODO: do some smoothing? Try some other methds (e.g. NMI, KL divergence)
Expand Down Expand Up @@ -180,8 +187,42 @@ def correct_inter_session_displacement(
# TODO: do multi-session optimisation

# Handle drift
interpolate_motion_kwargs = {}

# TODO: add motion to motion if exists otherwise create InterpolateMotionRecording object!
# Will need the y-axis bins for this
motion = Motion([motion_array], [temporal_bins], non_rigid_window_centers, direction=direction)
recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs)
all_recording_corrected = []
all_motion_info = []
for i, recording in enumerate(recordings_list):

# TODO: direct copy, use 'get_window' from motion machinery
bin_centers = spatial_bin_edges[:-1] + bin_um / 2.0
n = bin_centers.size
non_rigid_windows = [np.ones(n, dtype="float64")]
middle = (spatial_bin_edges[0] + spatial_bin_edges[-1]) / 2.0
non_rigid_window_centers = np.array([middle])

motion_array = shifts[i] # TODO: this is the rigid case!
temporal_bins = 0.5 * (temporal_bin_edges[1:] + temporal_bin_edges[:-1])
motion = Motion(
[np.atleast_2d(motion_array)], [temporal_bins], non_rigid_window_centers, direction="y"
) # will be same for all except for shifts
all_motion_info.append(motion) # not certain on this

if isinstance(recording, InterpolateMotionRecording):
raise NotImplementedError
recording_corrected = copy.deepcopy(recording)
# TODO: add interpolation to the existing one.
# Not if inter-session motion correction already exists, but further
# up the preprocessing chain, it will NOT be added and interpolation
# will occur twice. Throw a warning here!
else:
recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs)
all_recording_corrected.append(recording_corrected)

displacement_info = {
"all_motion_info": all_motion_info,
"all_motion_histograms": motion_histogram_list, # TODO: naming
"all_shifts": shifts,
}
return all_recording_corrected, displacement_info # TODO: output more stuff later e.g. the Motion object

0 comments on commit 58a6962

Please sign in to comment.