From eb4e1021017da4066ceac3c39f91b231af6ef30d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 12 Oct 2023 16:58:11 +0200 Subject: [PATCH 1/9] Improve generate.py with spatial on generate_template() --- src/spikeinterface/core/generate.py | 50 +++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index dc84d31987..bbb953ccff 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -881,9 +881,10 @@ def generate_single_fake_waveform( depolarization_ms=(0.09, 0.14), repolarization_ms=(0.5, 0.8), recovery_ms=(1.0, 1.5), - positive_amplitude=(0.05, 0.15), + positive_amplitude=(0.1, 0.25), smooth_ms=(0.03, 0.07), decay_power=(1.4, 1.8), + propagation_speed=(250., 350.), # ms / um ) @@ -938,6 +939,7 @@ def generate_templates( * 'positive_amplitude': the positive amplitude in a.u. (default range: (0.05-0.15)) (negative is always -1) * 'smooth_ms': the gaussian smooth in ms (default range: (0.03-0.07)) * 'decay_power': the decay power (default range: (1.2-1.8)) + * 'propagation_speed': mimic a propagation delay with a kind of a "speed" (default range: (250., 350.)). Values contains vector with same size of num_units. If the key is not in dict then it is generated using unit_params_range unit_params_range: dict of tuple @@ -985,12 +987,17 @@ def generate_templates( assert unit_params[k].size == num_units params[k] = unit_params[k] else: - v = rng.random(num_units) + if k in unit_params_range: - lim0, lim1 = unit_params_range[k] + lims = unit_params_range[k] else: - lim0, lim1 = default_unit_params_range[k] - params[k] = v * (lim1 - lim0) + lim0 + lims = default_unit_params_range[k] + if lims is not None: + lim0, lim1 = lims + v = rng.random(num_units) + params[k] = v * (lim1 - lim0) + lim0 + else: + params[k] = [None] * num_units for u in range(num_units): wf = generate_single_fake_waveform( @@ -1006,17 +1013,46 @@ def generate_templates( dtype=dtype, ) + + ## Add a spatial decay depend on distance from unit to each channel alpha = params["alpha"][u] # the espilon avoid enormous factors eps = 1.0 # naive formula for spatial decay pow = params["decay_power"][u] channel_factors = alpha / (distances[u, :] + eps) ** pow + wfs = wf[:, np.newaxis] * channel_factors[np.newaxis, :] + + # This mimic a propagation delay for distant channel + propagation_speed = params["propagation_speed"][u] + if propagation_speed is not None: + # the speed is um/ms + dist = distances[u, :].copy() + dist -= np.min(dist) + delay_s = dist / propagation_speed / 1000. + sample_shifts = delay_s * fs + + # apply the delay with fft transform to get sub sample shift + n = wfs.shape[0] + wfs_f = np.fft.rfft(wfs, axis=0) + if n % 2 == 0: + # n is even sig_f[-1] is nyquist and so pi + omega = np.linspace(0, np.pi, wfs_f.shape[0]) + else: + # n is odd sig_f[-1] is exactly nyquist!! we need (n-1) / n factor!! + omega = np.linspace(0, np.pi * (n - 1) / n, wfs_f.shape[0]) + # broadcast omega and sample_shifts depend the axis + shifts = omega[:, np.newaxis] * sample_shifts[np.newaxis, :] + wfs = np.fft.irfft(wfs_f * np.exp(-1j * shifts), n=n, axis=0) + if upsample_factor is not None: for f in range(upsample_factor): - templates[u, :, :, f] = wf[f::upsample_factor, np.newaxis] * channel_factors[np.newaxis, :] + templates[u, :, :, f] = wfs[f::upsample_factor] else: - templates[u, :, :] = wf[:, np.newaxis] * channel_factors[np.newaxis, :] + templates[u, :, :] = wfs + + + return templates From 0fd12553a70582c3a5cac5b007fe32ac439ddb48 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 16 Oct 2023 19:45:43 +0200 Subject: [PATCH 2/9] Use multi columns probe in generate_ground_truth_recording() --- src/spikeinterface/core/generate.py | 36 +++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index bbb953ccff..6bb5a384e6 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -8,7 +8,7 @@ from .numpyextractors import NumpyRecording, NumpySorting from .basesorting import minimum_spike_dtype -from probeinterface import Probe, generate_linear_probe +from probeinterface import Probe, generate_linear_probe, generate_multi_columns_probe from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting from .snippets_tools import snippets_from_sorting @@ -93,7 +93,6 @@ def generate_recording( probe = probe.to_3d() probe.set_device_channel_indices(np.arange(num_channels)) recording.set_probe(probe, in_place=True) - probe = generate_linear_probe(num_elec=num_channels) return recording @@ -1358,6 +1357,13 @@ def generate_ground_truth_recording( num_units=10, sorting=None, probe=None, + generate_probe_kwargs=dict( + num_columns=2, + xpitch=20, + ypitch=20, + contact_shapes="circle", + contact_shape_params={"radius": 6}, + ), templates=None, ms_before=1.0, ms_after=3.0, @@ -1386,7 +1392,9 @@ def generate_ground_truth_recording( sorting: Sorting or None An external sorting object. If not provide, one is genrated. probe: Probe or None - An external Probe object. If not provided of linear probe is generated. + An external Probe object. If not provided of probe is generated using generate_probe_kwargs. + generate_probe_kwargs: dict + A dict to constuct the Probe using :pyp:func:`probeinterface.generate_multi_columns_probe()`. templates: np.array or None The templates of units. If None they are generated. @@ -1443,8 +1451,28 @@ def generate_ground_truth_recording( num_spikes = sorting.to_spike_vector().size if probe is None: - probe = generate_linear_probe(num_elec=num_channels) + # probe = generate_linear_probe(num_elec=num_channels) + # probe.set_device_channel_indices(np.arange(num_channels)) + + prb_kwargs = generate_probe_kwargs.copy() + if 'num_contact_per_column' in prb_kwargs: + assert (prb_kwargs['num_contact_per_column'] * prb_kwargs['num_columns']) == num_channels, \ + "generate_multi_columns_probe : num_channels do not match num_contact_per_column x num_columns" + elif 'num_contact_per_column' not in prb_kwargs and 'num_columns' in prb_kwargs: + n = num_channels // prb_kwargs['num_columns'] + num_contact_per_column = [n] * prb_kwargs['num_columns'] + mid = prb_kwargs['num_columns'] // 2 + num_contact_per_column[mid] += num_channels % prb_kwargs['num_columns'] + prb_kwargs['num_contact_per_column'] = num_contact_per_column + else: + raise ValueError("num_columns should be provided in dict generate_probe_kwargs") + + print(prb_kwargs) + probe = generate_multi_columns_probe(**prb_kwargs) probe.set_device_channel_indices(np.arange(num_channels)) + print(probe) + + else: num_channels = probe.get_contact_count() From 0b57f4dccb6d2561737362b8351993eaac0939d3 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 17 Oct 2023 10:29:27 +0200 Subject: [PATCH 3/9] harmonize refactory period in generate.py --- src/spikeinterface/core/generate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 6bb5a384e6..d6924f6f4f 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -121,7 +121,7 @@ def generate_sorting( durations=[10.325, 3.5], #  in s for 2 segments firing_rates=3.0, empty_units=None, - refractory_period_ms=3.0, # in ms + refractory_period_ms=4.0, # in ms add_spikes_on_borders=False, num_spikes_per_border=3, border_size_samples=20, @@ -142,7 +142,7 @@ def generate_sorting( The firing rate of each unit (in Hz). empty_units : list, default: None List of units that will have no spikes. (used for testing mainly). - refractory_period_ms : float, default: 3.0 + refractory_period_ms : float, default: 4.0 The refractory period in ms add_spikes_on_borders : bool, default: False If True, spikes will be added close to the borders of the segments. @@ -1369,7 +1369,7 @@ def generate_ground_truth_recording( ms_after=3.0, upsample_factor=None, upsample_vector=None, - generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=1.5), + generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.), noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"), generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0), generate_templates_kwargs=dict(), From 737812f22560d2390ea5869f3579dfede3e7c28d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 17 Oct 2023 10:31:56 +0200 Subject: [PATCH 4/9] clean --- src/spikeinterface/core/generate.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index d6924f6f4f..de69af85f3 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1467,11 +1467,8 @@ def generate_ground_truth_recording( else: raise ValueError("num_columns should be provided in dict generate_probe_kwargs") - print(prb_kwargs) probe = generate_multi_columns_probe(**prb_kwargs) probe.set_device_channel_indices(np.arange(num_channels)) - print(probe) - else: num_channels = probe.get_contact_count() From 296751249af7e19ebdbec5586d3d08f31838ca74 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 23 Oct 2023 20:45:55 +0200 Subject: [PATCH 5/9] Update src/spikeinterface/core/generate.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index de69af85f3..7ac6f0cd36 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1392,7 +1392,7 @@ def generate_ground_truth_recording( sorting: Sorting or None An external sorting object. If not provide, one is genrated. probe: Probe or None - An external Probe object. If not provided of probe is generated using generate_probe_kwargs. + An external Probe object. If not provided a probe is generated using generate_probe_kwargs. generate_probe_kwargs: dict A dict to constuct the Probe using :pyp:func:`probeinterface.generate_multi_columns_probe()`. templates: np.array or None From 383ee2ac4f91f88df41c12ddf87f5a745dfcc87b Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 23 Oct 2023 20:46:07 +0200 Subject: [PATCH 6/9] Update src/spikeinterface/core/generate.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 7ac6f0cd36..e876888186 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1394,7 +1394,7 @@ def generate_ground_truth_recording( probe: Probe or None An external Probe object. If not provided a probe is generated using generate_probe_kwargs. generate_probe_kwargs: dict - A dict to constuct the Probe using :pyp:func:`probeinterface.generate_multi_columns_probe()`. + A dict to constuct the Probe using :py:func:`probeinterface.generate_multi_columns_probe()`. templates: np.array or None The templates of units. If None they are generated. From 439d2fed55b376ec0e011c6903d0211d9421d695 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 18:46:14 +0000 Subject: [PATCH 7/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/generate.py | 30 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index e876888186..0285d60f4f 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -883,7 +883,7 @@ def generate_single_fake_waveform( positive_amplitude=(0.1, 0.25), smooth_ms=(0.03, 0.07), decay_power=(1.4, 1.8), - propagation_speed=(250., 350.), # ms / um + propagation_speed=(250.0, 350.0), # ms / um ) @@ -986,7 +986,6 @@ def generate_templates( assert unit_params[k].size == num_units params[k] = unit_params[k] else: - if k in unit_params_range: lims = unit_params_range[k] else: @@ -1012,7 +1011,6 @@ def generate_templates( dtype=dtype, ) - ## Add a spatial decay depend on distance from unit to each channel alpha = params["alpha"][u] # the espilon avoid enormous factors @@ -1028,7 +1026,7 @@ def generate_templates( # the speed is um/ms dist = distances[u, :].copy() dist -= np.min(dist) - delay_s = dist / propagation_speed / 1000. + delay_s = dist / propagation_speed / 1000.0 sample_shifts = delay_s * fs # apply the delay with fft transform to get sub sample shift @@ -1050,9 +1048,6 @@ def generate_templates( else: templates[u, :, :] = wfs - - - return templates @@ -1369,7 +1364,7 @@ def generate_ground_truth_recording( ms_after=3.0, upsample_factor=None, upsample_vector=None, - generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.), + generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0), noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"), generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0), generate_templates_kwargs=dict(), @@ -1455,15 +1450,18 @@ def generate_ground_truth_recording( # probe.set_device_channel_indices(np.arange(num_channels)) prb_kwargs = generate_probe_kwargs.copy() - if 'num_contact_per_column' in prb_kwargs: - assert (prb_kwargs['num_contact_per_column'] * prb_kwargs['num_columns']) == num_channels, \ + if "num_contact_per_column" in prb_kwargs: + assert ( + prb_kwargs["num_contact_per_column"] * prb_kwargs["num_columns"] + ) == num_channels, ( "generate_multi_columns_probe : num_channels do not match num_contact_per_column x num_columns" - elif 'num_contact_per_column' not in prb_kwargs and 'num_columns' in prb_kwargs: - n = num_channels // prb_kwargs['num_columns'] - num_contact_per_column = [n] * prb_kwargs['num_columns'] - mid = prb_kwargs['num_columns'] // 2 - num_contact_per_column[mid] += num_channels % prb_kwargs['num_columns'] - prb_kwargs['num_contact_per_column'] = num_contact_per_column + ) + elif "num_contact_per_column" not in prb_kwargs and "num_columns" in prb_kwargs: + n = num_channels // prb_kwargs["num_columns"] + num_contact_per_column = [n] * prb_kwargs["num_columns"] + mid = prb_kwargs["num_columns"] // 2 + num_contact_per_column[mid] += num_channels % prb_kwargs["num_columns"] + prb_kwargs["num_contact_per_column"] = num_contact_per_column else: raise ValueError("num_columns should be provided in dict generate_probe_kwargs") From 2aedd47ec47edc048e312a38434f90200ef98d22 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 23 Oct 2023 20:49:09 +0200 Subject: [PATCH 8/9] generate alpha doc --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 0285d60f4f..e879651ae7 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -931,7 +931,7 @@ def generate_templates( An optional dict containing parameters per units. Keys are parameter names: - * 'alpha': amplitude of the action potential in a.u. (default range: (5'000-15'000)) + * 'alpha': amplitude of the action potential in a.u. (default range: (6'000-9'000)) * 'depolarization_ms': the depolarization interval in ms (default range: (0.09-0.14)) * 'repolarization_ms': the repolarization interval in ms (default range: (0.5-0.8)) * 'recovery_ms': the recovery interval in ms (default range: (1.0-1.5)) From 66280e4354fa3e8a8f1c1b923005bf882c78d5a8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 24 Oct 2023 07:14:21 +0200 Subject: [PATCH 9/9] Update src/spikeinterface/core/generate.py --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index e879651ae7..44ea02d32c 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -883,7 +883,7 @@ def generate_single_fake_waveform( positive_amplitude=(0.1, 0.25), smooth_ms=(0.03, 0.07), decay_power=(1.4, 1.8), - propagation_speed=(250.0, 350.0), # ms / um + propagation_speed=(250.0, 350.0), # um / ms )