diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index a48e10b3e1..203bd2473b 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -457,7 +457,7 @@ def make_2d_motion_histogram( """ n_samples = recording.get_num_samples() mint_s = recording.sample_index_to_time(0) - maxt_s = recording.sample_index_to_time(n_samples) + maxt_s = recording.sample_index_to_time(n_samples - 1) temporal_bin_edges = np.arange(mint_s, maxt_s + bin_s, bin_s) if spatial_bin_edges is None: spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um) @@ -542,7 +542,7 @@ def make_3d_motion_histograms( """ n_samples = recording.get_num_samples() mint_s = recording.sample_index_to_time(0) - maxt_s = recording.sample_index_to_time(n_samples) + maxt_s = recording.sample_index_to_time(n_samples - 1) temporal_bin_edges = np.arange(mint_s, maxt_s + bin_s, bin_s) if spatial_bin_edges is None: spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um) diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py index 3c83a56b9d..0726ca5a87 100644 --- a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py @@ -45,7 +45,10 @@ def setup_dataset_and_peaks(cache_folder): peak_location_path = cache_folder / "dataset_peak_locations.npy" np.save(peak_location_path, peak_locations) - return recording, sorting, cache_folder + recording_with_times = recording.clone() + recording_with_times.set_times(recording.get_times() + 100) + + return recording, recording_with_times, sorting, cache_folder @pytest.fixture(scope="module", name="dataset") @@ -56,7 +59,7 @@ def dataset_fixture(create_cache_folder): def test_estimate_motion(dataset): # recording, sorting = make_dataset() - recording, sorting, cache_folder = dataset + recording, recording_with_times, sorting, cache_folder = dataset peaks = np.load(cache_folder / "dataset_peaks.npy") peak_locations = np.load(cache_folder / "dataset_peak_locations.npy") @@ -146,78 +149,83 @@ def test_estimate_motion(dataset): ), } - motions = {} - for name, cases_kwargs in all_cases.items(): - print(name) - - kwargs = dict( - direction="y", - bin_s=1.0, - bin_um=10.0, - margin_um=5, - extra_outputs=True, - ) - kwargs.update(cases_kwargs) - - motion, extra = estimate_motion(recording, peaks, peak_locations, **kwargs) - motions[name] = motion - - if cases_kwargs["rigid"]: - assert motion.displacement[0].shape[1] == 1 - else: - assert motion.displacement[0].shape[1] > 1 - - # # Test saving to disk - # corrected_rec = InterpolateMotionRecording( - # recording, motion, temporal_bins, spatial_bins, border_mode="force_extrapolate" - # ) - # rec_folder = cache_folder / (name.replace("/", "").replace(" ", "_") + "_recording") - # if rec_folder.exists(): - # shutil.rmtree(rec_folder) - # corrected_rec.save(folder=rec_folder) - - if DEBUG: - fig, ax = plt.subplots() - seg_index = 0 - ax.plot(motion.temporal_bins_s[0], motion.displacement[seg_index]) - - # motion_histogram = extra_check['motion_histogram'] - # spatial_hist_bins = extra_check['spatial_hist_bin_edges'] - # fig, ax = plt.subplots() - # extent = (temporal_bins[0], temporal_bins[-1], spatial_hist_bins[0], spatial_hist_bins[-1]) - # im = ax.imshow(motion_histogram.T, interpolation='nearest', - # origin='lower', aspect='auto', extent=extent) - - # fig, ax = plt.subplots() - # pairwise_displacement = extra_check['pairwise_displacement_list'][0] - # im = ax.imshow(pairwise_displacement, interpolation='nearest', - # cmap='PiYG', origin='lower', aspect='auto', extent=None) - # im.set_clim(-40, 40) - # ax.set_aspect('equal') - # fig.colorbar(im) - - plt.show() - - # same params with differents engine should be the same - motion0 = motions["rigid / decentralized / torch"] - motion1 = motions["rigid / decentralized / numpy"] - assert motion0 == motion1 - - motion0 = motions["rigid / decentralized / torch / time_horizon_s"] - motion1 = motions["rigid / decentralized / numpy / time_horizon_s"] - np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) - - motion0 = motions["non-rigid / decentralized / torch"] - motion1 = motions["non-rigid / decentralized / numpy"] - np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) - - motion0 = motions["non-rigid / decentralized / torch / time_horizon_s"] - motion1 = motions["non-rigid / decentralized / numpy / time_horizon_s"] - np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) - - motion0 = motions["non-rigid / decentralized / torch / spatial_prior"] - motion1 = motions["non-rigid / decentralized / numpy / spatial_prior"] - np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) + for rec in [recording, recording_with_times]: + motions = {} + for name, cases_kwargs in all_cases.items(): + print(name) + + kwargs = dict( + direction="y", + bin_s=1.0, + bin_um=10.0, + margin_um=5, + extra_outputs=True, + ) + kwargs.update(cases_kwargs) + + motion, extra = estimate_motion(rec, peaks, peak_locations, **kwargs) + motions[name] = motion + + if cases_kwargs["rigid"]: + assert motion.displacement[0].shape[1] == 1 + else: + assert motion.displacement[0].shape[1] > 1 + + if rec.has_time_vector(): + assert np.all(motion.temporal_bins_s[0] >= rec.get_times()[0]) + assert np.all(motion.temporal_bins_s[0] <= rec.get_times()[-1]) + + # # Test saving to disk + # corrected_rec = InterpolateMotionRecording( + # recording, motion, temporal_bins, spatial_bins, border_mode="force_extrapolate" + # ) + # rec_folder = cache_folder / (name.replace("/", "").replace(" ", "_") + "_recording") + # if rec_folder.exists(): + # shutil.rmtree(rec_folder) + # corrected_rec.save(folder=rec_folder) + + if DEBUG: + fig, ax = plt.subplots() + seg_index = 0 + ax.plot(motion.temporal_bins_s[0], motion.displacement[seg_index]) + + # motion_histogram = extra_check['motion_histogram'] + # spatial_hist_bins = extra_check['spatial_hist_bin_edges'] + # fig, ax = plt.subplots() + # extent = (temporal_bins[0], temporal_bins[-1], spatial_hist_bins[0], spatial_hist_bins[-1]) + # im = ax.imshow(motion_histogram.T, interpolation='nearest', + # origin='lower', aspect='auto', extent=extent) + + # fig, ax = plt.subplots() + # pairwise_displacement = extra_check['pairwise_displacement_list'][0] + # im = ax.imshow(pairwise_displacement, interpolation='nearest', + # cmap='PiYG', origin='lower', aspect='auto', extent=None) + # im.set_clim(-40, 40) + # ax.set_aspect('equal') + # fig.colorbar(im) + + plt.show() + + # same params with differents engine should be the same + motion0 = motions["rigid / decentralized / torch"] + motion1 = motions["rigid / decentralized / numpy"] + assert motion0 == motion1 + + motion0 = motions["rigid / decentralized / torch / time_horizon_s"] + motion1 = motions["rigid / decentralized / numpy / time_horizon_s"] + np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) + + motion0 = motions["non-rigid / decentralized / torch"] + motion1 = motions["non-rigid / decentralized / numpy"] + np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) + + motion0 = motions["non-rigid / decentralized / torch / time_horizon_s"] + motion1 = motions["non-rigid / decentralized / numpy / time_horizon_s"] + np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) + + motion0 = motions["non-rigid / decentralized / torch / spatial_prior"] + motion1 = motions["non-rigid / decentralized / numpy / spatial_prior"] + np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) if __name__ == "__main__":