diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index dc84d31987..44ea02d32c 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -8,7 +8,7 @@ from .numpyextractors import NumpyRecording, NumpySorting from .basesorting import minimum_spike_dtype -from probeinterface import Probe, generate_linear_probe +from probeinterface import Probe, generate_linear_probe, generate_multi_columns_probe from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting from .snippets_tools import snippets_from_sorting @@ -93,7 +93,6 @@ def generate_recording( probe = probe.to_3d() probe.set_device_channel_indices(np.arange(num_channels)) recording.set_probe(probe, in_place=True) - probe = generate_linear_probe(num_elec=num_channels) return recording @@ -122,7 +121,7 @@ def generate_sorting( durations=[10.325, 3.5], #  in s for 2 segments firing_rates=3.0, empty_units=None, - refractory_period_ms=3.0, # in ms + refractory_period_ms=4.0, # in ms add_spikes_on_borders=False, num_spikes_per_border=3, border_size_samples=20, @@ -143,7 +142,7 @@ def generate_sorting( The firing rate of each unit (in Hz). empty_units : list, default: None List of units that will have no spikes. (used for testing mainly). - refractory_period_ms : float, default: 3.0 + refractory_period_ms : float, default: 4.0 The refractory period in ms add_spikes_on_borders : bool, default: False If True, spikes will be added close to the borders of the segments. @@ -881,9 +880,10 @@ def generate_single_fake_waveform( depolarization_ms=(0.09, 0.14), repolarization_ms=(0.5, 0.8), recovery_ms=(1.0, 1.5), - positive_amplitude=(0.05, 0.15), + positive_amplitude=(0.1, 0.25), smooth_ms=(0.03, 0.07), decay_power=(1.4, 1.8), + propagation_speed=(250.0, 350.0), # um / ms ) @@ -931,13 +931,14 @@ def generate_templates( An optional dict containing parameters per units. Keys are parameter names: - * 'alpha': amplitude of the action potential in a.u. (default range: (5'000-15'000)) + * 'alpha': amplitude of the action potential in a.u. (default range: (6'000-9'000)) * 'depolarization_ms': the depolarization interval in ms (default range: (0.09-0.14)) * 'repolarization_ms': the repolarization interval in ms (default range: (0.5-0.8)) * 'recovery_ms': the recovery interval in ms (default range: (1.0-1.5)) * 'positive_amplitude': the positive amplitude in a.u. (default range: (0.05-0.15)) (negative is always -1) * 'smooth_ms': the gaussian smooth in ms (default range: (0.03-0.07)) * 'decay_power': the decay power (default range: (1.2-1.8)) + * 'propagation_speed': mimic a propagation delay with a kind of a "speed" (default range: (250., 350.)). Values contains vector with same size of num_units. If the key is not in dict then it is generated using unit_params_range unit_params_range: dict of tuple @@ -985,12 +986,16 @@ def generate_templates( assert unit_params[k].size == num_units params[k] = unit_params[k] else: - v = rng.random(num_units) if k in unit_params_range: - lim0, lim1 = unit_params_range[k] + lims = unit_params_range[k] + else: + lims = default_unit_params_range[k] + if lims is not None: + lim0, lim1 = lims + v = rng.random(num_units) + params[k] = v * (lim1 - lim0) + lim0 else: - lim0, lim1 = default_unit_params_range[k] - params[k] = v * (lim1 - lim0) + lim0 + params[k] = [None] * num_units for u in range(num_units): wf = generate_single_fake_waveform( @@ -1006,17 +1011,42 @@ def generate_templates( dtype=dtype, ) + ## Add a spatial decay depend on distance from unit to each channel alpha = params["alpha"][u] # the espilon avoid enormous factors eps = 1.0 # naive formula for spatial decay pow = params["decay_power"][u] channel_factors = alpha / (distances[u, :] + eps) ** pow + wfs = wf[:, np.newaxis] * channel_factors[np.newaxis, :] + + # This mimic a propagation delay for distant channel + propagation_speed = params["propagation_speed"][u] + if propagation_speed is not None: + # the speed is um/ms + dist = distances[u, :].copy() + dist -= np.min(dist) + delay_s = dist / propagation_speed / 1000.0 + sample_shifts = delay_s * fs + + # apply the delay with fft transform to get sub sample shift + n = wfs.shape[0] + wfs_f = np.fft.rfft(wfs, axis=0) + if n % 2 == 0: + # n is even sig_f[-1] is nyquist and so pi + omega = np.linspace(0, np.pi, wfs_f.shape[0]) + else: + # n is odd sig_f[-1] is exactly nyquist!! we need (n-1) / n factor!! + omega = np.linspace(0, np.pi * (n - 1) / n, wfs_f.shape[0]) + # broadcast omega and sample_shifts depend the axis + shifts = omega[:, np.newaxis] * sample_shifts[np.newaxis, :] + wfs = np.fft.irfft(wfs_f * np.exp(-1j * shifts), n=n, axis=0) + if upsample_factor is not None: for f in range(upsample_factor): - templates[u, :, :, f] = wf[f::upsample_factor, np.newaxis] * channel_factors[np.newaxis, :] + templates[u, :, :, f] = wfs[f::upsample_factor] else: - templates[u, :, :] = wf[:, np.newaxis] * channel_factors[np.newaxis, :] + templates[u, :, :] = wfs return templates @@ -1322,12 +1352,19 @@ def generate_ground_truth_recording( num_units=10, sorting=None, probe=None, + generate_probe_kwargs=dict( + num_columns=2, + xpitch=20, + ypitch=20, + contact_shapes="circle", + contact_shape_params={"radius": 6}, + ), templates=None, ms_before=1.0, ms_after=3.0, upsample_factor=None, upsample_vector=None, - generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=1.5), + generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0), noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"), generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0), generate_templates_kwargs=dict(), @@ -1350,7 +1387,9 @@ def generate_ground_truth_recording( sorting: Sorting or None An external sorting object. If not provide, one is genrated. probe: Probe or None - An external Probe object. If not provided of linear probe is generated. + An external Probe object. If not provided a probe is generated using generate_probe_kwargs. + generate_probe_kwargs: dict + A dict to constuct the Probe using :py:func:`probeinterface.generate_multi_columns_probe()`. templates: np.array or None The templates of units. If None they are generated. @@ -1407,8 +1446,28 @@ def generate_ground_truth_recording( num_spikes = sorting.to_spike_vector().size if probe is None: - probe = generate_linear_probe(num_elec=num_channels) + # probe = generate_linear_probe(num_elec=num_channels) + # probe.set_device_channel_indices(np.arange(num_channels)) + + prb_kwargs = generate_probe_kwargs.copy() + if "num_contact_per_column" in prb_kwargs: + assert ( + prb_kwargs["num_contact_per_column"] * prb_kwargs["num_columns"] + ) == num_channels, ( + "generate_multi_columns_probe : num_channels do not match num_contact_per_column x num_columns" + ) + elif "num_contact_per_column" not in prb_kwargs and "num_columns" in prb_kwargs: + n = num_channels // prb_kwargs["num_columns"] + num_contact_per_column = [n] * prb_kwargs["num_columns"] + mid = prb_kwargs["num_columns"] // 2 + num_contact_per_column[mid] += num_channels % prb_kwargs["num_columns"] + prb_kwargs["num_contact_per_column"] = num_contact_per_column + else: + raise ValueError("num_columns should be provided in dict generate_probe_kwargs") + + probe = generate_multi_columns_probe(**prb_kwargs) probe.set_device_channel_indices(np.arange(num_channels)) + else: num_channels = probe.get_contact_count()