diff --git a/doc/api.rst b/doc/api.rst index c73cd812da..3e825084e7 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -338,14 +338,60 @@ spikeinterface.curation spikeinterface.generation ------------------------- +Core +~~~~ .. automodule:: spikeinterface.generation + .. autofunction:: generate_recording + .. autofunction:: generate_sorting + .. autofunction:: generate_snippets + .. autofunction:: generate_templates + .. autofunction:: generate_recording_by_size + .. autofunction:: generate_ground_truth_recording + .. autofunction:: add_synchrony_to_sorting + .. autofunction:: synthesize_random_firings + .. autofunction:: inject_some_duplicate_units + .. autofunction:: inject_some_split_units + .. autofunction:: synthetize_spike_train_bad_isi + .. autofunction:: inject_templates + .. autofunction:: noise_generator_recording + .. autoclass:: InjectTemplatesRecording + .. autoclass:: NoiseGeneratorRecording + +Drift +~~~~~ +.. automodule:: spikeinterface.generation + + .. autofunction:: generate_drifting_recording + .. autofunction:: generate_displacement_vector + .. autofunction:: make_one_displacement_vector .. autofunction:: make_linear_displacement .. autofunction:: move_dense_templates .. autofunction:: interpolate_templates .. autoclass:: DriftingTemplates .. autoclass:: InjectDriftingTemplatesRecording +Hybrid +~~~~~~ +.. automodule:: spikeinterface.generation + + .. autofunction:: generate_hybrid_recording + .. autofunction:: estimate_templates_from_recording + .. autofunction:: select_templates + .. autofunction:: scale_template_to_range + .. autofunction:: relocate_templates + .. autofunction:: fetch_template_object_from_database + .. autofunction:: fetch_templates_database_info + .. autofunction:: list_available_datasets_in_template_database + .. autofunction:: query_templates_from_database + + +Noise +~~~~~ +.. automodule:: spikeinterface.generation + + .. autofunction:: generate_noise + spikeinterface.sortingcomponents -------------------------------- diff --git a/doc/how_to/benchmark_with_hybrid_recordings.rst b/doc/how_to/benchmark_with_hybrid_recordings.rst index 9e8c6c7d65..5870d87955 100644 --- a/doc/how_to/benchmark_with_hybrid_recordings.rst +++ b/doc/how_to/benchmark_with_hybrid_recordings.rst @@ -9,7 +9,7 @@ with known spiking activity. The template (aka average waveforms) of the injected units can be from previous spike sorted data. In this example, we will be using an open database of templates that we have constructed from the International Brain Laboratory - Brain Wide Map (available on -`DANDI `__). +`DANDI `_). Importantly, recordings from long-shank probes, such as Neuropixels, usually experience drifts. Such drifts have to be taken into account in diff --git a/doc/modules/generation.rst b/doc/modules/generation.rst index a647919489..191cb57f30 100644 --- a/doc/modules/generation.rst +++ b/doc/modules/generation.rst @@ -1,9 +1,28 @@ Generation module ================= -The :py:mod:`spikeinterface.generation` provides functions to generate recordings containing spikes. -This module proposes several approaches for this including purely synthetic recordings as well as "hybrid" recordings (where templates come from true datasets). +The :py:mod:`spikeinterface.generation` provides functions to generate recordings containing spikes, +which can be used as "ground-truth" for benchmarking spike sorting algorithms. +There are several approaches to generating such recordings. +One possibility is to generate purely synthetic recordings. Another approach is to use real +recordings and add synthetic spikes to them, to make "hybrid" recordings. +The advantage of the former is that the ground-truth is known exactly, which is useful for benchmarking. +The advantage of the latter is that the spikes are added to real noise, which can be more realistic. -The :py:mod:`spikeinterface.core.generate` already provides functions for generating synthetic data but this module will supply an extended and more complex -machinery, for instance generating recordings that possess various types of drift. +For hybrid recordings, the main challenge is to generate realistic spike templates. +We therefore built an open database of templates that we have constructed from the International +Brain Laboratory - Brain Wide Map (available on +`DANDI `_). +You can check out this collection of over 600 templates from this `web app `_. + +The :py:mod:`spikeinterface.generation` module offers tools to interact with this database to select and download templates, +manupulating (e.g. rescaling and relocating them), and construct hybrid recordings with them. +Importantly, recordings from long-shank probes, such as Neuropixels, usually experience drifts. +Such drifts can be taken into account in order to smoothly inject spikes into the recording. + +The :py:mod:`spikeinterface.generation` also includes functions to generate different kinds of drift signals and drifting +recordings, as well as generating synthetic noise profiles of various types. + +Some of the generation functions are defined in the :py:mod:`spikeinterface.core.generate` module, but also exposed at the +:py:mod:`spikeinterface.generation` level for convenience. diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 4f3977d7bb..f5312f9c46 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -103,11 +103,11 @@ def generate_sorting( Parameters ---------- num_units : int, default: 5 - Number of units + Number of units. sampling_frequency : float, default: 30000.0 - The sampling frequency + The sampling frequency. durations : list, default: [10.325, 3.5] - Duration of each segment in s + Duration of each segment in s. firing_rates : float, default: 3.0 The firing rate of each unit (in Hz). empty_units : list, default: None @@ -123,12 +123,12 @@ def generate_sorting( border_size_samples : int, default: 20 The size of the border in samples to add border spikes. seed : int, default: None - The random seed + The random seed. Returns ------- sorting : NumpySorting - The sorting object + The sorting object. """ seed = _ensure_seed(seed) rng = np.random.default_rng(seed) @@ -187,19 +187,19 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): Parameters ---------- sorting : BaseSorting - The sorting object + The sorting object. sync_event_ratio : float The ratio of added synchronous spikes with respect to the total number of spikes. E.g., 0.5 means that the final sorting will have 1.5 times number of spikes, and all the extra spikes are synchronous (same sample_index), but on different units (not duplicates). seed : int, default: None - The random seed + The random seed. Returns ------- sorting : TransformSorting - The sorting object, keeping track of added spikes + The sorting object, keeping track of added spikes. """ rng = np.random.default_rng(seed) @@ -249,18 +249,18 @@ def generate_sorting_to_inject( Parameters ---------- sorting : BaseSorting - The sorting object + The sorting object. num_samples: list of size num_segments. The number of samples in all the segments of the sorting, to generate spike times - covering entire the entire duration of the segments + covering entire the entire duration of the segments. max_injected_per_unit: int, default 1000 - The maximal number of spikes injected per units + The maximal number of spikes injected per units. injected_rate: float, default 0.05 - The rate at which spikes are injected + The rate at which spikes are injected. refractory_period_ms: float, default 1.5 - The refractory period that should not be violated while injecting new spikes + The refractory period that should not be violated while injecting new spikes. seed: int, default None - The random seed + The random seed. Returns ------- @@ -312,22 +312,22 @@ class TransformSorting(BaseSorting): Parameters ---------- sorting : BaseSorting - The sorting object + The sorting object. added_spikes_existing_units : np.array (spike_vector) - The spikes that should be added to the sorting object, for existing units + The spikes that should be added to the sorting object, for existing units. added_spikes_new_units: np.array (spike_vector) - The spikes that should be added to the sorting object, for new units + The spikes that should be added to the sorting object, for new units. new_units_ids: list - The unit_ids that should be added if spikes for new units are added + The unit_ids that should be added if spikes for new units are added. refractory_period_ms : float, default None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be - discarded + discarded. Returns ------- sorting : TransformSorting - The sorting object with the added spikes and/or units + The sorting object with the added spikes and/or units. """ def __init__( @@ -428,12 +428,14 @@ def add_from_sorting(sorting1: BaseSorting, sorting2: BaseSorting, refractory_pe Parameters ---------- - sorting1: the first sorting - sorting2: the second sorting + sorting1: BaseSorting + The first sorting. + sorting2: BaseSorting + The second sorting. refractory_period_ms : float, default None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be - discarded + discarded. """ assert ( sorting1.get_sampling_frequency() == sorting2.get_sampling_frequency() @@ -492,12 +494,14 @@ def add_from_unit_dict( Parameters ---------- - sorting1: the first sorting + sorting1: BaseSorting + The first sorting dict_list: list of dict + A list of dict with unit_ids as keys and spike times as values. refractory_period_ms : float, default None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be - discarded + discarded. """ sorting2 = NumpySorting.from_unit_dict(units_dict_list, sorting1.get_sampling_frequency()) sorting = TransformSorting.add_from_sorting(sorting1, sorting2, refractory_period_ms) @@ -515,18 +519,19 @@ def from_times_labels( Parameters ---------- - sorting1: the first sorting + sorting1: BaseSorting + The first sorting times_list: list of array (or array) - An array of spike times (in frames) + An array of spike times (in frames). labels_list: list of array (or array) - An array of spike labels corresponding to the given times + An array of spike labels corresponding to the given times. unit_ids: list or None, default: None The explicit list of unit_ids that should be extracted from labels_list - If None, then it will be np.unique(labels_list) + If None, then it will be np.unique(labels_list). refractory_period_ms : float, default None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be - discarded + discarded. """ sorting2 = NumpySorting.from_times_labels(times_list, labels_list, sampling_frequency, unit_ids) @@ -556,6 +561,16 @@ def clean_refractory_period(self): def create_sorting_npz(num_seg, file_path): + """ + Create a NPZ sorting file. + + Parameters + ---------- + num_seg : int + The number of segments. + file_path : str | Path + The file path to save the NPZ file. + """ # create a NPZ sorting file d = {} d["unit_ids"] = np.array([0, 1, 2], dtype="int64") @@ -585,6 +600,35 @@ def generate_snippets( empty_units=None, **job_kwargs, ): + """ + Generates a synthetic Snippets object. + + Parameters + ---------- + nbefore : int, default: 20 + Number of samples before the peak. + nafter : int, default: 44 + Number of samples after the peak. + num_channels : int, default: 2 + Number of channels. + wf_folder : str | Path | None, default: None + Optional folder to save the waveform snippets. If None, snippets are in memory. + sampling_frequency : float, default: 30000.0 + The sampling frequency of the snippets. + ndim : int, default: 2 + The number of dimensions of the probe. + num_units : int, default: 5 + The number of units. + empty_units : list | None, default: None + A list of units that will have no spikes. + + Returns + ------- + snippets : NumpySnippets + The snippets object. + sorting : NumpySorting + The associated sorting object. + """ recording = generate_recording( durations=durations, num_channels=num_channels, @@ -645,18 +689,18 @@ def synthesize_poisson_spike_vector( Parameters ---------- num_units : int, default: 20 - Number of neuronal units to simulate + Number of neuronal units to simulate. sampling_frequency : float, default: 30000.0 - Sampling frequency in Hz + Sampling frequency in Hz. duration : float, default: 60.0 - Duration of the simulation in seconds + Duration of the simulation in seconds. refractory_period_ms : float, default: 4.0 - Refractory period between spikes in milliseconds + Refractory period between spikes in milliseconds. firing_rates : float or array_like or tuple, default: 3.0 Firing rate(s) in Hz. Can be a single value for all units or an array of firing rates with - each element being the firing rate for one unit + each element being the firing rate for one unit. seed : int, default: 0 - Seed for random number generator + Seed for random number generator. Returns ------- @@ -750,27 +794,27 @@ def synthesize_random_firings( Parameters ---------- num_units : int - number of units + Number of units. sampling_frequency : float - sampling rate + Sampling rate. duration : float - duration of the segment in seconds + Duration of the segment in seconds. refractory_period_ms: float - refractory_period in ms + Refractory period in ms. firing_rates: float or list[float] The firing rate of each unit (in Hz). If float, all units will have the same firing rate. add_shift_shuffle: bool, default: False Optionally add a small shuffle on half of the spikes to make the autocorrelogram less flat. seed: int, default: None - seed for the generator + Seed for the generator. Returns ------- - times: - Concatenated and sorted times vector - labels: - Concatenated and sorted label vector + times: np.array + Concatenated and sorted times vector. + labels: np.array + Concatenated and sorted label vector. """ @@ -854,11 +898,11 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No Parameters ---------- sorting : - Original sorting + Original sorting. num : int - Number of injected units + Number of injected units. max_shift : int - range of the shift in sample + range of the shift in sample. ratio: float Proportion of original spike in the injected units. @@ -909,8 +953,27 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No def inject_some_split_units(sorting, split_ids: list, num_split=2, output_ids=False, seed=None): - """ """ + """ + Inject some split units in a sorting. + + Parameters + ---------- + sorting : BaseSorting + Original sorting. + split_ids : list + List of unit_ids to split. + num_split : int, default: 2 + Number of split units. + output_ids : bool, default: False + If True, return the new unit_ids. + seed : int, default: None + Random seed. + Returns + ------- + sorting_with_split : NumpySorting + A sorting with split units. + """ unit_ids = sorting.unit_ids assert unit_ids.dtype.kind == "i" @@ -960,7 +1023,7 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol num_violations : int Number of contaminating spikes. violation_delta : float, default: 1e-5 - Temporal offset of contaminating spikes (in seconds) + Temporal offset of contaminating spikes (in seconds). Returns ------- @@ -1217,7 +1280,7 @@ def generate_recording_by_size( num_channels: int Number of channels. seed : int, default: None - The seed for np.random.default_rng + The seed for np.random.default_rng. Returns ------- @@ -1617,7 +1680,7 @@ class InjectTemplatesRecording(BaseRecording): * (num_units, num_samples, num_channels): standard case * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce sampling jitter. nbefore: list[int] | int | None, default: None - Where is the center of the template for each unit? + The number of samples before the peak of the template to align the spike. If None, will default to the highest peak. amplitude_factor: list[float] | float | None, default: None The amplitude of each spike for each unit. @@ -1632,7 +1695,7 @@ class InjectTemplatesRecording(BaseRecording): You can use int for mono-segment objects. upsample_vector: np.array or None, default: None. When templates is 4d we can simulate a jitter. - Optional the upsample_vector is the jitter index with a number per spike in range 0-templates.sahpe[3] + Optional the upsample_vector is the jitter index with a number per spike in range 0-templates.shape[3]. Returns ------- diff --git a/src/spikeinterface/generation/__init__.py b/src/spikeinterface/generation/__init__.py index 7a2291d932..5bf42ecf0f 100644 --- a/src/spikeinterface/generation/__init__.py +++ b/src/spikeinterface/generation/__init__.py @@ -14,6 +14,7 @@ relocate_templates, ) from .noise_tools import generate_noise + from .drifting_generator import ( make_one_displacement_vector, generate_displacement_vector, @@ -26,3 +27,22 @@ list_available_datasets_in_template_database, query_templates_from_database, ) + +# expose the core generate functions +from ..core.generate import ( + generate_recording, + generate_sorting, + generate_snippets, + generate_templates, + generate_recording_by_size, + generate_ground_truth_recording, + add_synchrony_to_sorting, + synthesize_random_firings, + inject_some_duplicate_units, + inject_some_split_units, + synthetize_spike_train_bad_isi, + NoiseGeneratorRecording, + noise_generator_recording, + InjectTemplatesRecording, + inject_templates, +) diff --git a/src/spikeinterface/generation/noise_tools.py b/src/spikeinterface/generation/noise_tools.py index 11f30e352f..685f0113b4 100644 --- a/src/spikeinterface/generation/noise_tools.py +++ b/src/spikeinterface/generation/noise_tools.py @@ -7,22 +7,25 @@ def generate_noise( probe, sampling_frequency, durations, dtype="float32", noise_levels=15.0, spatial_decay=None, seed=None ): """ + Generate a noise recording. Parameters ---------- probe : Probe A probe object. sampling_frequency : float - Sampling frequency + The sampling frequency of the recording. durations : list of float - Durations + The duration(s) of the recording segment(s) in seconds. dtype : np.dtype - Dtype - noise_levels : float | np.array | tuple + The dtype of the recording. + noise_levels : float | np.array | tuple, default: 15.0 If scalar same noises on all channels. If array then per channels noise level. If tuple, then this represent the range. - seed : None | int + spatial_decay : float | None, default: None + If not None, the spatial decay of the noise used to generate the noise covariance matrix. + seed : int | None, default: None The seed for random generator. Returns diff --git a/src/spikeinterface/generation/template_database.py b/src/spikeinterface/generation/template_database.py index e1cba07c8e..17d2bdf521 100644 --- a/src/spikeinterface/generation/template_database.py +++ b/src/spikeinterface/generation/template_database.py @@ -20,7 +20,7 @@ def fetch_template_object_from_database(dataset="test_templates.zarr") -> Templa Returns ------- Templates - _description_ + The templates object. """ s3_path = f"s3://spikeinterface-template-database/{dataset}/" zarr_group = zarr.open_consolidated(s3_path, storage_options={"anon": True})