diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py index 8542b62524..c97c8324ba 100644 --- a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py @@ -4,11 +4,8 @@ import spikeinterface.core as sc from spikeinterface.sortingcomponents.motion import Motion from spikeinterface.sortingcomponents.motion.motion_interpolation import ( - InterpolateMotionRecording, - correct_motion_on_peaks, - interpolate_motion, - interpolate_motion_on_traces, -) + InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion, + interpolate_motion_on_traces) from spikeinterface.sortingcomponents.tests.common import make_dataset @@ -84,26 +81,26 @@ def test_interpolate_motion_on_traces(): def test_interpolation_simple(): # a recording where a 1 moves at 1 chan per second. 30 chans 10 frames. # there will be 9 chans of drift, so we add 9 chans of padding to the bottom - nt = nc0 = 10 # these need to be the same for this test - nc1 = nc0 + nc0 - 1 - traces = np.zeros((nt, nc1), dtype="float32") - traces[:, :nc0] = np.eye(nc0) + n_samples = num_chans_orig = 10 # these need to be the same for this test + num_chans_drifted = num_chans_orig + num_chans_orig - 1 + traces = np.zeros((n_samples, num_chans_drifted), dtype="float32") + traces[:, :num_chans_orig] = np.eye(num_chans_orig) rec = sc.NumpyRecording(traces, sampling_frequency=1) - rec.set_dummy_probe_from_locations(np.c_[np.zeros(nc1), np.arange(nc1)]) + rec.set_dummy_probe_from_locations(np.c_[np.zeros(num_chans_drifted), np.arange(num_chans_drifted)]) - true_motion = Motion(np.arange(nt)[:, None], 0.5 + np.arange(nt), np.zeros(1)) + true_motion = Motion(np.arange(n_samples)[:, None], 0.5 + np.arange(n_samples), np.zeros(1)) rec_corrected = interpolate_motion(rec, true_motion, spatial_interpolation_method="nearest") traces_corrected = rec_corrected.get_traces() - assert traces_corrected.shape == (nc0, nc0) - assert np.array_equal(traces_corrected[:, 0], np.ones(nt)) - assert np.array_equal(traces_corrected[:, 1:], np.zeros((nt, nc0 - 1))) + assert traces_corrected.shape == (num_chans_orig, num_chans_orig) + assert np.array_equal(traces_corrected[:, 0], np.ones(n_samples)) + assert np.array_equal(traces_corrected[:, 1:], np.zeros((n_samples, num_chans_orig - 1))) # let's try a new version where we interpolate too slowly rec_corrected = interpolate_motion( rec, true_motion, spatial_interpolation_method="nearest", num_closest=2, interpolation_time_bin_size_s=2 ) traces_corrected = rec_corrected.get_traces() - assert traces_corrected.shape == (nc0, nc0) + assert traces_corrected.shape == (num_chans_orig, num_chans_orig) # what happens with nearest here? # well... due to rounding towards the nearest even number, the motion (which at # these time bin centers is 0.5, 2.5, 4.5, ...) flips the signal's nearest @@ -131,8 +128,8 @@ def test_cross_band_interpolation(): fs_ap = 300.0 t_start = 10.0 total_duration = 5.0 - nt_lfp = int(fs_lfp * total_duration) - nt_ap = int(fs_ap * total_duration) + num_samples_lfp = int(fs_lfp * total_duration) + num_samples_ap = int(fs_ap * total_duration) t_switch = 3 # because interpolation uses bin centers logic, there will be a half @@ -140,18 +137,18 @@ def test_cross_band_interpolation(): halfbin_ap_lfp = int(0.5 * (fs_ap / fs_lfp)) # channel geometry - nc = 10 - geom = np.c_[np.zeros(nc), np.arange(nc)] + num_chans = 10 + geom = np.c_[np.zeros(num_chans), np.arange(num_chans)] # make an LFP recording which drifts a bit - traces_lfp = np.zeros((nt_lfp, nc)) + traces_lfp = np.zeros((num_samples_lfp, num_chans)) traces_lfp[: int(t_switch * fs_lfp), 5] = 1.0 traces_lfp[int(t_switch * fs_lfp) :, 6] = 1.0 rec_lfp = sc.NumpyRecording(traces_lfp, sampling_frequency=fs_lfp) rec_lfp.set_dummy_probe_from_locations(geom) # same for AP - traces_ap = np.zeros((nt_ap, nc)) + traces_ap = np.zeros((num_samples_ap, num_chans)) traces_ap[: int(t_switch * fs_ap) - halfbin_ap_lfp, 5] = 1.0 traces_ap[int(t_switch * fs_ap) - halfbin_ap_lfp :, 6] = 1.0 rec_ap = sc.NumpyRecording(traces_ap, sampling_frequency=fs_ap) @@ -160,8 +157,8 @@ def test_cross_band_interpolation(): # set times for both, and silence the warning with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) - rec_lfp.set_times(t_start + np.arange(nt_lfp) / fs_lfp) - rec_ap.set_times(t_start + np.arange(nt_ap) / fs_ap) + rec_lfp.set_times(t_start + np.arange(num_samples_lfp) / fs_lfp) + rec_ap.set_times(t_start + np.arange(num_samples_ap) / fs_ap) # estimate motion motion = estimate_motion(rec_lfp, method="dredge_lfp", rigid=True) @@ -169,7 +166,7 @@ def test_cross_band_interpolation(): # nearest to keep it simple rec_corrected = interpolate_motion(rec_ap, motion, spatial_interpolation_method="nearest", num_closest=2) traces_corrected = rec_corrected.get_traces() - target = np.zeros((nt_ap, nc - 2)) + target = np.zeros((num_samples_ap, num_chans - 2)) target[:, 4] = 1 ii, jj = np.nonzero(traces_corrected) assert np.array_equal(traces_corrected, target)