Skip to content

Commit

Permalink
Variable names in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Nov 21, 2024
1 parent c890603 commit 6d2e479
Showing 1 changed file with 21 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
import spikeinterface.core as sc
from spikeinterface.sortingcomponents.motion import Motion
from spikeinterface.sortingcomponents.motion.motion_interpolation import (
InterpolateMotionRecording,
correct_motion_on_peaks,
interpolate_motion,
interpolate_motion_on_traces,
)
InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion,
interpolate_motion_on_traces)
from spikeinterface.sortingcomponents.tests.common import make_dataset


Expand Down Expand Up @@ -84,26 +81,26 @@ def test_interpolate_motion_on_traces():
def test_interpolation_simple():
# a recording where a 1 moves at 1 chan per second. 30 chans 10 frames.
# there will be 9 chans of drift, so we add 9 chans of padding to the bottom
nt = nc0 = 10 # these need to be the same for this test
nc1 = nc0 + nc0 - 1
traces = np.zeros((nt, nc1), dtype="float32")
traces[:, :nc0] = np.eye(nc0)
n_samples = num_chans_orig = 10 # these need to be the same for this test
num_chans_drifted = num_chans_orig + num_chans_orig - 1
traces = np.zeros((n_samples, num_chans_drifted), dtype="float32")
traces[:, :num_chans_orig] = np.eye(num_chans_orig)
rec = sc.NumpyRecording(traces, sampling_frequency=1)
rec.set_dummy_probe_from_locations(np.c_[np.zeros(nc1), np.arange(nc1)])
rec.set_dummy_probe_from_locations(np.c_[np.zeros(num_chans_drifted), np.arange(num_chans_drifted)])

true_motion = Motion(np.arange(nt)[:, None], 0.5 + np.arange(nt), np.zeros(1))
true_motion = Motion(np.arange(n_samples)[:, None], 0.5 + np.arange(n_samples), np.zeros(1))
rec_corrected = interpolate_motion(rec, true_motion, spatial_interpolation_method="nearest")
traces_corrected = rec_corrected.get_traces()
assert traces_corrected.shape == (nc0, nc0)
assert np.array_equal(traces_corrected[:, 0], np.ones(nt))
assert np.array_equal(traces_corrected[:, 1:], np.zeros((nt, nc0 - 1)))
assert traces_corrected.shape == (num_chans_orig, num_chans_orig)
assert np.array_equal(traces_corrected[:, 0], np.ones(n_samples))
assert np.array_equal(traces_corrected[:, 1:], np.zeros((n_samples, num_chans_orig - 1)))

# let's try a new version where we interpolate too slowly
rec_corrected = interpolate_motion(
rec, true_motion, spatial_interpolation_method="nearest", num_closest=2, interpolation_time_bin_size_s=2
)
traces_corrected = rec_corrected.get_traces()
assert traces_corrected.shape == (nc0, nc0)
assert traces_corrected.shape == (num_chans_orig, num_chans_orig)
# what happens with nearest here?
# well... due to rounding towards the nearest even number, the motion (which at
# these time bin centers is 0.5, 2.5, 4.5, ...) flips the signal's nearest
Expand Down Expand Up @@ -131,27 +128,27 @@ def test_cross_band_interpolation():
fs_ap = 300.0
t_start = 10.0
total_duration = 5.0
nt_lfp = int(fs_lfp * total_duration)
nt_ap = int(fs_ap * total_duration)
num_samples_lfp = int(fs_lfp * total_duration)
num_samples_ap = int(fs_ap * total_duration)
t_switch = 3

# because interpolation uses bin centers logic, there will be a half
# bin offset at the change point in the AP recording.
halfbin_ap_lfp = int(0.5 * (fs_ap / fs_lfp))

# channel geometry
nc = 10
geom = np.c_[np.zeros(nc), np.arange(nc)]
num_chans = 10
geom = np.c_[np.zeros(num_chans), np.arange(num_chans)]

# make an LFP recording which drifts a bit
traces_lfp = np.zeros((nt_lfp, nc))
traces_lfp = np.zeros((num_samples_lfp, num_chans))
traces_lfp[: int(t_switch * fs_lfp), 5] = 1.0
traces_lfp[int(t_switch * fs_lfp) :, 6] = 1.0
rec_lfp = sc.NumpyRecording(traces_lfp, sampling_frequency=fs_lfp)
rec_lfp.set_dummy_probe_from_locations(geom)

# same for AP
traces_ap = np.zeros((nt_ap, nc))
traces_ap = np.zeros((num_samples_ap, num_chans))
traces_ap[: int(t_switch * fs_ap) - halfbin_ap_lfp, 5] = 1.0
traces_ap[int(t_switch * fs_ap) - halfbin_ap_lfp :, 6] = 1.0
rec_ap = sc.NumpyRecording(traces_ap, sampling_frequency=fs_ap)
Expand All @@ -160,16 +157,16 @@ def test_cross_band_interpolation():
# set times for both, and silence the warning
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
rec_lfp.set_times(t_start + np.arange(nt_lfp) / fs_lfp)
rec_ap.set_times(t_start + np.arange(nt_ap) / fs_ap)
rec_lfp.set_times(t_start + np.arange(num_samples_lfp) / fs_lfp)
rec_ap.set_times(t_start + np.arange(num_samples_ap) / fs_ap)

# estimate motion
motion = estimate_motion(rec_lfp, method="dredge_lfp", rigid=True)

# nearest to keep it simple
rec_corrected = interpolate_motion(rec_ap, motion, spatial_interpolation_method="nearest", num_closest=2)
traces_corrected = rec_corrected.get_traces()
target = np.zeros((nt_ap, nc - 2))
target = np.zeros((num_samples_ap, num_chans - 2))
target[:, 4] = 1
ii, jj = np.nonzero(traces_corrected)
assert np.array_equal(traces_corrected, target)
Expand Down

0 comments on commit 6d2e479

Please sign in to comment.