From 8e6e77ba974e0bf6557b373e6044d24ed23fd02d Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Tue, 16 Jul 2024 20:17:22 +0100 Subject: [PATCH] Edit motion interpolator for 1d case --- .../motion/motion_interpolation.py | 43 +++++++++++-------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 3056578070..1425a33d7a 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -120,20 +120,22 @@ def interpolate_motion_on_traces( time_bins = interpolation_time_bin_centers_s if time_bins is None: time_bins = motion.temporal_bins_s[segment_index] - bin_s = ( - time_bins[1] - time_bins[0] if time_bins.size > 1 else time_bins * 2 - ) # TODO: check this is * 2 but yes must be because its in the middle NO ITS NOT if first time is not 0 - # must use a different stragery - bins_start = time_bins[0] - 0.5 * bin_s - # nearest bin center for each frame? - bin_inds = (times - bins_start) // bin_s - bin_inds = bin_inds.astype(int) - # the time bins may not cover the whole set of times in the recording, - # so we need to clip these indices to the valid range - np.clip(bin_inds, 0, time_bins.size, out=bin_inds) - - # -- what are the possibilities here anyway? - bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1) # TODO: just replace this with 0 + + if time_bins.size == 1: + bins_here = [0] + else: + bin_s = time_bins[1] - time_bins[0] + # must use a different stragery + bins_start = time_bins[0] - 0.5 * bin_s + # nearest bin center for each frame? + bin_inds = (times - bins_start) // bin_s + bin_inds = bin_inds.astype(int) + # the time bins may not cover the whole set of times in the recording, + # so we need to clip these indices to the valid range + np.clip(bin_inds, 0, time_bins.size, out=bin_inds) + + # -- what are the possibilities here anyway? + bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1) # inperpolation kernel will be the same per temporal bin interp_times = np.empty(total_num_chans) @@ -168,16 +170,19 @@ def interpolate_motion_on_traces( # plt.show() # quickly find the end of this bin, which is also the start of the next - next_start_index = current_start_index + np.searchsorted( - bin_inds[current_start_index:], bin_ind + 1, side="left" - ) - in_bin = slice(current_start_index, next_start_index) + if time_bins.size == 1: + in_bin = None + else: + next_start_index = current_start_index + np.searchsorted( + bin_inds[current_start_index:], bin_ind + 1, side="left" + ) + in_bin = slice(current_start_index, next_start_index) + current_start_index = next_start_index # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing # in ChunkRecordingExecutor) np.matmul(traces[in_bin], drift_kernel, out=traces_corrected[in_bin]) - current_start_index = next_start_index return traces_corrected