Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 11, 2024
1 parent 35ab6b0 commit 18b7cfe
Show file tree
Hide file tree
Showing 14 changed files with 134 additions and 212 deletions.
14 changes: 6 additions & 8 deletions doc/how_to/drift_with_lfp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ For each patient, the dataset contains two recording : a high pass (AP -
from pathlib import Path
import matplotlib.pyplot as plt
import spikeinterface.full as si
from spikeinterface.sortingcomponents.motion import estimate_motion
Expand All @@ -57,7 +57,7 @@ read the spikeglx file
.. parsed-literal::
SpikeGLXRecordingExtractor: 384 channels - 2.5kHz - 1 segments - 2,183,292 samples
SpikeGLXRecordingExtractor: 384 channels - 2.5kHz - 1 segments - 2,183,292 samples
873.32s (14.56 minutes) - int16 dtype - 1.56 GiB
Expand Down Expand Up @@ -87,24 +87,24 @@ eyes ont the traces plotted with the map mode.
raw_rec,
freq_min=0.5,
freq_max=250,
margin_ms=1500.,
filter_order=3,
dtype="float32",
add_reflect_padding=True,
)
lfprec = si.phase_shift(lfprec)
lfprec = si.resample(lfprec, resample_rate=250, margin_ms=1000)
lfprec = si.directional_derivative(lfprec, order=2, edge_order=1)
lfprec = si.average_across_direction(lfprec)
print(lfprec)
.. parsed-literal::
AverageAcrossDirectionRecording: 192 channels - 0.2kHz - 1 segments - 218,329 samples
AverageAcrossDirectionRecording: 192 channels - 0.2kHz - 1 segments - 218,329 samples
873.32s (14.56 minutes) - float32 dtype - 159.91 MiB
Expand Down Expand Up @@ -185,5 +185,3 @@ This motion match the LFP signal above.
.. image:: drift_with_lfp_files/drift_with_lfp_12_1.png


2 changes: 1 addition & 1 deletion doc/how_to/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to.
process_by_channel_group
load_your_data_into_sorting
benchmark_with_hybrid_recordings
drift_with_lfp
drift_with_lfp
3 changes: 0 additions & 3 deletions examples/how_to/drift_with_lfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,3 @@
si.plot_motion(motion, mode='line', ax=ax)
ax.set_xlim(400, 420)
ax.set_ylim(800, 1300)



18 changes: 9 additions & 9 deletions examples/tutorials/sortingcomponents/plot_1_estimate_motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
from spikeinterface.widgets import plot_motion, plot_motion_info, plot_probe_map

# %%
# First, let's simulate a drifting recording using the
# First, let's simulate a drifting recording using the
# :code:`spikeinterface.generation module`.
#
#
# Here the simulated recording has a small zigzag motion along the 'y' axis of the probe.

