Skip to content

Commit

Permalink
Merge pull request #3218 from alejoe91/fix-estmiate-motion-with-times
Browse files Browse the repository at this point in the history
Fix estimate_motion when time_vector is set
  • Loading branch information
samuelgarcia authored Jul 18, 2024
2 parents c929e51 + e368ea5 commit 5e222d3
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 76 deletions.
4 changes: 2 additions & 2 deletions src/spikeinterface/sortingcomponents/motion/motion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def make_2d_motion_histogram(
"""
n_samples = recording.get_num_samples()
mint_s = recording.sample_index_to_time(0)
maxt_s = recording.sample_index_to_time(n_samples)
maxt_s = recording.sample_index_to_time(n_samples - 1)
temporal_bin_edges = np.arange(mint_s, maxt_s + bin_s, bin_s)
if spatial_bin_edges is None:
spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um)
Expand Down Expand Up @@ -542,7 +542,7 @@ def make_3d_motion_histograms(
"""
n_samples = recording.get_num_samples()
mint_s = recording.sample_index_to_time(0)
maxt_s = recording.sample_index_to_time(n_samples)
maxt_s = recording.sample_index_to_time(n_samples - 1)
temporal_bin_edges = np.arange(mint_s, maxt_s + bin_s, bin_s)
if spatial_bin_edges is None:
spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def setup_dataset_and_peaks(cache_folder):
peak_location_path = cache_folder / "dataset_peak_locations.npy"
np.save(peak_location_path, peak_locations)

return recording, sorting, cache_folder
recording_with_times = recording.clone()
recording_with_times.set_times(recording.get_times() + 100)

return recording, recording_with_times, sorting, cache_folder


@pytest.fixture(scope="module", name="dataset")
Expand All @@ -56,7 +59,7 @@ def dataset_fixture(create_cache_folder):

def test_estimate_motion(dataset):
# recording, sorting = make_dataset()
recording, sorting, cache_folder = dataset
recording, recording_with_times, sorting, cache_folder = dataset

peaks = np.load(cache_folder / "dataset_peaks.npy")
peak_locations = np.load(cache_folder / "dataset_peak_locations.npy")
Expand Down Expand Up @@ -146,78 +149,83 @@ def test_estimate_motion(dataset):
),
}

motions = {}
for name, cases_kwargs in all_cases.items():
print(name)

kwargs = dict(
direction="y",
bin_s=1.0,
bin_um=10.0,
margin_um=5,
extra_outputs=True,
)
kwargs.update(cases_kwargs)

motion, extra = estimate_motion(recording, peaks, peak_locations, **kwargs)
motions[name] = motion

if cases_kwargs["rigid"]:
assert motion.displacement[0].shape[1] == 1
else:
assert motion.displacement[0].shape[1] > 1

# # Test saving to disk
# corrected_rec = InterpolateMotionRecording(
# recording, motion, temporal_bins, spatial_bins, border_mode="force_extrapolate"
# )
# rec_folder = cache_folder / (name.replace("/", "").replace(" ", "_") + "_recording")
# if rec_folder.exists():
# shutil.rmtree(rec_folder)
# corrected_rec.save(folder=rec_folder)

if DEBUG:
fig, ax = plt.subplots()
seg_index = 0
ax.plot(motion.temporal_bins_s[0], motion.displacement[seg_index])

# motion_histogram = extra_check['motion_histogram']
# spatial_hist_bins = extra_check['spatial_hist_bin_edges']
# fig, ax = plt.subplots()
# extent = (temporal_bins[0], temporal_bins[-1], spatial_hist_bins[0], spatial_hist_bins[-1])
# im = ax.imshow(motion_histogram.T, interpolation='nearest',
# origin='lower', aspect='auto', extent=extent)

# fig, ax = plt.subplots()
# pairwise_displacement = extra_check['pairwise_displacement_list'][0]
# im = ax.imshow(pairwise_displacement, interpolation='nearest',
# cmap='PiYG', origin='lower', aspect='auto', extent=None)
# im.set_clim(-40, 40)
# ax.set_aspect('equal')
# fig.colorbar(im)

