Skip to content

Commit

Permalink
Merge branch 'spike_location_with_true_channel' of github.com:samuelg…
Browse files Browse the repository at this point in the history
…arcia/spikeinterface into spike_location_with_true_channel
  • Loading branch information
samuelgarcia committed Oct 25, 2023
2 parents e238608 + f94c594 commit 57ba043
Show file tree
Hide file tree
Showing 26 changed files with 706 additions and 1,628 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.9.1
rev: 23.10.0
hooks:
- id: black
files: ^src/
4 changes: 2 additions & 2 deletions examples/how_to/get_started.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,8 @@

print('Units in agreement between TDC, SC2, and KS2:', sorting_agreement.get_unit_ids())

w_multi = sw.plot_multicomp_agreement(comp_multi)
w_multi = sw.plot_multicomp_agreement_by_sorter(comp_multi)
w_multi = sw.plot_multicomparison_agreement(comp_multi)
w_multi = sw.plot_multicomparison_agreement_by_sorter(comp_multi)
# -

# We see that 10 unit were found by all sorters (note that this simulated dataset is a very simple example, and usually sorters do not do such a great job)!
Expand Down
10 changes: 5 additions & 5 deletions examples/modules_gallery/widgets/plot_2_sort_gallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

##############################################################################
# plot_rasters()
# ~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~

w_rs = sw.plot_rasters(sorting)

##############################################################################
# plot_isi_distribution()
# ~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~

w_isi = sw.plot_isi_distribution(sorting, window_ms=150.0, bin_ms=5.0)

Expand All @@ -43,10 +43,10 @@


##############################################################################
# plot_presence()
# ~~~~~~~~~~~~~~~~~~~~~~~~
# plot_unit_presence()
# ~~~~~~~~~~~~~~~~~~~~


w_pr = sw.plot_presence(sorting)
w_pr = sw.plot_unit_presence(sorting)

plt.show()
6 changes: 3 additions & 3 deletions examples/modules_gallery/widgets/plot_4_peaks_gallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@

##############################################################################
# This "peaks" vector can be used in several widgets, for instance
# plot_peak_activity_map()
# plot_peak_activity()

si.plot_peak_activity_map(rec_filtred, peaks=peaks)
si.plot_peak_activity(rec_filtred, peaks=peaks)

##############################################################################
# can be also animated with bin_duration_s=1.

si.plot_peak_activity_map(rec_filtred, bin_duration_s=1.)
si.plot_peak_activity(rec_filtred, bin_duration_s=1.)


plt.show()
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,6 @@ def test_compare_multiple_sorters():

msc = MultiSortingComparison.load_from_folder(multicomparison_folder)

# import spikeinterface.widgets as sw
# import matplotlib.pyplot as plt
# sw.plot_multicomp_graph(msc)
# sw.plot_multicomp_agreement(msc)
# sw.plot_multicomp_agreement_by_sorter(msc)
# plt.show()


def test_compare_multi_segment():
num_segments = 3
Expand Down
89 changes: 74 additions & 15 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .numpyextractors import NumpyRecording, NumpySorting
from .basesorting import minimum_spike_dtype

from probeinterface import Probe, generate_linear_probe
from probeinterface import Probe, generate_linear_probe, generate_multi_columns_probe

