Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve generate.py with spatial on generate_template() #2098

Merged
merged 12 commits into from
Oct 24, 2023
89 changes: 74 additions & 15 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .numpyextractors import NumpyRecording, NumpySorting
from .basesorting import minimum_spike_dtype

from probeinterface import Probe, generate_linear_probe
from probeinterface import Probe, generate_linear_probe, generate_multi_columns_probe

from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting
from .snippets_tools import snippets_from_sorting
Expand Down Expand Up @@ -93,7 +93,6 @@ def generate_recording(
probe = probe.to_3d()
probe.set_device_channel_indices(np.arange(num_channels))
recording.set_probe(probe, in_place=True)
probe = generate_linear_probe(num_elec=num_channels)

return recording

Expand Down Expand Up @@ -122,7 +121,7 @@ def generate_sorting(
durations=[10.325, 3.5], #  in s for 2 segments
firing_rates=3.0,
empty_units=None,
refractory_period_ms=3.0, # in ms
refractory_period_ms=4.0, # in ms
add_spikes_on_borders=False,
num_spikes_per_border=3,
border_size_samples=20,
Expand All @@ -143,7 +142,7 @@ def generate_sorting(
The firing rate of each unit (in Hz).
empty_units : list, default: None
List of units that will have no spikes. (used for testing mainly).
refractory_period_ms : float, default: 3.0
refractory_period_ms : float, default: 4.0
The refractory period in ms
add_spikes_on_borders : bool, default: False
If True, spikes will be added close to the borders of the segments.
Expand Down Expand Up @@ -881,9 +880,10 @@ def generate_single_fake_waveform(
depolarization_ms=(0.09, 0.14),
repolarization_ms=(0.5, 0.8),
recovery_ms=(1.0, 1.5),
positive_amplitude=(0.05, 0.15),
positive_amplitude=(0.1, 0.25),
smooth_ms=(0.03, 0.07),
decay_power=(1.4, 1.8),
propagation_speed=(250.0, 350.0), # ms / um
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
)


Expand Down Expand Up @@ -931,13 +931,14 @@ def generate_templates(
An optional dict containing parameters per units.
Keys are parameter names:

* 'alpha': amplitude of the action potential in a.u. (default range: (5'000-15'000))
* 'alpha': amplitude of the action potential in a.u. (default range: (6'000-9'000))
* 'depolarization_ms': the depolarization interval in ms (default range: (0.09-0.14))
* 'repolarization_ms': the repolarization interval in ms (default range: (0.5-0.8))
* 'recovery_ms': the recovery interval in ms (default range: (1.0-1.5))
* 'positive_amplitude': the positive amplitude in a.u. (default range: (0.05-0.15)) (negative is always -1)
* 'smooth_ms': the gaussian smooth in ms (default range: (0.03-0.07))
* 'decay_power': the decay power (default range: (1.2-1.8))
* 'propagation_speed': mimic a propagation delay with a kind of a "speed" (default range: (250., 350.)).
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
Values contains vector with same size of num_units.
If the key is not in dict then it is generated using unit_params_range
unit_params_range: dict of tuple
Expand Down Expand Up @@ -985,12 +986,16 @@ def generate_templates(
assert unit_params[k].size == num_units
params[k] = unit_params[k]
else:
v = rng.random(num_units)
if k in unit_params_range:
lim0, lim1 = unit_params_range[k]
lims = unit_params_range[k]
else:
lims = default_unit_params_range[k]
if lims is not None:
lim0, lim1 = lims
v = rng.random(num_units)
params[k] = v * (lim1 - lim0) + lim0
else:
lim0, lim1 = default_unit_params_range[k]
params[k] = v * (lim1 - lim0) + lim0
params[k] = [None] * num_units

for u in range(num_units):
wf = generate_single_fake_waveform(
Expand All @@ -1006,17 +1011,42 @@ def generate_templates(
dtype=dtype,
)

## Add a spatial decay depend on distance from unit to each channel
alpha = params["alpha"][u]
# the espilon avoid enormous factors
eps = 1.0
# naive formula for spatial decay
pow = params["decay_power"][u]
channel_factors = alpha / (distances[u, :] + eps) ** pow
wfs = wf[:, np.newaxis] * channel_factors[np.newaxis, :]

# This mimic a propagation delay for distant channel
propagation_speed = params["propagation_speed"][u]
if propagation_speed is not None:
# the speed is um/ms
dist = distances[u, :].copy()
dist -= np.min(dist)
delay_s = dist / propagation_speed / 1000.0
sample_shifts = delay_s * fs

# apply the delay with fft transform to get sub sample shift
n = wfs.shape[0]
wfs_f = np.fft.rfft(wfs, axis=0)
if n % 2 == 0:
# n is even sig_f[-1] is nyquist and so pi
omega = np.linspace(0, np.pi, wfs_f.shape[0])
else:
# n is odd sig_f[-1] is exactly nyquist!! we need (n-1) / n factor!!
omega = np.linspace(0, np.pi * (n - 1) / n, wfs_f.shape[0])
# broadcast omega and sample_shifts depend the axis
shifts = omega[:, np.newaxis] * sample_shifts[np.newaxis, :]
wfs = np.fft.irfft(wfs_f * np.exp(-1j * shifts), n=n, axis=0)
Comment on lines +1027 to +1043
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!


if upsample_factor is not None:
for f in range(upsample_factor):
templates[u, :, :, f] = wf[f::upsample_factor, np.newaxis] * channel_factors[np.newaxis, :]
templates[u, :, :, f] = wfs[f::upsample_factor]
else:
templates[u, :, :] = wf[:, np.newaxis] * channel_factors[np.newaxis, :]
templates[u, :, :] = wfs

return templates

Expand Down Expand Up @@ -1322,12 +1352,19 @@ def generate_ground_truth_recording(
num_units=10,
sorting=None,
probe=None,
generate_probe_kwargs=dict(
num_columns=2,
xpitch=20,
ypitch=20,
contact_shapes="circle",
contact_shape_params={"radius": 6},
),
templates=None,
ms_before=1.0,
ms_after=3.0,
upsample_factor=None,
upsample_vector=None,
generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=1.5),
generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0),
noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"),
generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0),
generate_templates_kwargs=dict(),
Expand All @@ -1350,7 +1387,9 @@ def generate_ground_truth_recording(
sorting: Sorting or None
An external sorting object. If not provide, one is genrated.
probe: Probe or None
An external Probe object. If not provided of linear probe is generated.
An external Probe object. If not provided a probe is generated using generate_probe_kwargs.
generate_probe_kwargs: dict
A dict to constuct the Probe using :py:func:`probeinterface.generate_multi_columns_probe()`.
templates: np.array or None
The templates of units.
If None they are generated.
Expand Down Expand Up @@ -1407,8 +1446,28 @@ def generate_ground_truth_recording(
num_spikes = sorting.to_spike_vector().size

if probe is None:
probe = generate_linear_probe(num_elec=num_channels)
# probe = generate_linear_probe(num_elec=num_channels)
# probe.set_device_channel_indices(np.arange(num_channels))

prb_kwargs = generate_probe_kwargs.copy()
if "num_contact_per_column" in prb_kwargs:
assert (
prb_kwargs["num_contact_per_column"] * prb_kwargs["num_columns"]
) == num_channels, (
"generate_multi_columns_probe : num_channels do not match num_contact_per_column x num_columns"
)
elif "num_contact_per_column" not in prb_kwargs and "num_columns" in prb_kwargs:
n = num_channels // prb_kwargs["num_columns"]
num_contact_per_column = [n] * prb_kwargs["num_columns"]
mid = prb_kwargs["num_columns"] // 2
num_contact_per_column[mid] += num_channels % prb_kwargs["num_columns"]
prb_kwargs["num_contact_per_column"] = num_contact_per_column
else:
raise ValueError("num_columns should be provided in dict generate_probe_kwargs")

probe = generate_multi_columns_probe(**prb_kwargs)
probe.set_device_channel_indices(np.arange(num_channels))

else:
num_channels = probe.get_contact_count()

Expand Down