Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a cross-band interpolation bug, and allow time_vector in interpolate_motion #3517

Merged
merged 19 commits into from
Nov 22, 2024

Conversation

cwindolf
Copy link
Collaborator

@cwindolf cwindolf commented Nov 4, 2024

Hi all, this PR addresses two issues related to motion interpolation, specifically estimating motion in the LFP band and applying the correction to the AP band.

I also added a small test for cross-band registration. The idea was to generate an "LFP" and an "AP" recording, which are zeros except for a channel of 1s which jumps after 3s. Motion is estimated from the LFP recording and then used to interpolate the AP recording, and we assert that the result has no drift at all.

Interestingly, since the LFP recording's time bins are treated as bin centers by interpolate_motion but as left bin edges by time_to_sample_index(), there is a half-bin offset between the motion and the recording which I had to account for in this test. But, I actually think this is the right way to go? 2 reasons:

  • interpolate_motion: It makes sense to treat motion time bins as centers even when they correspond to LFP recording samples; this artificial case is not realistic and in real life we believe that behavior leads to less error.
  • time_to_sample_index(): obviously, would not want to change the behavior here, everybody expects the time in seconds to be "floored" to the previous sample.
    • If I'm wrong about that, then the searchsorted() line needs to be changed in time_to_sample_index() to map to the nearest bin, and logic similar to what I've written in interpolate_motion_on_traces() can be used. But, that's a separate issue from this PR -- just wanted to mention that in case my understanding is off.

(And, to @DaohanZhang: this PR together with a recent dredge repo commit (evarol/dredge@5cfe179) should fix your issue, if you'd like to test them. I've updated the demo notebook there to reflect the latest updates -- dredge's understanding of spikeinterface's time handling needed to be updated.)

@cwindolf cwindolf added bug Something isn't working motion correction Questions related to motion correction labels Nov 4, 2024
@@ -921,11 +917,11 @@ def time_to_sample_index(self, time_s):
sample_index = time_s * self.sampling_frequency
else:
sample_index = (time_s - self.t_start) * self.sampling_frequency
sample_index = round(sample_index)
sample_index = np.round(sample_index).astype(int)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flagging this change! Sorry if this is not relevant enough to this PR, but I thought that while I was working on time logic it would be good to fix this last small quality of life thing (vectorizing time_to_sample_index -- note that the scalar case still behaves the same).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this follows @h-mayorquin 's comment

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! This is nice, I think the only consideration is possible overflow for longer recordings, as int64 is capped but python int() is not capped. @h-mayorquin has been focussing on this more, but looking at a quick example below it should be fine.

int64 max value is 9,223,372,036,854,775,807. If we take a neuropixels recording, continuous for 2 months (not unfeasible these days) we have (30,000 * 60 * 60 * 24 * 60) = 165888000000 (samples per s x seconds per minute x minutes per hour x hours per day x ~days in 2 month) (please check). But, maybe in 5 years people are sampling at 100 kHz and doing year long recordings 😆 we would have max index of (3.1536e+12). So I think should be sufficient under all feasible uses, but something to consider.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting... wait, maybe I'm doing the math wrong, but don't we have:

# int64 max val     | samples/sec | sec/min | min/hr | hr/day
9223372036854775807 /     100_000 /      60 /     60 /     24
# => 1_067_519_911.6730065 days

which is quite a long time? (wolfram double check)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I think we're good, I just meant we would have with that example a max index of 3.1536e+12 out of possible 9223372036854775807 for one year, remainder 2924712.08678 which is your 1_067_519_911.6730065/365. I think 1_067_519_911.6730065 days at 100 kHz is a much better way of putting it which really shows how sufficient this is!

@cwindolf
Copy link
Collaborator Author

Hi @samuelgarcia -- updated this PR with the stuff we talked about. To recap: we need the interpolation time bin edges to interpolate traces, which we were recomputing over and over for each interpolation chunk. Now they are cached. They are cached in the motion object, but also in the interpolated recording object if for example a user specifies their own interpolation time bin size.

I also added a small thing to vectorize time_to_sample_index, flagged in a comment above. If that feels outside the scope here I'm happy to put it in another PR!

Copy link
Collaborator

@JoeZiminski JoeZiminski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Charlie this is very nice, I agree searchsorted is succinct and should pose no issues with performance. The new test is great will help the code be robust going forward and also helps explain what the functions are doing. Please see some general comments / suggestions none of these are major so feel free to ignore.

