Skip to content

Commit

Permalink
Merge branch 'SpikeInterface:main' into sc2_recording_slices
Browse files Browse the repository at this point in the history
  • Loading branch information
yger authored Nov 28, 2024
2 parents 047cac4 + 853d8a4 commit d39874f
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 60 deletions.
14 changes: 5 additions & 9 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import warnings
from pathlib import Path

Expand All @@ -7,14 +8,9 @@

from .base import BaseSegment
from .baserecordingsnippets import BaseRecordingSnippets
from .core_tools import (
convert_bytes_to_str,
convert_seconds_to_str,
)
from .recording_tools import write_binary_recording


from .core_tools import convert_bytes_to_str, convert_seconds_to_str
from .job_tools import split_job_kwargs
from .recording_tools import write_binary_recording


class BaseRecording(BaseRecordingSnippets):
Expand Down Expand Up @@ -950,11 +946,11 @@ def time_to_sample_index(self, time_s):
sample_index = time_s * self.sampling_frequency
else:
sample_index = (time_s - self.t_start) * self.sampling_frequency
sample_index = round(sample_index)
sample_index = np.round(sample_index).astype(int)
else:
sample_index = np.searchsorted(self.time_vector, time_s, side="right") - 1

return int(sample_index)
return sample_index

def get_num_samples(self) -> int:
"""Returns the number of samples in this signal segment
Expand Down
71 changes: 50 additions & 21 deletions src/spikeinterface/sortingcomponents/motion/motion_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment
from spikeinterface.preprocessing.filter import fix_dtype

from .motion_utils import ensure_time_bin_edges, ensure_time_bins


def correct_motion_on_peaks(peaks, peak_locations, motion, recording) -> np.ndarray:
"""
Expand Down Expand Up @@ -54,14 +56,19 @@ def interpolate_motion_on_traces(
segment_index=None,
channel_inds=None,
interpolation_time_bin_centers_s=None,
interpolation_time_bin_edges_s=None,
spatial_interpolation_method="kriging",
spatial_interpolation_kwargs={},
dtype=None,
):
"""
Apply inverse motion with spatial interpolation on traces.
Traces can be full traces, but also waveforms snippets.
Traces can be full traces, but also waveforms snippets. Times used for looking up
displacements are controlled by interpolation_time_bin_edges_s or
interpolation_time_bin_centers_s, or fall back to the Motion object's time bins
by default; times in the recording outside these time bins use the closest edge
bin's displacement value during interpolation.
Parameters
----------
Expand All @@ -80,6 +87,9 @@ def interpolate_motion_on_traces(
interpolation_time_bin_centers_s : None or np.array
Manually specify the time bins which the interpolation happens
in for this segment. If None, these are the motion estimate's time bins.
interpolation_time_bin_edges_s : None or np.array
If present, interpolation chunks will be the time bins defined by these edges
rather than interpolation_time_bin_centers_s or the motion's bins.
spatial_interpolation_method : "idw" | "kriging", default: "kriging"
The spatial interpolation method used to interpolate the channel locations:
* idw : Inverse Distance Weighing
Expand Down Expand Up @@ -119,26 +129,33 @@ def interpolate_motion_on_traces(
total_num_chans = channel_locations.shape[0]

# -- determine the blocks of frames that will land in the same interpolation time bin
time_bins = interpolation_time_bin_centers_s
if time_bins is None:
time_bins = motion.temporal_bins_s[segment_index]
bin_s = time_bins[1] - time_bins[0]
bins_start = time_bins[0] - 0.5 * bin_s
# nearest bin center for each frame?
bin_inds = (times - bins_start) // bin_s
bin_inds = bin_inds.astype(int)
if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None:
interpolation_time_bin_centers_s = motion.temporal_bins_s[segment_index]
interpolation_time_bin_edges_s = motion.temporal_bin_edges_s[segment_index]
else:
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins(
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s
)

# bin the frame times according to the interpolation time bins.
# searchsorted(b, t, side="right") == i means that b[i-1] <= t < b[i]
# hence the -1. doing it with "left" is not as nice -- we want t==b[0]
# to lead to i=1 (rounding down).
interpolation_bin_inds = np.searchsorted(interpolation_time_bin_edges_s, times, side="right") - 1

# the time bins may not cover the whole set of times in the recording,
# so we need to clip these indices to the valid range
np.clip(bin_inds, 0, time_bins.size, out=bin_inds)
n_bins = interpolation_time_bin_edges_s.shape[0] - 1
np.clip(interpolation_bin_inds, 0, n_bins - 1, out=interpolation_bin_inds)

# -- what are the possibilities here anyway?
bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1)
interpolation_bins_here = np.arange(interpolation_bin_inds[0], interpolation_bin_inds[-1] + 1)

# inperpolation kernel will be the same per temporal bin
interp_times = np.empty(total_num_chans)
current_start_index = 0
for bin_ind in bins_here:
bin_time = time_bins[bin_ind]
for interp_bin_ind in interpolation_bins_here:
bin_time = interpolation_time_bin_centers_s[interp_bin_ind]
interp_times.fill(bin_time)
channel_motions = motion.get_displacement_at_time_and_depth(
interp_times,
Expand Down Expand Up @@ -166,16 +183,17 @@ def interpolate_motion_on_traces(
# ax.set_title(f"bin_ind {bin_ind} - {bin_time}s - {spatial_interpolation_method}")
# plt.show()

# quick search logic to find frames corresponding to this interpolation bin in the recording
# quickly find the end of this bin, which is also the start of the next
next_start_index = current_start_index + np.searchsorted(
bin_inds[current_start_index:], bin_ind + 1, side="left"
interpolation_bin_inds[current_start_index:], interp_bin_ind + 1, side="left"
)
in_bin = slice(current_start_index, next_start_index)
frames_in_bin = slice(current_start_index, next_start_index)

# here we use a simple np.matmul even if dirft_kernel can be super sparse.
# because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing
# in ChunkRecordingExecutor)
np.matmul(traces[in_bin], drift_kernel, out=traces_corrected[in_bin])
np.matmul(traces[frames_in_bin], drift_kernel, out=traces_corrected[frames_in_bin])
current_start_index = next_start_index

return traces_corrected
Expand Down Expand Up @@ -297,6 +315,7 @@ def __init__(
p=1,
num_closest=3,
interpolation_time_bin_centers_s=None,
interpolation_time_bin_edges_s=None,
interpolation_time_bin_size_s=None,
dtype=None,
**spatial_interpolation_kwargs,
Expand Down Expand Up @@ -363,9 +382,14 @@ def __init__(

# handle manual interpolation_time_bin_centers_s
# the case where interpolation_time_bin_size_s is set is handled per-segment below
if interpolation_time_bin_centers_s is None:
if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None:
if interpolation_time_bin_size_s is None:
interpolation_time_bin_centers_s = motion.temporal_bins_s
interpolation_time_bin_edges_s = motion.temporal_bin_edges_s
else:
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins(
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s
)

for segment_index, parent_segment in enumerate(recording._recording_segments):
# finish the per-segment part of the time bin logic
Expand All @@ -375,8 +399,13 @@ def __init__(
t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end]))
halfbin = interpolation_time_bin_size_s / 2.0
segment_interpolation_time_bins_s = np.arange(t_start + halfbin, t_end, interpolation_time_bin_size_s)
segment_interpolation_time_bin_edges_s = np.arange(
t_start, t_end + halfbin, interpolation_time_bin_size_s
)
assert segment_interpolation_time_bin_edges_s.shape == (segment_interpolation_time_bins_s.shape[0] + 1,)
else:
segment_interpolation_time_bins_s = interpolation_time_bin_centers_s[segment_index]
segment_interpolation_time_bin_edges_s = interpolation_time_bin_edges_s[segment_index]

rec_segment = InterpolateMotionRecordingSegment(
parent_segment,
Expand All @@ -387,6 +416,7 @@ def __init__(
channel_inds,
segment_index,
segment_interpolation_time_bins_s,
segment_interpolation_time_bin_edges_s,
dtype=dtype_,
)
self.add_recording_segment(rec_segment)
Expand Down Expand Up @@ -420,6 +450,7 @@ def __init__(
channel_inds,
segment_index,
interpolation_time_bin_centers_s,
interpolation_time_bin_edges_s,
dtype="float32",
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
Expand All @@ -429,13 +460,11 @@ def __init__(
self.channel_inds = channel_inds
self.segment_index = segment_index
self.interpolation_time_bin_centers_s = interpolation_time_bin_centers_s
self.interpolation_time_bin_edges_s = interpolation_time_bin_edges_s
self.dtype = dtype
self.motion = motion

def get_traces(self, start_frame, end_frame, channel_indices):
if self.time_vector is not None:
raise NotImplementedError("InterpolateMotionRecording does not yet support recordings with time_vectors.")

if start_frame is None:
start_frame = 0
if end_frame is None:
Expand All @@ -453,7 +482,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
channel_inds=self.channel_inds,
spatial_interpolation_method=self.spatial_interpolation_method,
spatial_interpolation_kwargs=self.spatial_interpolation_kwargs,
interpolation_time_bin_centers_s=self.interpolation_time_bin_centers_s,
interpolation_time_bin_edges_s=self.interpolation_time_bin_edges_s,
)

if channel_indices is not None:
Expand Down
40 changes: 39 additions & 1 deletion src/spikeinterface/sortingcomponents/motion/motion_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
import json
import warnings
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -54,6 +54,7 @@ def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y"
self.direction = direction
self.dim = ["x", "y", "z"].index(direction)
self.check_properties()
self.temporal_bin_edges_s = [ensure_time_bin_edges(tbins) for tbins in self.temporal_bins_s]

def check_properties(self):
assert all(d.ndim == 2 for d in self.displacement)
Expand Down Expand Up @@ -576,3 +577,40 @@ def make_3d_motion_histograms(
motion_histograms = np.log2(1 + motion_histograms)

return motion_histograms, temporal_bin_edges, spatial_bin_edges


def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None):
"""Ensure that both bin edges and bin centers are present
If either of the inputs are None but not both, the missing is reconstructed
from the present. Going from edges to centers is done by taking midpoints.
Going from centers to edges is done by taking midpoints and padding with the
left and rightmost centers.
Parameters
----------
time_bin_centers_s : None or np.array
time_bin_edges_s : None or np.array
Returns
-------
time_bin_centers_s, time_bin_edges_s
"""
if time_bin_centers_s is None and time_bin_edges_s is None:
raise ValueError("Need at least one of time_bin_centers_s or time_bin_edges_s.")

if time_bin_centers_s is None:
assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2
time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1])

if time_bin_edges_s is None:
time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype)
time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]]
if time_bin_centers_s.size > 2:
time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1])

return time_bin_centers_s, time_bin_edges_s


def ensure_time_bin_edges(time_bin_centers_s=None, time_bin_edges_s=None):
return ensure_time_bins(time_bin_centers_s, time_bin_edges_s)[1]
Loading

0 comments on commit d39874f

Please sign in to comment.