Skip to content

Commit

Permalink
Add tutorial motion.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Jul 8, 2024
1 parent 840e9c1 commit 8ef01d9
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 4 deletions.
103 changes: 103 additions & 0 deletions examples/tutorials/sortingcomponents/plot_1_estimate_motion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""
Motion estimation
=================
SpikeInterface offers a very flexible framework to handle drift as a
preprocessing step. If you want to know more, please read the
:ref:`motion_correction` section of the documentation.
Here a short example with a simulated drifting recording.
"""

# %%
import matplotlib.pyplot as plt


from spikeinterface.generation import generate_drifting_recording
from spikeinterface.preprocessing import correct_motion
from spikeinterface.widgets import plot_motion, plot_motion_info, plot_probe_map

# %%
# First, let's simulate a drifting recording using the
# :code:`spikeinterface.generation module`.
#
# Here the simulated recording has a small zigzag motion along the 'y' axis of the probe.

static_recording, drifting_recording, sorting = generate_drifting_recording(
num_units=200,
duration=300.,
probe_name='Neuropixel-128',
generate_displacement_vector_kwargs=dict(
displacement_sampling_frequency=5.0,
drift_start_um=[0, 20],
drift_stop_um=[0, -20],
drift_step_um=1,
motion_list=[
dict(
drift_mode="zigzag",
non_rigid_gradient=None,
t_start_drift=60.0,
t_end_drift=None,
period_s=200,
),
],
),
seed=2205,
)

plot_probe_map(drifting_recording)

# %%
# Here we will use the high level function :code:`correct_motion()`
#
# Internally, this function is doing all steps of the motion detection:
# 1. **activity profile** : detect peaks and localize them along time and depth
# 2. **motion inference**: estimate the drift motion
# 3. **motion interpolation**: interpolate traces using the estimated motion
#
# All steps have an use several methods with many parameters. This is why we can use
# 'preset' which combine methods and related parameters.
#
# This function can take a while peak detection and localization is a slow process
# that need to go trought the entire traces

recording_corrected, motion, motion_info = correct_motion(
drifting_recording, preset="nonrigid_fast_and_accurate",
output_motion=True, output_motion_info=True,
n_jobs=-1, progress_bar=True,
)

# %%
# The function return a recording 'corrected'
#
# A new recording is return, this recording will interpolate motion corrected traces
# when calling get_traces()

print(recording_corrected)

# %%
# Optionally the function also return the `Motion` object itself
#

print(motion)

# %%
# This motion can be plotted, in our case the motion has been estimated as non-rigid
# so we can use the use the `mode='map'` to check the motion across depth.
#

plot_motion(motion, mode='line')
plot_motion(motion, mode='map')


# %%
# The dict `motion_info` can be used for more plotting.
# Here we can appreciate of the two top axes the raster of peaks depth vs times before and
# after correction.

fig = plt.figure()
plot_motion_info(motion_info, drifting_recording, amplitude_cmap="inferno", color_amplitude=True, figure=fig)
fig.axes[0].set_ylim(520, 620)
plt.show()
# %%
20 changes: 16 additions & 4 deletions src/spikeinterface/preprocessing/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def correct_motion(
recording,
preset="nonrigid_accurate",
folder=None,
output_motion=False,
output_motion_info=False,
overwrite=False,
detect_kwargs={},
Expand Down Expand Up @@ -251,6 +252,8 @@ def correct_motion(
The preset name
folder : Path str or None, default: None
If not None then intermediate motion info are saved into a folder
output_motion : bool, default: False
It True, the function returns a `motion` object.
output_motion_info : bool, default: False
If True, then the function returns a `motion_info` dictionary that contains variables
to check intermediate steps (motion_histogram, non_rigid_windows, pairwise_displacement)
Expand All @@ -275,8 +278,11 @@ def correct_motion(
-------
recording_corrected : Recording
The motion corrected recording
motion : Motion
Optional output if `output_motion=True`.
motion_info : dict
Optional output if `output_motion_info=True`. The key "motion" holds the Motion object.
Optional output if `output_motion_info=True`. This dict contains several variable for
for plotting. See `plot_motion_info()`
"""
# local import are important because "sortingcomponents" is not important by default
from spikeinterface.sortingcomponents.peak_detection import detect_peaks, detect_peak_methods
Expand Down Expand Up @@ -390,10 +396,16 @@ def correct_motion(
if folder is not None:
save_motion_info(motion_info, folder, overwrite=overwrite)

if output_motion_info:
return recording_corrected, motion_info
else:
if not output_motion and not output_motion_info:
return recording_corrected

out = (recording_corrected, )
if output_motion:
out += (motion, )
if output_motion_info:
out += (motion_info, )
return out



_doc_presets = "\n"
Expand Down

0 comments on commit 8ef01d9

Please sign in to comment.