Out of interest, what was the particular change that fixed the cross-band bug? I am trying to visualise the problem. So we have the AP band samples at 30 kHz. We have AP-band motion correction time segments e.g. 2s per segment, with each segment associated with some displacement to apply. But we can also have LFP-band segments over which the drift is estimated. However because the sampling rate is different for the LFP band, in some cases this led to alignment issues when finding the nearest AP band sample to an LFP motion bin?

time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1])

if time_bin_edges_s is None:
time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this always be float? As we are multiplying by 0.5. If we the dtypes need to be the same, should we instead cast time_bin_centers_s to float?

@@ -576,3 +577,24 @@ def make_3d_motion_histograms(
motion_histograms = np.log2(1 + motion_histograms)

return motion_histograms, temporal_bin_edges, spatial_bin_edges


def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A docstring would be useful here, just to explain a) the case this is used b) brief overview of what it is doing.

If I understand correctly, we need both bin centres and bin edges. Given some bin centres, we compute the edges, or vice versa given some bin edges we compute the centres?

bin_inds = (times - bins_start) // bin_s
bin_inds = bin_inds.astype(int)
if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None:
bin_centers_s = motion.temporal_bin_edges_s[segment_index]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct here that the bin centers are assigned the temporal bin edges?

@@ -54,6 +56,7 @@ def interpolate_motion_on_traces(
segment_index=None,
channel_inds=None,
interpolation_time_bin_centers_s=None,
interpolation_time_bin_edges_s=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of interest, what is the use-case for allowing edges to be passed instead of centres? say vs. requiring centres only? I find this signature and the code necessary to handle either centres or edges a little confusing, but agree there are few option that allow this level of flexibility. I guess these options typically not user-facing anyway? i.e. most users would be using the motion pipeline and can safely ignore this.

Also, a docstring addition in Parameters for interpolation_time_bin_edges_s would be great.

bin_centers_s, bin_edges_s = ensure_time_bins(interpolation_time_bin_centers_s, interpolation_time_bin_edges_s)

# nearest interpolation bin:
# seachsorted(b, t, side="right") == i means that b[i-1] <= t < b[i]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# seachsorted(b, t, side="right") == i means that b[i-1] <= t < b[i]
# searchsorted(b, t, side="right") == i means that b[i-1] <= t < b[i]

@@ -375,8 +393,13 @@ def __init__(
t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end]))
halfbin = interpolation_time_bin_size_s / 2.0
segment_interpolation_time_bins_s = np.arange(t_start + halfbin, t_end, interpolation_time_bin_size_s)
segment_interpolation_time_bin_edges_s = np.arange(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(mostly for line 390) Is it possible for interpolation_time_bin_centers_s to be None at this point anymore? If centers and edges are both None, it will be motion.temporal_bins_s, if it is passed it will not be None, and it if centers is None it will be filled in with ensure_time_bins ?

halfbin_ap_lfp = int(0.5 * (fs_ap / fs_lfp))

# channel geometry
nc = 10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this be num_channels or num_chans?

t_start = 10.0
total_duration = 5.0
nt_lfp = int(fs_lfp * total_duration)
nt_ap = int(fs_ap * total_duration)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is nt num_timepoints? could it be expanded?

@@ -115,6 +115,66 @@ def test_interpolation_simple():
assert np.all(traces_corrected[:, 2:] == 0)


def test_cross_band_interpolation():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a really nice test, super useful also for conceptualising the cross-band interpolation

# spatial_interpolation_kwargs={},
spatial_interpolation_kwargs={"force_extrapolate": True},
)
assert traces.shape == traces_corrected.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outside the scope of this PR (as tests were like this anyways), but it would be nice to extend these tests to explicitly check the values are correct. Testing the shape will not pick up any regressions that mess up the actual computation but leave the shape in tact. Unfortunately these are the least likely to be picked up as any erroneous shapes will probably crash at runtime anyway.

I am wondering if traces is a simple 3x3 array (3 channels, 3 timepoints) it would be relatively easy to compute manually the expected results of kriging, idw and NN interpolation and check against the output of this function? Of course, this is outside the scope of the PR so feel free to ignore and I can write an issue!

@cwindolf
Copy link
Collaborator Author

Thanks a ton for the review @JoeZiminski ! I think I've made changes corresponding to all of your comments. No clue why some of my comments got duplicated... that's what I get for trying out this github code spaces thing!

@samuelgarcia
Copy link
Member

This is OK for me.

@alejoe91 alejoe91 merged commit 853d8a4 into SpikeInterface:main Nov 22, 2024
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working motion correction Questions related to motion correction
Projects
None yet
5 participants