from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting
from .snippets_tools import snippets_from_sorting
Expand Down Expand Up @@ -93,7 +93,6 @@ def generate_recording(
probe = probe.to_3d()
probe.set_device_channel_indices(np.arange(num_channels))
recording.set_probe(probe, in_place=True)
probe = generate_linear_probe(num_elec=num_channels)

return recording

Expand Down Expand Up @@ -122,7 +121,7 @@ def generate_sorting(
durations=[10.325, 3.5], #  in s for 2 segments
firing_rates=3.0,
empty_units=None,
refractory_period_ms=3.0, # in ms
refractory_period_ms=4.0, # in ms
add_spikes_on_borders=False,
num_spikes_per_border=3,
border_size_samples=20,
Expand All @@ -143,7 +142,7 @@ def generate_sorting(
The firing rate of each unit (in Hz).
empty_units : list, default: None
List of units that will have no spikes. (used for testing mainly).
refractory_period_ms : float, default: 3.0
refractory_period_ms : float, default: 4.0
The refractory period in ms
add_spikes_on_borders : bool, default: False
If True, spikes will be added close to the borders of the segments.
Expand Down Expand Up @@ -881,9 +880,10 @@ def generate_single_fake_waveform(
depolarization_ms=(0.09, 0.14),
repolarization_ms=(0.5, 0.8),
recovery_ms=(1.0, 1.5),
positive_amplitude=(0.05, 0.15),
positive_amplitude=(0.1, 0.25),
smooth_ms=(0.03, 0.07),
decay_power=(1.4, 1.8),
propagation_speed=(250.0, 350.0), # um / ms
)


Expand Down Expand Up @@ -931,13 +931,14 @@ def generate_templates(
An optional dict containing parameters per units.
Keys are parameter names:
* 'alpha': amplitude of the action potential in a.u. (default range: (5'000-15'000))
* 'alpha': amplitude of the action potential in a.u. (default range: (6'000-9'000))
* 'depolarization_ms': the depolarization interval in ms (default range: (0.09-0.14))
* 'repolarization_ms': the repolarization interval in ms (default range: (0.5-0.8))
* 'recovery_ms': the recovery interval in ms (default range: (1.0-1.5))
* 'positive_amplitude': the positive amplitude in a.u. (default range: (0.05-0.15)) (negative is always -1)
* 'smooth_ms': the gaussian smooth in ms (default range: (0.03-0.07))
* 'decay_power': the decay power (default range: (1.2-1.8))
* 'propagation_speed': mimic a propagation delay with a kind of a "speed" (default range: (250., 350.)).
Values contains vector with same size of num_units.
If the key is not in dict then it is generated using unit_params_range
unit_params_range: dict of tuple
Expand Down Expand Up @@ -985,12 +986,16 @@ def generate_templates(
assert unit_params[k].size == num_units
params[k] = unit_params[k]
else:
v = rng.random(num_units)
if k in unit_params_range:
lim0, lim1 = unit_params_range[k]
lims = unit_params_range[k]
else:
lims = default_unit_params_range[k]
if lims is not None:
lim0, lim1 = lims
v = rng.random(num_units)
params[k] = v * (lim1 - lim0) + lim0
else:
lim0, lim1 = default_unit_params_range[k]
params[k] = v * (lim1 - lim0) + lim0
params[k] = [None] * num_units

for u in range(num_units):
wf = generate_single_fake_waveform(
Expand All @@ -1006,17 +1011,42 @@ def generate_templates(
dtype=dtype,
)

## Add a spatial decay depend on distance from unit to each channel
alpha = params["alpha"][u]
# the espilon avoid enormous factors
eps = 1.0
# naive formula for spatial decay
pow = params["decay_power"][u]
channel_factors = alpha / (distances[u, :] + eps) ** pow
wfs = wf[:, np.newaxis] * channel_factors[np.newaxis, :]

# This mimic a propagation delay for distant channel
propagation_speed = params["propagation_speed"][u]
if propagation_speed is not None:
# the speed is um/ms
dist = distances[u, :].copy()
dist -= np.min(dist)
delay_s = dist / propagation_speed / 1000.0
sample_shifts = delay_s * fs

# apply the delay with fft transform to get sub sample shift
n = wfs.shape[0]
wfs_f = np.fft.rfft(wfs, axis=0)
if n % 2 == 0:
# n is even sig_f[-1] is nyquist and so pi
omega = np.linspace(0, np.pi, wfs_f.shape[0])
else:
# n is odd sig_f[-1] is exactly nyquist!! we need (n-1) / n factor!!
omega = np.linspace(0, np.pi * (n - 1) / n, wfs_f.shape[0])
# broadcast omega and sample_shifts depend the axis
shifts = omega[:, np.newaxis] * sample_shifts[np.newaxis, :]
wfs = np.fft.irfft(wfs_f * np.exp(-1j * shifts), n=n, axis=0)

if upsample_factor is not None:
for f in range(upsample_factor):
templates[u, :, :, f] = wf[f::upsample_factor, np.newaxis] * channel_factors[np.newaxis, :]
templates[u, :, :, f] = wfs[f::upsample_factor]
else:
templates[u, :, :] = wf[:, np.newaxis] * channel_factors[np.newaxis, :]
templates[u, :, :] = wfs

return templates

Expand Down Expand Up @@ -1322,12 +1352,19 @@ def generate_ground_truth_recording(
num_units=10,
sorting=None,
probe=None,
generate_probe_kwargs=dict(
num_columns=2,
xpitch=20,
ypitch=20,
contact_shapes="circle",
contact_shape_params={"radius": 6},
),
templates=None,
ms_before=1.0,
ms_after=3.0,
upsample_factor=None,
upsample_vector=None,
generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=1.5),
generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0),
noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"),
generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0),
generate_templates_kwargs=dict(),
Expand All @@ -1350,7 +1387,9 @@ def generate_ground_truth_recording(
sorting: Sorting or None
An external sorting object. If not provide, one is genrated.
probe: Probe or None
An external Probe object. If not provided of linear probe is generated.
An external Probe object. If not provided a probe is generated using generate_probe_kwargs.
generate_probe_kwargs: dict
A dict to constuct the Probe using :py:func:`probeinterface.generate_multi_columns_probe()`.
templates: np.array or None
The templates of units.
If None they are generated.
Expand Down Expand Up @@ -1407,8 +1446,28 @@ def generate_ground_truth_recording(
num_spikes = sorting.to_spike_vector().size

if probe is None:
probe = generate_linear_probe(num_elec=num_channels)
# probe = generate_linear_probe(num_elec=num_channels)
# probe.set_device_channel_indices(np.arange(num_channels))

prb_kwargs = generate_probe_kwargs.copy()
if "num_contact_per_column" in prb_kwargs:
assert (
prb_kwargs["num_contact_per_column"] * prb_kwargs["num_columns"]
) == num_channels, (
"generate_multi_columns_probe : num_channels do not match num_contact_per_column x num_columns"
)
elif "num_contact_per_column" not in prb_kwargs and "num_columns" in prb_kwargs:
n = num_channels // prb_kwargs["num_columns"]
num_contact_per_column = [n] * prb_kwargs["num_columns"]
mid = prb_kwargs["num_columns"] // 2
num_contact_per_column[mid] += num_channels % prb_kwargs["num_columns"]
prb_kwargs["num_contact_per_column"] = num_contact_per_column
else:
raise ValueError("num_columns should be provided in dict generate_probe_kwargs")

probe = generate_multi_columns_probe(**prb_kwargs)
probe.set_device_channel_indices(np.arange(num_channels))

else:
num_channels = probe.get_contact_count()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@
from spikeinterface.qualitymetrics import compute_quality_metrics
from spikeinterface.comparison import CollisionGTComparison, compare_sorter_to_ground_truth
from spikeinterface.widgets import (
plot_sorting_performance,
plot_agreement_matrix,
plot_comparison_collision_by_similarity,
plot_unit_templates,
plot_unit_waveforms,
plot_gt_performances,
)

import time
Expand Down Expand Up @@ -474,13 +471,12 @@ def plot(self, comp, title=None):
ax = axs[1, 0]
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plot_sorting_performance(comp, self.metrics, performance_name="accuracy", metric_name="snr", ax=ax, color="r")
plot_sorting_performance(comp, self.metrics, performance_name="recall", metric_name="snr", ax=ax, color="g")
plot_sorting_performance(comp, self.metrics, performance_name="precision", metric_name="snr", ax=ax, color="b")
ax.legend(["accuracy", "recall", "precision"])

ax = axs[1, 1]
plot_gt_performances(comp, ax=ax)
for k in ("accuracy", "recall", "precision"):
x = comp.get_performance()[k]
y = self.metrics["snr"]
ax.scatter(x, y, markersize=10, marker=".", label=k)
ax.legend()

ax = axs[0, 1]
if self.exhaustive_gt:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from spikeinterface.extractors import read_mearec
from spikeinterface.preprocessing import bandpass_filter, zscore, common_reference, scale, highpass_filter, whiten
from spikeinterface.sorters import run_sorter, read_sorter_folder
from spikeinterface.widgets import plot_unit_waveforms, plot_gt_performances

from spikeinterface.comparison import GroundTruthComparison
from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording
Expand Down
6 changes: 0 additions & 6 deletions src/spikeinterface/widgets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,3 @@
# general functions
from .utils import get_some_colors, get_unit_colors, array_to_image
from .base import set_default_plotter_backend, get_default_plotter_backend


# we keep this to keep compatibility so we have all previous widgets
# except the one that have been ported that are imported
# with "from .widget_list import *" in the first line
from ._legacy_mpl_widgets import *
48 changes: 0 additions & 48 deletions src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py

This file was deleted.

Loading

0 comments on commit 57ba043

Please sign in to comment.