-
Notifications
You must be signed in to change notification settings - Fork 189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix a cross-band interpolation bug, and allow time_vector in interpolate_motion #3517
Changes from 18 commits
507b6b3
4e38ac1
726170b
e791fe1
82e2600
d8f39b5
0a201e1
ad00beb
28527d2
b3b3fcf
b02860e
df24840
91fb732
b80bad7
c890603
6d2e479
ee29fae
b4c91a0
38e0ada
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,8 @@ | |
from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment | ||
from spikeinterface.preprocessing.filter import fix_dtype | ||
|
||
from .motion_utils import ensure_time_bin_edges, ensure_time_bins | ||
|
||
|
||
def correct_motion_on_peaks(peaks, peak_locations, motion, recording) -> np.ndarray: | ||
""" | ||
|
@@ -54,14 +56,19 @@ def interpolate_motion_on_traces( | |
segment_index=None, | ||
channel_inds=None, | ||
interpolation_time_bin_centers_s=None, | ||
interpolation_time_bin_edges_s=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of interest, what is the use-case for allowing edges to be passed instead of centres? say vs. requiring centres only? I find this signature and the code necessary to handle either centres or edges a little confusing, but agree there are few option that allow this level of flexibility. I guess these options typically not user-facing anyway? i.e. most users would be using the motion pipeline and can safely ignore this. Also, a docstring addition in Parameters for |
||
spatial_interpolation_method="kriging", | ||
spatial_interpolation_kwargs={}, | ||
dtype=None, | ||
): | ||
""" | ||
Apply inverse motion with spatial interpolation on traces. | ||
|
||
Traces can be full traces, but also waveforms snippets. | ||
Traces can be full traces, but also waveforms snippets. Times used for looking up | ||
displacements are controlled by interpolation_time_bin_edges_s or | ||
interpolation_time_bin_centers_s, or fall back to the Motion object's time bins | ||
by default; times in the recording outside these time bins use the closest edge | ||
bin's displacement value during interpolation. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -80,6 +87,9 @@ def interpolate_motion_on_traces( | |
interpolation_time_bin_centers_s : None or np.array | ||
Manually specify the time bins which the interpolation happens | ||
in for this segment. If None, these are the motion estimate's time bins. | ||
interpolation_time_bin_edges_s : None or np.array | ||
If present, interpolation chunks will be the time bins defined by these edges | ||
rather than interpolation_time_bin_centers_s or the motion's bins. | ||
spatial_interpolation_method : "idw" | "kriging", default: "kriging" | ||
The spatial interpolation method used to interpolate the channel locations: | ||
* idw : Inverse Distance Weighing | ||
|
@@ -119,26 +129,33 @@ def interpolate_motion_on_traces( | |
total_num_chans = channel_locations.shape[0] | ||
|
||
# -- determine the blocks of frames that will land in the same interpolation time bin | ||
time_bins = interpolation_time_bin_centers_s | ||
if time_bins is None: | ||
time_bins = motion.temporal_bins_s[segment_index] | ||
bin_s = time_bins[1] - time_bins[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to check my understanding, the |
||
bins_start = time_bins[0] - 0.5 * bin_s | ||
# nearest bin center for each frame? | ||
bin_inds = (times - bins_start) // bin_s | ||
bin_inds = bin_inds.astype(int) | ||
if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None: | ||
interpolation_time_bin_centers_s = motion.temporal_bin_centers_s[segment_index] | ||
interpolation_time_bin_edges_s = motion.temporal_bin_edges_s[segment_index] | ||
else: | ||
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins( | ||
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s | ||
) | ||
|
||
# bin the frame times according to the interpolation time bins. | ||
# searchsorted(b, t, side="right") == i means that b[i-1] <= t < b[i] | ||
# hence the -1. doing it with "left" is not as nice -- we want t==b[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wow, I cannot believe that is not the default behaviour of "left"! |
||
# to lead to i=1 (rounding down). | ||
interpolation_bin_inds = np.searchsorted(interpolation_time_bin_edges_s, times, side="right") - 1 | ||
|
||
# the time bins may not cover the whole set of times in the recording, | ||
# so we need to clip these indices to the valid range | ||
np.clip(bin_inds, 0, time_bins.size, out=bin_inds) | ||
n_bins = interpolation_time_bin_edges_s.shape[0] - 1 | ||
np.clip(interpolation_bin_inds, 0, n_bins - 1, out=interpolation_bin_inds) | ||
|
||
# -- what are the possibilities here anyway? | ||
bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1) | ||
interpolation_bins_here = np.arange(interpolation_bin_inds[0], interpolation_bin_inds[-1] + 1) | ||
|
||
# inperpolation kernel will be the same per temporal bin | ||
interp_times = np.empty(total_num_chans) | ||
current_start_index = 0 | ||
for bin_ind in bins_here: | ||
bin_time = time_bins[bin_ind] | ||
for interp_bin_ind in interpolation_bins_here: | ||
bin_time = interpolation_time_bin_centers_s[interp_bin_ind] | ||
interp_times.fill(bin_time) | ||
channel_motions = motion.get_displacement_at_time_and_depth( | ||
interp_times, | ||
|
@@ -166,16 +183,17 @@ def interpolate_motion_on_traces( | |
# ax.set_title(f"bin_ind {bin_ind} - {bin_time}s - {spatial_interpolation_method}") | ||
# plt.show() | ||
|
||
# quick search logic to find frames corresponding to this interpolation bin in the recording | ||
# quickly find the end of this bin, which is also the start of the next | ||
next_start_index = current_start_index + np.searchsorted( | ||
bin_inds[current_start_index:], bin_ind + 1, side="left" | ||
interpolation_bin_inds[current_start_index:], interp_bin_ind + 1, side="left" | ||
) | ||
in_bin = slice(current_start_index, next_start_index) | ||
frames_in_bin = slice(current_start_index, next_start_index) | ||
|
||
# here we use a simple np.matmul even if dirft_kernel can be super sparse. | ||
# because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing | ||
# in ChunkRecordingExecutor) | ||
np.matmul(traces[in_bin], drift_kernel, out=traces_corrected[in_bin]) | ||
np.matmul(traces[frames_in_bin], drift_kernel, out=traces_corrected[frames_in_bin]) | ||
current_start_index = next_start_index | ||
|
||
return traces_corrected | ||
|
@@ -297,6 +315,7 @@ def __init__( | |
p=1, | ||
num_closest=3, | ||
interpolation_time_bin_centers_s=None, | ||
interpolation_time_bin_edges_s=None, | ||
interpolation_time_bin_size_s=None, | ||
dtype=None, | ||
**spatial_interpolation_kwargs, | ||
|
@@ -363,9 +382,14 @@ def __init__( | |
|
||
# handle manual interpolation_time_bin_centers_s | ||
# the case where interpolation_time_bin_size_s is set is handled per-segment below | ||
if interpolation_time_bin_centers_s is None: | ||
if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None: | ||
if interpolation_time_bin_size_s is None: | ||
interpolation_time_bin_centers_s = motion.temporal_bins_s | ||
interpolation_time_bin_edges_s = motion.temporal_bin_edges_s | ||
else: | ||
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins( | ||
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s | ||
) | ||
|
||
for segment_index, parent_segment in enumerate(recording._recording_segments): | ||
# finish the per-segment part of the time bin logic | ||
|
@@ -375,8 +399,13 @@ def __init__( | |
t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end])) | ||
halfbin = interpolation_time_bin_size_s / 2.0 | ||
segment_interpolation_time_bins_s = np.arange(t_start + halfbin, t_end, interpolation_time_bin_size_s) | ||
segment_interpolation_time_bin_edges_s = np.arange( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (mostly for line 390) Is it possible for |
||
t_start, t_end + halfbin, interpolation_time_bin_size_s | ||
) | ||
assert segment_interpolation_time_bin_edges_s.shape == (segment_interpolation_time_bins_s.shape[0] + 1,) | ||
else: | ||
segment_interpolation_time_bins_s = interpolation_time_bin_centers_s[segment_index] | ||
segment_interpolation_time_bin_edges_s = interpolation_time_bin_edges_s[segment_index] | ||
|
||
rec_segment = InterpolateMotionRecordingSegment( | ||
parent_segment, | ||
|
@@ -387,6 +416,7 @@ def __init__( | |
channel_inds, | ||
segment_index, | ||
segment_interpolation_time_bins_s, | ||
segment_interpolation_time_bin_edges_s, | ||
dtype=dtype_, | ||
) | ||
self.add_recording_segment(rec_segment) | ||
|
@@ -420,6 +450,7 @@ def __init__( | |
channel_inds, | ||
segment_index, | ||
interpolation_time_bin_centers_s, | ||
interpolation_time_bin_edges_s, | ||
dtype="float32", | ||
): | ||
BasePreprocessorSegment.__init__(self, parent_recording_segment) | ||
|
@@ -429,13 +460,11 @@ def __init__( | |
self.channel_inds = channel_inds | ||
self.segment_index = segment_index | ||
self.interpolation_time_bin_centers_s = interpolation_time_bin_centers_s | ||
self.interpolation_time_bin_edges_s = interpolation_time_bin_edges_s | ||
self.dtype = dtype | ||
self.motion = motion | ||
|
||
def get_traces(self, start_frame, end_frame, channel_indices): | ||
if self.time_vector is not None: | ||
raise NotImplementedError("InterpolateMotionRecording does not yet support recordings with time_vectors.") | ||
|
||
if start_frame is None: | ||
start_frame = 0 | ||
if end_frame is None: | ||
|
@@ -453,7 +482,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): | |
channel_inds=self.channel_inds, | ||
spatial_interpolation_method=self.spatial_interpolation_method, | ||
spatial_interpolation_kwargs=self.spatial_interpolation_kwargs, | ||
interpolation_time_bin_centers_s=self.interpolation_time_bin_centers_s, | ||
interpolation_time_bin_edges_s=self.interpolation_time_bin_edges_s, | ||
) | ||
|
||
if channel_indices is not None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import warnings | ||
import json | ||
import warnings | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
|
@@ -54,6 +54,7 @@ def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y" | |
self.direction = direction | ||
self.dim = ["x", "y", "z"].index(direction) | ||
self.check_properties() | ||
self.temporal_bin_edges_s = [ensure_time_bin_edges(tbins) for tbins in self.temporal_bins_s] | ||
|
||
def check_properties(self): | ||
assert all(d.ndim == 2 for d in self.displacement) | ||
|
@@ -576,3 +577,40 @@ def make_3d_motion_histograms( | |
motion_histograms = np.log2(1 + motion_histograms) | ||
|
||
return motion_histograms, temporal_bin_edges, spatial_bin_edges | ||
|
||
|
||
def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None): | ||
"""Ensure that both bin edges and bin centers are present | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A docstring would be useful here, just to explain a) the case this is used b) brief overview of what it is doing. If I understand correctly, we need both bin centres and bin edges. Given some bin centres, we compute the edges, or vice versa given some bin edges we compute the centres? |
||
|
||
If either of the inputs are None but not both, the missing is reconstructed | ||
from the present. Going from edges to centers is done by taking midpoints. | ||
Going from centers to edges is done by taking midpoints and padding with the | ||
left and rightmost centers. | ||
|
||
Parameters | ||
---------- | ||
time_bin_centers_s : None or np.array | ||
time_bin_edges_s : None or np.array | ||
|
||
Returns | ||
------- | ||
time_bin_centers_s, time_bin_edges_s | ||
""" | ||
if time_bin_centers_s is None and time_bin_edges_s is None: | ||
raise ValueError("Need at least one of time_bin_centers_s or time_bin_edges_s.") | ||
|
||
if time_bin_centers_s is None: | ||
assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2 | ||
time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1]) | ||
|
||
if time_bin_edges_s is None: | ||
time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype) | ||
time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this always be float? As we are multiplying by |
||
if time_bin_centers_s.size > 2: | ||
time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1]) | ||
|
||
return time_bin_centers_s, time_bin_edges_s | ||
|
||
|
||
def ensure_time_bin_edges(time_bin_centers_s=None, time_bin_edges_s=None): | ||
return ensure_time_bins(time_bin_centers_s, time_bin_edges_s)[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flagging this change! Sorry if this is not relevant enough to this PR, but I thought that while I was working on time logic it would be good to fix this last small quality of life thing (vectorizing time_to_sample_index -- note that the scalar case still behaves the same).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this follows @h-mayorquin 's comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting! This is nice, I think the only consideration is possible overflow for longer recordings, as
int64
is capped but pythonint()
is not capped. @h-mayorquin has been focussing on this more, but looking at a quick example below it should be fine.int64
max value is9,223,372,036,854,775,807
. If we take a neuropixels recording, continuous for 2 months (not unfeasible these days) we have(30,000 * 60 * 60 * 24 * 60) = 165888000000
(samples per s x seconds per minute x minutes per hour x hours per day x ~days in 2 month) (please check). But, maybe in 5 years people are sampling at 100 kHz and doing year long recordings 😆 we would have max index of(3.1536e+12)
. So I think should be sufficient under all feasible uses, but something to consider.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting... wait, maybe I'm doing the math wrong, but don't we have:
which is quite a long time? (wolfram double check)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I think we're good, I just meant we would have with that example a max index of
3.1536e+12
out of possible9223372036854775807
for one year, remainder2924712.08678
which is your1_067_519_911.6730065/365
. I think1_067_519_911.6730065 days
at 100 kHz is a much better way of putting it which really shows how sufficient this is!