Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a cross-band interpolation bug, and allow time_vector in interpolate_motion #3517

Merged
merged 19 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import warnings
from pathlib import Path

Expand All @@ -7,14 +8,9 @@

from .base import BaseSegment
from .baserecordingsnippets import BaseRecordingSnippets
from .core_tools import (
convert_bytes_to_str,
convert_seconds_to_str,
)
from .recording_tools import write_binary_recording


from .core_tools import convert_bytes_to_str, convert_seconds_to_str
from .job_tools import split_job_kwargs
from .recording_tools import write_binary_recording


class BaseRecording(BaseRecordingSnippets):
Expand Down Expand Up @@ -950,11 +946,11 @@ def time_to_sample_index(self, time_s):
sample_index = time_s * self.sampling_frequency
else:
sample_index = (time_s - self.t_start) * self.sampling_frequency
sample_index = round(sample_index)
sample_index = np.round(sample_index).astype(int)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flagging this change! Sorry if this is not relevant enough to this PR, but I thought that while I was working on time logic it would be good to fix this last small quality of life thing (vectorizing time_to_sample_index -- note that the scalar case still behaves the same).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this follows @h-mayorquin 's comment

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! This is nice, I think the only consideration is possible overflow for longer recordings, as int64 is capped but python int() is not capped. @h-mayorquin has been focussing on this more, but looking at a quick example below it should be fine.

int64 max value is 9,223,372,036,854,775,807. If we take a neuropixels recording, continuous for 2 months (not unfeasible these days) we have (30,000 * 60 * 60 * 24 * 60) = 165888000000 (samples per s x seconds per minute x minutes per hour x hours per day x ~days in 2 month) (please check). But, maybe in 5 years people are sampling at 100 kHz and doing year long recordings 😆 we would have max index of (3.1536e+12). So I think should be sufficient under all feasible uses, but something to consider.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting... wait, maybe I'm doing the math wrong, but don't we have:

# int64 max val     | samples/sec | sec/min | min/hr | hr/day
9223372036854775807 /     100_000 /      60 /     60 /     24
# => 1_067_519_911.6730065 days

which is quite a long time? (wolfram double check)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I think we're good, I just meant we would have with that example a max index of 3.1536e+12 out of possible 9223372036854775807 for one year, remainder 2924712.08678 which is your 1_067_519_911.6730065/365. I think 1_067_519_911.6730065 days at 100 kHz is a much better way of putting it which really shows how sufficient this is!

else:
sample_index = np.searchsorted(self.time_vector, time_s, side="right") - 1

return int(sample_index)
return sample_index