static_recording, drifting_recording, sorting = generate_drifting_recording(
num_units=200,
duration=300.,
probe_name='Neuropixel-128',
probe_name='Neuropixel-128',
generate_displacement_vector_kwargs=dict(
displacement_sampling_frequency=5.0,
drift_start_um=[0, 20],
Expand All @@ -50,12 +50,12 @@

# %%
# Here we will use the high level function :code:`correct_motion()`
#
#
# Internally, this function is doing all steps of the motion detection:
# 1. **activity profile** : detect peaks and localize them along time and depth
# 2. **motion inference**: estimate the drift motion
# 3. **motion interpolation**: interpolate traces using the estimated motion
#
#
# All steps have an use several methods with many parameters. This is why we can use
# 'preset' which combine methods and related parameters.
#
Expand All @@ -70,30 +70,30 @@

# %%
# The function return a recording 'corrected'
#
#
# A new recording is return, this recording will interpolate motion corrected traces
# when calling get_traces()

print(recording_corrected)

# %%
# Optionally the function also return the `Motion` object itself
#
#

print(motion)

# %%
# This motion can be plotted, in our case the motion has been estimated as non-rigid
# so we can use the use the `mode='map'` to check the motion across depth.
#
#

plot_motion(motion, mode='line')
plot_motion(motion, mode='map')


# %%
# The dict `motion_info` can be used for more plotting.
# Here we can appreciate of the two top axes the raster of peaks depth vs times before and
# Here we can appreciate of the two top axes the raster of peaks depth vs times before and
# after correction.

fig = plt.figure()
Expand Down
14 changes: 6 additions & 8 deletions src/spikeinterface/preprocessing/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
peak_sign="neg",
detect_threshold=8.0,
exclude_sweep_ms=0.8,
radius_um=80.,
radius_um=80.0,
),
"select_kwargs": dict(),
"localize_peaks_kwargs": dict(
Expand Down Expand Up @@ -76,7 +76,7 @@
peak_sign="neg",
detect_threshold=8.0,
exclude_sweep_ms=0.8,
radius_um=80.,
radius_um=80.0,
),
"select_kwargs": dict(),
"localize_peaks_kwargs": dict(
Expand Down Expand Up @@ -196,7 +196,6 @@
}



def correct_motion(
recording,
preset="nonrigid_accurate",
Expand Down Expand Up @@ -398,16 +397,15 @@ def correct_motion(

if not output_motion and not output_motion_info:
return recording_corrected
out = (recording_corrected, )

out = (recording_corrected,)
if output_motion:
out += (motion, )
out += (motion,)
if output_motion_info:
out += (motion_info, )
out += (motion_info,)
return out



_doc_presets = "\n"
for k, v in motion_options_preset.items():
if k == "":
Expand Down
8 changes: 6 additions & 2 deletions src/spikeinterface/sortingcomponents/motion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from .motion_utils import Motion
from .motion_estimation import estimate_motion
from .motion_interpolation import (correct_motion_on_peaks, interpolate_motion_on_traces,
InterpolateMotionRecording, interpolate_motion)
from .motion_interpolation import (
correct_motion_on_peaks,
interpolate_motion_on_traces,
InterpolateMotionRecording,
interpolate_motion,
)
from .motion_cleaner import clean_motion_vector
14 changes: 5 additions & 9 deletions src/spikeinterface/sortingcomponents/motion/decentralized.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def run(
bin_um=1.0,
hist_margin_um=20.0,
bin_s=1.0,
histogram_depth_smooth_um=1.,
histogram_time_smooth_s=1.,
histogram_depth_smooth_um=1.0,
histogram_time_smooth_s=1.0,
pairwise_displacement_method="conv",
max_displacement_um=100.0,
weight_scale="linear",
Expand All @@ -135,7 +135,6 @@ def run(
lsqr_robust_n_iter=20,
weight_with_amplitude=False,
):


dim = ["x", "y", "z"].index(direction)
contact_depth = recording.get_channel_locations()[:, dim]
Expand All @@ -153,7 +152,7 @@ def run(
win_step_um=win_step_um,
win_scale_um=win_scale_um,
win_margin_um=win_margin_um,
zero_threshold=None
zero_threshold=None,
)

# make 2D histogram raster
Expand Down Expand Up @@ -322,6 +321,7 @@ def compute_pairwise_displacement(
try:
import torch
import torch.nn.functional as F

conv_engine = "torch"
except ImportError:
conv_engine = "numpy"
Expand Down Expand Up @@ -430,7 +430,6 @@ def compute_pairwise_displacement(
return pairwise_displacement, pairwise_displacement_weight



_possible_convergence_method = ("lsmr", "gradient_descent", "lsqr_robust")


Expand Down Expand Up @@ -687,9 +686,6 @@ def jac(p):
return np.squeeze(displacement)





# normxcorr1d is now implemented in dredge
# we keep the old version here but this will be removed soon

Expand Down Expand Up @@ -809,4 +805,4 @@ def jac(p):
# # get rid of NaNs in zero-variance areas
# corr[~npx.isfinite(corr)] = 0

# return corr
# return corr
Loading

0 comments on commit 18b7cfe

Please sign in to comment.