plt.show()

# same params with differents engine should be the same
motion0 = motions["rigid / decentralized / torch"]
motion1 = motions["rigid / decentralized / numpy"]
assert motion0 == motion1

motion0 = motions["rigid / decentralized / torch / time_horizon_s"]
motion1 = motions["rigid / decentralized / numpy / time_horizon_s"]
np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement)

motion0 = motions["non-rigid / decentralized / torch"]
motion1 = motions["non-rigid / decentralized / numpy"]
np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement)

motion0 = motions["non-rigid / decentralized / torch / time_horizon_s"]
motion1 = motions["non-rigid / decentralized / numpy / time_horizon_s"]
np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement)

motion0 = motions["non-rigid / decentralized / torch / spatial_prior"]
motion1 = motions["non-rigid / decentralized / numpy / spatial_prior"]
np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement)
for rec in [recording, recording_with_times]:
motions = {}
for name, cases_kwargs in all_cases.items():
print(name)

kwargs = dict(
direction="y",
bin_s=1.0,
bin_um=10.0,
margin_um=5,
extra_outputs=True,
)
kwargs.update(cases_kwargs)

motion, extra = estimate_motion(rec, peaks, peak_locations, **kwargs)
motions[name] = motion

if cases_kwargs["rigid"]:
assert motion.displacement[0].shape[1] == 1
else:
assert motion.displacement[0].shape[1] > 1

if rec.has_time_vector():
assert np.all(motion.temporal_bins_s[0] >= rec.get_times()[0])
assert np.all(motion.temporal_bins_s[0] <= rec.get_times()[-1])

# # Test saving to disk
# corrected_rec = InterpolateMotionRecording(
# recording, motion, temporal_bins, spatial_bins, border_mode="force_extrapolate"
# )
# rec_folder = cache_folder / (name.replace("/", "").replace(" ", "_") + "_recording")
# if rec_folder.exists():
# shutil.rmtree(rec_folder)
# corrected_rec.save(folder=rec_folder)

if DEBUG:
fig, ax = plt.subplots()
seg_index = 0
ax.plot(motion.temporal_bins_s[0], motion.displacement[seg_index])

# motion_histogram = extra_check['motion_histogram']
# spatial_hist_bins = extra_check['spatial_hist_bin_edges']
# fig, ax = plt.subplots()
# extent = (temporal_bins[0], temporal_bins[-1], spatial_hist_bins[0], spatial_hist_bins[-1])
# im = ax.imshow(motion_histogram.T, interpolation='nearest',
# origin='lower', aspect='auto', extent=extent)

# fig, ax = plt.subplots()
# pairwise_displacement = extra_check['pairwise_displacement_list'][0]
# im = ax.imshow(pairwise_displacement, interpolation='nearest',
# cmap='PiYG', origin='lower', aspect='auto', extent=None)
# im.set_clim(-40, 40)
# ax.set_aspect('equal')
# fig.colorbar(im)

plt.show()

# same params with differents engine should be the same
motion0 = motions["rigid / decentralized / torch"]
motion1 = motions["rigid / decentralized / numpy"]
assert motion0 == motion1

motion0 = motions["rigid / decentralized / torch / time_horizon_s"]
motion1 = motions["rigid / decentralized / numpy / time_horizon_s"]
np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement)

motion0 = motions["non-rigid / decentralized / torch"]
motion1 = motions["non-rigid / decentralized / numpy"]
np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement)

motion0 = motions["non-rigid / decentralized / torch / time_horizon_s"]
motion1 = motions["non-rigid / decentralized / numpy / time_horizon_s"]
np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement)

motion0 = motions["non-rigid / decentralized / torch / spatial_prior"]
motion1 = motions["non-rigid / decentralized / numpy / spatial_prior"]
np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement)


if __name__ == "__main__":
Expand Down

0 comments on commit 5e222d3

Please sign in to comment.