def get_num_samples(self) -> int:
"""Returns the number of samples in this signal segment
Expand Down
71 changes: 50 additions & 21 deletions src/spikeinterface/sortingcomponents/motion/motion_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment
from spikeinterface.preprocessing.filter import fix_dtype

from .motion_utils import ensure_time_bin_edges, ensure_time_bins


def correct_motion_on_peaks(peaks, peak_locations, motion, recording) -> np.ndarray:
"""
Expand Down Expand Up @@ -54,14 +56,19 @@ def interpolate_motion_on_traces(
segment_index=None,
channel_inds=None,
interpolation_time_bin_centers_s=None,
interpolation_time_bin_edges_s=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of interest, what is the use-case for allowing edges to be passed instead of centres? say vs. requiring centres only? I find this signature and the code necessary to handle either centres or edges a little confusing, but agree there are few option that allow this level of flexibility. I guess these options typically not user-facing anyway? i.e. most users would be using the motion pipeline and can safely ignore this.

Also, a docstring addition in Parameters for interpolation_time_bin_edges_s would be great.

spatial_interpolation_method="kriging",
spatial_interpolation_kwargs={},
dtype=None,
):
"""
Apply inverse motion with spatial interpolation on traces.

Traces can be full traces, but also waveforms snippets.
Traces can be full traces, but also waveforms snippets. Times used for looking up
displacements are controlled by interpolation_time_bin_edges_s or
interpolation_time_bin_centers_s, or fall back to the Motion object's time bins
by default; times in the recording outside these time bins use the closest edge
bin's displacement value during interpolation.

Parameters
----------
Expand All @@ -80,6 +87,9 @@ def interpolate_motion_on_traces(
interpolation_time_bin_centers_s : None or np.array
Manually specify the time bins which the interpolation happens
in for this segment. If None, these are the motion estimate's time bins.
interpolation_time_bin_edges_s : None or np.array
If present, interpolation chunks will be the time bins defined by these edges
rather than interpolation_time_bin_centers_s or the motion's bins.
spatial_interpolation_method : "idw" | "kriging", default: "kriging"
The spatial interpolation method used to interpolate the channel locations:
* idw : Inverse Distance Weighing
Expand Down Expand Up @@ -119,26 +129,33 @@ def interpolate_motion_on_traces(
total_num_chans = channel_locations.shape[0]

# -- determine the blocks of frames that will land in the same interpolation time bin
time_bins = interpolation_time_bin_centers_s
if time_bins is None:
time_bins = motion.temporal_bins_s[segment_index]
bin_s = time_bins[1] - time_bins[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check my understanding, the searchsorted on the bin_edges is functionally equivalent to this approach? (but of course searchsorted is less verbose)

bins_start = time_bins[0] - 0.5 * bin_s
# nearest bin center for each frame?
bin_inds = (times - bins_start) // bin_s
bin_inds = bin_inds.astype(int)
if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None:
interpolation_time_bin_centers_s = motion.temporal_bin_centers_s[segment_index]
interpolation_time_bin_edges_s = motion.temporal_bin_edges_s[segment_index]
else:
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins(
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s
)

# bin the frame times according to the interpolation time bins.
# searchsorted(b, t, side="right") == i means that b[i-1] <= t < b[i]
# hence the -1. doing it with "left" is not as nice -- we want t==b[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, I cannot believe that is not the default behaviour of "left"!

# to lead to i=1 (rounding down).
interpolation_bin_inds = np.searchsorted(interpolation_time_bin_edges_s, times, side="right") - 1

# the time bins may not cover the whole set of times in the recording,
# so we need to clip these indices to the valid range
np.clip(bin_inds, 0, time_bins.size, out=bin_inds)
n_bins = interpolation_time_bin_edges_s.shape[0] - 1
np.clip(interpolation_bin_inds, 0, n_bins - 1, out=interpolation_bin_inds)

# -- what are the possibilities here anyway?
bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1)
interpolation_bins_here = np.arange(interpolation_bin_inds[0], interpolation_bin_inds[-1] + 1)

# inperpolation kernel will be the same per temporal bin
interp_times = np.empty(total_num_chans)
current_start_index = 0
for bin_ind in bins_here:
bin_time = time_bins[bin_ind]
for interp_bin_ind in interpolation_bins_here:
bin_time = interpolation_time_bin_centers_s[interp_bin_ind]
interp_times.fill(bin_time)
channel_motions = motion.get_displacement_at_time_and_depth(
interp_times,
Expand Down Expand Up @@ -166,16 +183,17 @@ def interpolate_motion_on_traces(
# ax.set_title(f"bin_ind {bin_ind} - {bin_time}s - {spatial_interpolation_method}")
# plt.show()

# quick search logic to find frames corresponding to this interpolation bin in the recording
# quickly find the end of this bin, which is also the start of the next
next_start_index = current_start_index + np.searchsorted(
bin_inds[current_start_index:], bin_ind + 1, side="left"
interpolation_bin_inds[current_start_index:], interp_bin_ind + 1, side="left"
)
in_bin = slice(current_start_index, next_start_index)
frames_in_bin = slice(current_start_index, next_start_index)

# here we use a simple np.matmul even if dirft_kernel can be super sparse.
# because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing
# in ChunkRecordingExecutor)
np.matmul(traces[in_bin], drift_kernel, out=traces_corrected[in_bin])
np.matmul(traces[frames_in_bin], drift_kernel, out=traces_corrected[frames_in_bin])
current_start_index = next_start_index

return traces_corrected
Expand Down Expand Up @@ -297,6 +315,7 @@ def __init__(
p=1,
num_closest=3,
interpolation_time_bin_centers_s=None,
interpolation_time_bin_edges_s=None,
interpolation_time_bin_size_s=None,
dtype=None,
**spatial_interpolation_kwargs,
Expand Down Expand Up @@ -363,9 +382,14 @@ def __init__(

# handle manual interpolation_time_bin_centers_s
# the case where interpolation_time_bin_size_s is set is handled per-segment below
if interpolation_time_bin_centers_s is None:
if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None:
if interpolation_time_bin_size_s is None:
interpolation_time_bin_centers_s = motion.temporal_bins_s
interpolation_time_bin_edges_s = motion.temporal_bin_edges_s
else:
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins(
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s
)

for segment_index, parent_segment in enumerate(recording._recording_segments):
# finish the per-segment part of the time bin logic
Expand All @@ -375,8 +399,13 @@ def __init__(
t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end]))
halfbin = interpolation_time_bin_size_s / 2.0
segment_interpolation_time_bins_s = np.arange(t_start + halfbin, t_end, interpolation_time_bin_size_s)
segment_interpolation_time_bin_edges_s = np.arange(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(mostly for line 390) Is it possible for interpolation_time_bin_centers_s to be None at this point anymore? If centers and edges are both None, it will be motion.temporal_bins_s, if it is passed it will not be None, and it if centers is None it will be filled in with ensure_time_bins ?

t_start, t_end + halfbin, interpolation_time_bin_size_s
)
assert segment_interpolation_time_bin_edges_s.shape == (segment_interpolation_time_bins_s.shape[0] + 1,)
else:
segment_interpolation_time_bins_s = interpolation_time_bin_centers_s[segment_index]
segment_interpolation_time_bin_edges_s = interpolation_time_bin_edges_s[segment_index]

rec_segment = InterpolateMotionRecordingSegment(
parent_segment,
Expand All @@ -387,6 +416,7 @@ def __init__(
channel_inds,
segment_index,
segment_interpolation_time_bins_s,
segment_interpolation_time_bin_edges_s,
dtype=dtype_,
)
self.add_recording_segment(rec_segment)
Expand Down Expand Up @@ -420,6 +450,7 @@ def __init__(
channel_inds,
segment_index,
interpolation_time_bin_centers_s,
interpolation_time_bin_edges_s,
dtype="float32",
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
Expand All @@ -429,13 +460,11 @@ def __init__(
self.channel_inds = channel_inds
self.segment_index = segment_index
self.interpolation_time_bin_centers_s = interpolation_time_bin_centers_s
self.interpolation_time_bin_edges_s = interpolation_time_bin_edges_s
self.dtype = dtype
self.motion = motion

def get_traces(self, start_frame, end_frame, channel_indices):
if self.time_vector is not None:
raise NotImplementedError("InterpolateMotionRecording does not yet support recordings with time_vectors.")

if start_frame is None:
start_frame = 0
if end_frame is None:
Expand All @@ -453,7 +482,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
channel_inds=self.channel_inds,
spatial_interpolation_method=self.spatial_interpolation_method,
spatial_interpolation_kwargs=self.spatial_interpolation_kwargs,
interpolation_time_bin_centers_s=self.interpolation_time_bin_centers_s,
interpolation_time_bin_edges_s=self.interpolation_time_bin_edges_s,
)

if channel_indices is not None:
Expand Down
40 changes: 39 additions & 1 deletion src/spikeinterface/sortingcomponents/motion/motion_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
import json
import warnings
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -54,6 +54,7 @@ def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y"
self.direction = direction
self.dim = ["x", "y", "z"].index(direction)
self.check_properties()
self.temporal_bin_edges_s = [ensure_time_bin_edges(tbins) for tbins in self.temporal_bins_s]

def check_properties(self):
assert all(d.ndim == 2 for d in self.displacement)
Expand Down Expand Up @@ -576,3 +577,40 @@ def make_3d_motion_histograms(
motion_histograms = np.log2(1 + motion_histograms)

return motion_histograms, temporal_bin_edges, spatial_bin_edges


def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None):
"""Ensure that both bin edges and bin centers are present
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A docstring would be useful here, just to explain a) the case this is used b) brief overview of what it is doing.

If I understand correctly, we need both bin centres and bin edges. Given some bin centres, we compute the edges, or vice versa given some bin edges we compute the centres?


If either of the inputs are None but not both, the missing is reconstructed
from the present. Going from edges to centers is done by taking midpoints.
Going from centers to edges is done by taking midpoints and padding with the
left and rightmost centers.

Parameters
----------
time_bin_centers_s : None or np.array
time_bin_edges_s : None or np.array

Returns
-------
time_bin_centers_s, time_bin_edges_s
"""
if time_bin_centers_s is None and time_bin_edges_s is None:
raise ValueError("Need at least one of time_bin_centers_s or time_bin_edges_s.")

if time_bin_centers_s is None:
assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2
time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1])

if time_bin_edges_s is None:
time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype)
time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this always be float? As we are multiplying by 0.5. If we the dtypes need to be the same, should we instead cast time_bin_centers_s to float?

if time_bin_centers_s.size > 2:
time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1])

return time_bin_centers_s, time_bin_edges_s


def ensure_time_bin_edges(time_bin_centers_s=None, time_bin_edges_s=None):
return ensure_time_bins(time_bin_centers_s, time_bin_edges_s)[1]
Loading
Loading