diff --git a/doc/api.rst b/doc/api.rst index c73cd812da..3e825084e7 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -338,14 +338,60 @@ spikeinterface.curation spikeinterface.generation ------------------------- +Core +~~~~ .. automodule:: spikeinterface.generation + .. autofunction:: generate_recording + .. autofunction:: generate_sorting + .. autofunction:: generate_snippets + .. autofunction:: generate_templates + .. autofunction:: generate_recording_by_size + .. autofunction:: generate_ground_truth_recording + .. autofunction:: add_synchrony_to_sorting + .. autofunction:: synthesize_random_firings + .. autofunction:: inject_some_duplicate_units + .. autofunction:: inject_some_split_units + .. autofunction:: synthetize_spike_train_bad_isi + .. autofunction:: inject_templates + .. autofunction:: noise_generator_recording + .. autoclass:: InjectTemplatesRecording + .. autoclass:: NoiseGeneratorRecording + +Drift +~~~~~ +.. automodule:: spikeinterface.generation + + .. autofunction:: generate_drifting_recording + .. autofunction:: generate_displacement_vector + .. autofunction:: make_one_displacement_vector .. autofunction:: make_linear_displacement .. autofunction:: move_dense_templates .. autofunction:: interpolate_templates .. autoclass:: DriftingTemplates .. autoclass:: InjectDriftingTemplatesRecording +Hybrid +~~~~~~ +.. automodule:: spikeinterface.generation + + .. autofunction:: generate_hybrid_recording + .. autofunction:: estimate_templates_from_recording + .. autofunction:: select_templates + .. autofunction:: scale_template_to_range + .. autofunction:: relocate_templates + .. autofunction:: fetch_template_object_from_database + .. autofunction:: fetch_templates_database_info + .. autofunction:: list_available_datasets_in_template_database + .. autofunction:: query_templates_from_database + + +Noise +~~~~~ +.. automodule:: spikeinterface.generation + + .. autofunction:: generate_noise + spikeinterface.sortingcomponents -------------------------------- diff --git a/doc/how_to/benchmark_with_hybrid_recordings.rst b/doc/how_to/benchmark_with_hybrid_recordings.rst index 9e8c6c7d65..5870d87955 100644 --- a/doc/how_to/benchmark_with_hybrid_recordings.rst +++ b/doc/how_to/benchmark_with_hybrid_recordings.rst @@ -9,7 +9,7 @@ with known spiking activity. The template (aka average waveforms) of the injected units can be from previous spike sorted data. In this example, we will be using an open database of templates that we have constructed from the International Brain Laboratory - Brain Wide Map (available on -`DANDI `__). +`DANDI `_). Importantly, recordings from long-shank probes, such as Neuropixels, usually experience drifts. Such drifts have to be taken into account in diff --git a/doc/modules/generation.rst b/doc/modules/generation.rst index a647919489..191cb57f30 100644 --- a/doc/modules/generation.rst +++ b/doc/modules/generation.rst @@ -1,9 +1,28 @@ Generation module ================= -The :py:mod:`spikeinterface.generation` provides functions to generate recordings containing spikes. -This module proposes several approaches for this including purely synthetic recordings as well as "hybrid" recordings (where templates come from true datasets). +The :py:mod:`spikeinterface.generation` provides functions to generate recordings containing spikes, +which can be used as "ground-truth" for benchmarking spike sorting algorithms. +There are several approaches to generating such recordings. +One possibility is to generate purely synthetic recordings. Another approach is to use real +recordings and add synthetic spikes to them, to make "hybrid" recordings. +The advantage of the former is that the ground-truth is known exactly, which is useful for benchmarking. +The advantage of the latter is that the spikes are added to real noise, which can be more realistic. -The :py:mod:`spikeinterface.core.generate` already provides functions for generating synthetic data but this module will supply an extended and more complex -machinery, for instance generating recordings that possess various types of drift. +For hybrid recordings, the main challenge is to generate realistic spike templates. +We therefore built an open database of templates that we have constructed from the International +Brain Laboratory - Brain Wide Map (available on +`DANDI `_). +You can check out this collection of over 600 templates from this `web app `_. + +The :py:mod:`spikeinterface.generation` module offers tools to interact with this database to select and download templates, +manupulating (e.g. rescaling and relocating them), and construct hybrid recordings with them. +Importantly, recordings from long-shank probes, such as Neuropixels, usually experience drifts. +Such drifts can be taken into account in order to smoothly inject spikes into the recording. + +The :py:mod:`spikeinterface.generation` also includes functions to generate different kinds of drift signals and drifting +recordings, as well as generating synthetic noise profiles of various types. + +Some of the generation functions are defined in the :py:mod:`spikeinterface.core.generate` module, but also exposed at the +:py:mod:`spikeinterface.generation` level for convenience. diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 37e7b83e62..b2d147ce4c 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -7,7 +7,6 @@ import weakref import json import pickle -import os import random import string from packaging.version import parse @@ -41,7 +40,7 @@ class BaseExtractor: # This replaces the old key_properties # These are annotations/properties that always need to be # dumped (for instance locations, groups, is_fileterd, etc.) - _main_annotations = [] + _main_annotations = ["name"] _main_properties = [] # these properties are skipped by default in copy_metadata @@ -79,6 +78,19 @@ def __init__(self, main_ids: Sequence) -> None: # preferred context for multiprocessing self._preferred_mp_context = None + @property + def name(self): + name = self._annotations.get("name", None) + return name if name is not None else self.__class__.__name__ + + @name.setter + def name(self, value): + if value is not None: + self.annotate(name=value) + else: + # we remove the annotation if it exists + _ = self._annotations.pop("name", None) + def get_num_segments(self) -> int: # This is implemented in BaseRecording or BaseSorting raise NotImplementedError @@ -941,10 +953,11 @@ def save_to_folder( provenance_file_path = folder / f"provenance.json" if self.check_serializability("json"): self.dump_to_json(file_path=provenance_file_path, relative_to=folder) + elif self.check_serializability("pickle"): + provenance_file = folder / f"provenance.pkl" + self.dump_to_pickle(provenance_file, relative_to=folder) else: - provenance_file_path.write_text( - json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8" - ) + warnings.warn("The extractor is not serializable to file. The provenance will not be saved.") self.save_metadata_to_folder(folder) @@ -1012,7 +1025,6 @@ def save_to_zarr( cached: ZarrExtractor Saved copy of the extractor. """ - import zarr from .zarrextractors import read_zarr save_kwargs.pop("format", None) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index e70c95bb65..0ea9426674 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -23,7 +23,7 @@ class BaseRecording(BaseRecordingSnippets): Internally handle list of RecordingSegment """ - _main_annotations = ["is_filtered"] + _main_annotations = BaseRecordingSnippets._main_annotations + ["is_filtered"] _main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"] _main_features = [] # recording do not handle features @@ -45,9 +45,6 @@ def __init__(self, sampling_frequency: float, channel_ids: list, dtype): self.annotate(is_filtered=False) def __repr__(self): - - class_name = self.__class__.__name__ - name_to_display = class_name num_segments = self.get_num_segments() txt = self._repr_header() @@ -57,7 +54,7 @@ def __repr__(self): split_index = txt.rfind("-", 0, 100) # Find the last "-" before character 100 if split_index != -1: first_line = txt[:split_index] - recording_string_space = len(name_to_display) + 2 # Length of name_to_display plus ": " + recording_string_space = len(self.name) + 2 # Length of self.name plus ": " white_space_to_align_with_first_line = " " * recording_string_space second_line = white_space_to_align_with_first_line + txt[split_index + 1 :].lstrip() txt = first_line + "\n" + second_line @@ -97,21 +94,21 @@ def list_to_string(lst, max_size=6): return txt def _repr_header(self): - class_name = self.__class__.__name__ - name_to_display = class_name num_segments = self.get_num_segments() num_channels = self.get_num_channels() - sf_khz = self.get_sampling_frequency() / 1000.0 + sf_hz = self.get_sampling_frequency() + sf_khz = sf_hz / 1000 dtype = self.get_dtype() total_samples = self.get_total_samples() total_duration = self.get_total_duration() total_memory_size = self.get_total_memory_size() + sampling_frequency_repr = f"{sf_khz:0.1f}kHz" if sf_hz > 10_000.0 else f"{sf_hz:0.1f}Hz" txt = ( - f"{name_to_display}: " + f"{self.name}: " f"{num_channels} channels - " - f"{sf_khz:0.1f}kHz - " + f"{sampling_frequency_repr} - " f"{num_segments} segments - " f"{total_samples:,} samples - " f"{convert_seconds_to_str(total_duration)} - " @@ -501,24 +498,35 @@ def time_to_sample_index(self, time_s, segment_index=None): rs = self._recording_segments[segment_index] return rs.time_to_sample_index(time_s) - def _save(self, format="binary", verbose: bool = False, **save_kwargs): + def _get_t_starts(self): # handle t_starts t_starts = [] has_time_vectors = [] - for segment_index, rs in enumerate(self._recording_segments): + for rs in self._recording_segments: d = rs.get_times_kwargs() t_starts.append(d["t_start"]) - has_time_vectors.append(d["time_vector"] is not None) if all(t_start is None for t_start in t_starts): t_starts = None + return t_starts + + def _get_time_vectors(self): + time_vectors = [] + for rs in self._recording_segments: + d = rs.get_times_kwargs() + time_vectors.append(d["time_vector"]) + if all(time_vector is None for time_vector in time_vectors): + time_vectors = None + return time_vectors + def _save(self, format="binary", verbose: bool = False, **save_kwargs): kwargs, job_kwargs = split_job_kwargs(save_kwargs) if format == "binary": folder = kwargs["folder"] file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())] dtype = kwargs.get("dtype", None) or self.get_dtype() + t_starts = self._get_t_starts() write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs) @@ -575,11 +583,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) - for segment_index, rs in enumerate(self._recording_segments): - d = rs.get_times_kwargs() - time_vector = d["time_vector"] - if time_vector is not None: - cached._recording_segments[segment_index].time_vector = time_vector + time_vectors = self._get_time_vectors() + if time_vectors is not None: + for segment_index, time_vector in enumerate(time_vectors): + if time_vector is not None: + cached.set_times(time_vector, segment_index=segment_index) return cached diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index 1f3fee74a8..869842779d 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -14,7 +14,6 @@ class BaseSnippets(BaseRecordingSnippets): Abstract class representing several multichannel snippets. """ - _main_annotations = [] _main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"] _main_features = [] diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index d9a567dedf..2af48407a3 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -30,11 +30,10 @@ def __init__(self, sampling_frequency: float, unit_ids: List): self._cached_spike_trains = {} def __repr__(self): - clsname = self.__class__.__name__ nseg = self.get_num_segments() nunits = self.get_num_units() sf_khz = self.get_sampling_frequency() / 1000.0 - txt = f"{clsname}: {nunits} units - {nseg} segments - {sf_khz:0.1f}kHz" + txt = f"{self.name}: {nunits} units - {nseg} segments - {sf_khz:0.1f}kHz" if "file_path" in self._kwargs: txt += "\n file_path: {}".format(self._kwargs["file_path"]) return txt diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 11909bce0e..6ce94114c4 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -80,6 +80,8 @@ def generate_recording( probe.set_device_channel_indices(np.arange(num_channels)) recording.set_probe(probe, in_place=True) + recording.name = "SyntheticRecording" + return recording @@ -101,11 +103,11 @@ def generate_sorting( Parameters ---------- num_units : int, default: 5 - Number of units + Number of units. sampling_frequency : float, default: 30000.0 - The sampling frequency + The sampling frequency. durations : list, default: [10.325, 3.5] - Duration of each segment in s + Duration of each segment in s. firing_rates : float, default: 3.0 The firing rate of each unit (in Hz). empty_units : list, default: None @@ -121,12 +123,12 @@ def generate_sorting( border_size_samples : int, default: 20 The size of the border in samples to add border spikes. seed : int, default: None - The random seed + The random seed. Returns ------- sorting : NumpySorting - The sorting object + The sorting object. """ seed = _ensure_seed(seed) rng = np.random.default_rng(seed) @@ -185,19 +187,19 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): Parameters ---------- sorting : BaseSorting - The sorting object + The sorting object. sync_event_ratio : float The ratio of added synchronous spikes with respect to the total number of spikes. E.g., 0.5 means that the final sorting will have 1.5 times number of spikes, and all the extra spikes are synchronous (same sample_index), but on different units (not duplicates). seed : int, default: None - The random seed + The random seed. Returns ------- sorting : TransformSorting - The sorting object, keeping track of added spikes + The sorting object, keeping track of added spikes. """ rng = np.random.default_rng(seed) @@ -247,18 +249,18 @@ def generate_sorting_to_inject( Parameters ---------- sorting : BaseSorting - The sorting object + The sorting object. num_samples: list of size num_segments. The number of samples in all the segments of the sorting, to generate spike times - covering entire the entire duration of the segments + covering entire the entire duration of the segments. max_injected_per_unit: int, default 1000 - The maximal number of spikes injected per units + The maximal number of spikes injected per units. injected_rate: float, default 0.05 - The rate at which spikes are injected + The rate at which spikes are injected. refractory_period_ms: float, default 1.5 - The refractory period that should not be violated while injecting new spikes + The refractory period that should not be violated while injecting new spikes. seed: int, default None - The random seed + The random seed. Returns ------- @@ -310,22 +312,22 @@ class TransformSorting(BaseSorting): Parameters ---------- sorting : BaseSorting - The sorting object + The sorting object. added_spikes_existing_units : np.array (spike_vector) - The spikes that should be added to the sorting object, for existing units + The spikes that should be added to the sorting object, for existing units. added_spikes_new_units: np.array (spike_vector) - The spikes that should be added to the sorting object, for new units + The spikes that should be added to the sorting object, for new units. new_units_ids: list - The unit_ids that should be added if spikes for new units are added + The unit_ids that should be added if spikes for new units are added. refractory_period_ms : float, default None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be - discarded + discarded. Returns ------- sorting : TransformSorting - The sorting object with the added spikes and/or units + The sorting object with the added spikes and/or units. """ def __init__( @@ -426,12 +428,14 @@ def add_from_sorting(sorting1: BaseSorting, sorting2: BaseSorting, refractory_pe Parameters ---------- - sorting1: the first sorting - sorting2: the second sorting + sorting1: BaseSorting + The first sorting. + sorting2: BaseSorting + The second sorting. refractory_period_ms : float, default None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be - discarded + discarded. """ assert ( sorting1.get_sampling_frequency() == sorting2.get_sampling_frequency() @@ -490,12 +494,14 @@ def add_from_unit_dict( Parameters ---------- - sorting1: the first sorting + sorting1: BaseSorting + The first sorting dict_list: list of dict + A list of dict with unit_ids as keys and spike times as values. refractory_period_ms : float, default None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be - discarded + discarded. """ sorting2 = NumpySorting.from_unit_dict(units_dict_list, sorting1.get_sampling_frequency()) sorting = TransformSorting.add_from_sorting(sorting1, sorting2, refractory_period_ms) @@ -513,18 +519,19 @@ def from_times_labels( Parameters ---------- - sorting1: the first sorting + sorting1: BaseSorting + The first sorting times_list: list of array (or array) - An array of spike times (in frames) + An array of spike times (in frames). labels_list: list of array (or array) - An array of spike labels corresponding to the given times + An array of spike labels corresponding to the given times. unit_ids: list or None, default: None The explicit list of unit_ids that should be extracted from labels_list - If None, then it will be np.unique(labels_list) + If None, then it will be np.unique(labels_list). refractory_period_ms : float, default None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be - discarded + discarded. """ sorting2 = NumpySorting.from_times_labels(times_list, labels_list, sampling_frequency, unit_ids) @@ -554,6 +561,16 @@ def clean_refractory_period(self): def create_sorting_npz(num_seg, file_path): + """ + Create a NPZ sorting file. + + Parameters + ---------- + num_seg : int + The number of segments. + file_path : str | Path + The file path to save the NPZ file. + """ # create a NPZ sorting file d = {} d["unit_ids"] = np.array([0, 1, 2], dtype="int64") @@ -583,6 +600,35 @@ def generate_snippets( empty_units=None, **job_kwargs, ): + """ + Generates a synthetic Snippets object. + + Parameters + ---------- + nbefore : int, default: 20 + Number of samples before the peak. + nafter : int, default: 44 + Number of samples after the peak. + num_channels : int, default: 2 + Number of channels. + wf_folder : str | Path | None, default: None + Optional folder to save the waveform snippets. If None, snippets are in memory. + sampling_frequency : float, default: 30000.0 + The sampling frequency of the snippets. + ndim : int, default: 2 + The number of dimensions of the probe. + num_units : int, default: 5 + The number of units. + empty_units : list | None, default: None + A list of units that will have no spikes. + + Returns + ------- + snippets : NumpySnippets + The snippets object. + sorting : NumpySorting + The associated sorting object. + """ recording = generate_recording( durations=durations, num_channels=num_channels, @@ -643,18 +689,18 @@ def synthesize_poisson_spike_vector( Parameters ---------- num_units : int, default: 20 - Number of neuronal units to simulate + Number of neuronal units to simulate. sampling_frequency : float, default: 30000.0 - Sampling frequency in Hz + Sampling frequency in Hz. duration : float, default: 60.0 - Duration of the simulation in seconds + Duration of the simulation in seconds. refractory_period_ms : float, default: 4.0 - Refractory period between spikes in milliseconds + Refractory period between spikes in milliseconds. firing_rates : float or array_like or tuple, default: 3.0 Firing rate(s) in Hz. Can be a single value for all units or an array of firing rates with - each element being the firing rate for one unit + each element being the firing rate for one unit. seed : int, default: 0 - Seed for random number generator + Seed for random number generator. Returns ------- @@ -748,27 +794,27 @@ def synthesize_random_firings( Parameters ---------- num_units : int - number of units + Number of units. sampling_frequency : float - sampling rate + Sampling rate. duration : float - duration of the segment in seconds + Duration of the segment in seconds. refractory_period_ms: float - refractory_period in ms + Refractory period in ms. firing_rates: float or list[float] The firing rate of each unit (in Hz). If float, all units will have the same firing rate. add_shift_shuffle: bool, default: False Optionally add a small shuffle on half of the spikes to make the autocorrelogram less flat. seed: int, default: None - seed for the generator + Seed for the generator. Returns ------- - times: - Concatenated and sorted times vector - labels: - Concatenated and sorted label vector + times: np.array + Concatenated and sorted times vector. + labels: np.array + Concatenated and sorted label vector. """ @@ -852,11 +898,11 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No Parameters ---------- sorting : - Original sorting + Original sorting. num : int - Number of injected units + Number of injected units. max_shift : int - range of the shift in sample + range of the shift in sample. ratio: float Proportion of original spike in the injected units. @@ -907,8 +953,27 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No def inject_some_split_units(sorting, split_ids: list, num_split=2, output_ids=False, seed=None): - """ """ + """ + Inject some split units in a sorting. + Parameters + ---------- + sorting : BaseSorting + Original sorting. + split_ids : list + List of unit_ids to split. + num_split : int, default: 2 + Number of split units. + output_ids : bool, default: False + If True, return the new unit_ids. + seed : int, default: None + Random seed. + + Returns + ------- + sorting_with_split : NumpySorting + A sorting with split units. + """ unit_ids = sorting.unit_ids assert unit_ids.dtype.kind == "i" @@ -958,7 +1023,7 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol num_violations : int Number of contaminating spikes. violation_delta : float, default: 1e-5 - Temporal offset of contaminating spikes (in seconds) + Temporal offset of contaminating spikes (in seconds). Returns ------- @@ -1215,7 +1280,7 @@ def generate_recording_by_size( num_channels: int Number of channels. seed : int, default: None - The seed for np.random.default_rng + The seed for np.random.default_rng. Returns ------- @@ -1615,7 +1680,7 @@ class InjectTemplatesRecording(BaseRecording): * (num_units, num_samples, num_channels): standard case * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce sampling jitter. nbefore: list[int] | int | None, default: None - Where is the center of the template for each unit? + The number of samples before the peak of the template to align the spike. If None, will default to the highest peak. amplitude_factor: list[float] | float | None, default: None The amplitude of each spike for each unit. @@ -1630,7 +1695,7 @@ class InjectTemplatesRecording(BaseRecording): You can use int for mono-segment objects. upsample_vector: np.array or None, default: None. When templates is 4d we can simulate a jitter. - Optional the upsample_vector is the jitter index with a number per spike in range 0-templates.sahpe[3] + Optional the upsample_vector is the jitter index with a number per spike in range 0-templates.shape[3]. Returns ------- @@ -1738,6 +1803,8 @@ def __init__( ) self.add_recording_segment(recording_segment) + # to discuss: maybe we could set json serializability to False always + # because templates could be large! if not sorting.check_serializability("json"): self._serializability["json"] = False if parent_recording is not None: @@ -2122,4 +2189,7 @@ def generate_ground_truth_recording( recording.set_channel_gains(1.0) recording.set_channel_offsets(0.0) + recording.name = "GroundTruthRecording" + sorting.name = "GroundTruthSorting" + return recording, sorting diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 09ba743a8c..f4790817a8 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -83,6 +83,9 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N @staticmethod def from_recording(source_recording, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs) + + t_starts = source_recording._get_t_starts() + if shms[0] is not None: # if the computation was done in parallel then traces_list is shared array # this can lead to problem @@ -91,13 +94,14 @@ def from_recording(source_recording, **job_kwargs): for shm in shms: shm.close() shm.unlink() - # TODO later : propagte t_starts ? + recording = NumpyRecording( traces_list, source_recording.get_sampling_frequency(), - t_starts=None, + t_starts=t_starts, channel_ids=source_recording.channel_ids, ) + return recording class NumpyRecordingSegment(BaseRecordingSegment): @@ -206,7 +210,7 @@ def __del__(self): def from_recording(source_recording, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs) - # TODO later : propagte t_starts ? + t_starts = source_recording._get_t_starts() recording = SharedMemoryRecording( shm_names=[shm.name for shm in shms], @@ -214,7 +218,7 @@ def from_recording(source_recording, **job_kwargs): dtype=source_recording.dtype, sampling_frequency=source_recording.sampling_frequency, channel_ids=source_recording.channel_ids, - t_starts=None, + t_starts=t_starts, main_shm_owner=True, ) diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index b23b7202c6..039fa8fd60 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -191,16 +191,20 @@ def get_traces(self, start_frame, end_frame, channel_indices): seg_start = self.cumsum_length[i] if i == i0: # first - traces_chunk = rec_seg.get_traces(start_frame - seg_start, None, channel_indices) + end_frame_ = rec_seg.get_num_samples() + traces_chunk = rec_seg.get_traces(start_frame - seg_start, end_frame_, channel_indices) all_traces.append(traces_chunk) elif i == i1: # last if (end_frame - seg_start) > 0: - traces_chunk = rec_seg.get_traces(None, end_frame - seg_start, channel_indices) + start_frame_ = 0 + traces_chunk = rec_seg.get_traces(start_frame_, end_frame - seg_start, channel_indices) all_traces.append(traces_chunk) else: # in between - traces_chunk = rec_seg.get_traces(None, None, channel_indices) + start_frame_ = 0 + end_frame_ = rec_seg.get_num_samples() + traces_chunk = rec_seg.get_traces(start_frame_, end_frame_, channel_indices) all_traces.append(traces_chunk) traces = np.concatenate(all_traces, axis=0) diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index 947c5686d8..7f55646b63 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -4,8 +4,9 @@ """ from typing import Sequence +import numpy as np from spikeinterface.core.base import BaseExtractor -from spikeinterface.core import generate_recording, concatenate_recordings +from spikeinterface.core import generate_recording, generate_ground_truth_recording, concatenate_recordings class DummyDictExtractor(BaseExtractor): @@ -65,6 +66,34 @@ def test_check_if_serializable(): assert not extractor.check_serializability("json") +def test_name_and_repr(): + test_recording, test_sorting = generate_ground_truth_recording(seed=0, durations=[2]) + assert test_recording.name == "GroundTruthRecording" + assert test_sorting.name == "GroundTruthSorting" + + # set a different name + test_recording.name = "MyRecording" + assert test_recording.name == "MyRecording" + + # to/from dict + test_recording_dict = test_recording.to_dict() + test_recording2 = BaseExtractor.from_dict(test_recording_dict) + assert test_recording2.name == "MyRecording" + + # repr + rec_str = str(test_recording2) + assert "MyRecording" in rec_str + test_recording2.name = None + assert "MyRecording" not in str(test_recording2) + assert test_recording2.__class__.__name__ in str(test_recording2) + # above 10khz, sampling frequency is printed in kHz + assert f"kHz" in rec_str + # below 10khz sampling frequency is printed in Hz + test_rec_low_fs = generate_recording(seed=0, durations=[2], sampling_frequency=5000) + rec_str = str(test_rec_low_fs) + assert "Hz" in rec_str + + if __name__ == "__main__": test_check_if_memory_serializable() test_check_if_serializable() diff --git a/src/spikeinterface/core/tests/test_segmentutils.py b/src/spikeinterface/core/tests/test_segmentutils.py index d3c73805f0..166ecafd09 100644 --- a/src/spikeinterface/core/tests/test_segmentutils.py +++ b/src/spikeinterface/core/tests/test_segmentutils.py @@ -5,10 +5,6 @@ from numpy.testing import assert_raises from spikeinterface.core import ( - AppendSegmentRecording, - AppendSegmentSorting, - ConcatenateSegmentRecording, - ConcatenateSegmentSorting, NumpyRecording, NumpySorting, append_recordings, diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index 487a893096..049d5ab6e5 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -1,69 +1,289 @@ +import copy + import pytest import numpy as np from spikeinterface.core import generate_recording, generate_sorting +import spikeinterface.full as si + +class TestTimeHandling: + """ + This class tests how time is handled in SpikeInterface. Under the hood, + time can be represented as a full `time_vector` or only as + `t_start` attribute on segments from which a vector of times + is generated on the fly. Both time representations are tested here. + """ -def test_time_handling(create_cache_folder): - cache_folder = create_cache_folder - durations = [[10], [10, 5]] + # Fixtures ##### + @pytest.fixture(scope="session") + def time_vector_recording(self): + """ + Add time vectors to the recording, returning the + raw recording, recording with time vectors added to + segments, and list a the time vectors added to the recording. + """ + durations = [10, 15, 20] + raw_recording = generate_recording(num_channels=4, durations=durations) - # test multi-segment - for i, dur in enumerate(durations): - rec = generate_recording(num_channels=4, durations=dur) - sort = generate_sorting(num_units=10, durations=dur) + return self._get_time_vector_recording(raw_recording) - for segment_index in range(rec.get_num_segments()): - original_times = rec.get_times(segment_index=segment_index) - new_times = original_times + 5 - rec.set_times(new_times, segment_index=segment_index) + @pytest.fixture(scope="session") + def t_start_recording(self): + """ + Add a t_starts to the recording, returning the + raw recording, recording with t_starts added to segments, + and a list of the time vectors generated from adding the + t_start to the recording times. + """ + durations = [10, 15, 20] + raw_recording = generate_recording(num_channels=4, durations=durations) - sort.register_recording(rec) - assert sort.has_recording() + return self._get_t_start_recording(raw_recording) - rec_cache = rec.save(folder=cache_folder / f"rec{i}") + def _get_time_vector_recording(self, raw_recording): + """ + Loop through all recording segments, adding a different time + vector to each segment. The time vector is the original times with + a t_start and irregularly spaced offsets to mimic irregularly + spaced timeseries data. Return the original recording, + recoridng with time vectors added and list including the added time vectors. + """ + times_recording = copy.deepcopy(raw_recording) + all_time_vectors = [] + for segment_index in range(raw_recording.get_num_segments()): - for segment_index in range(sort.get_num_segments()): - assert rec.has_time_vector(segment_index=segment_index) - assert sort.has_time_vector(segment_index=segment_index) + t_start = segment_index + 1 * 100 - # times are correctly saved by the recording - assert np.allclose( - rec.get_times(segment_index=segment_index), rec_cache.get_times(segment_index=segment_index) + some_small_increasing_numbers = np.arange(times_recording.get_num_samples(segment_index)) * ( + 1 / times_recording.get_sampling_frequency() ) - # spike times are correctly adjusted - for u in sort.get_unit_ids(): - spike_times = sort.get_unit_spike_train(u, segment_index=segment_index, return_times=True) - rec_times = rec.get_times(segment_index=segment_index) - assert np.all(spike_times >= rec_times[0]) - assert np.all(spike_times <= rec_times[-1]) + offsets = np.cumsum(some_small_increasing_numbers) + time_vector = t_start + times_recording.get_times(segment_index) + offsets + + all_time_vectors.append(time_vector) + times_recording.set_times(times=time_vector, segment_index=segment_index) + + assert np.array_equal( + times_recording._recording_segments[segment_index].time_vector, + time_vector, + ), "time_vector was not properly set during test setup" + + return (raw_recording, times_recording, all_time_vectors) + + def _get_t_start_recording(self, raw_recording): + """ + For each segment in the recording, add a different `t_start`. + Return a list of time vectors generating from the recording times + + the t_starts. + """ + t_start_recording = copy.deepcopy(raw_recording) + + all_t_starts = [] + for segment_index in range(raw_recording.get_num_segments()): + + t_start = (segment_index + 1) * 100 + + all_t_starts.append(t_start + t_start_recording.get_times(segment_index)) + t_start_recording._recording_segments[segment_index].t_start = t_start + + return (raw_recording, t_start_recording, all_t_starts) + + def _get_fixture_data(self, request, fixture_name): + """ + A convenience function to get the data from a fixture + based on the name. This is used to allow parameterising + tests across fixtures. + """ + time_recording_fixture = request.getfixturevalue(fixture_name) + raw_recording, times_recording, all_times = time_recording_fixture + return (raw_recording, times_recording, all_times) + + # Tests ##### + def test_has_time_vector(self, time_vector_recording): + """ + Test the `has_time_vector` function returns `False` before + a time vector is added and `True` afterwards. + """ + raw_recording, times_recording, _ = time_vector_recording + + for segment_idx in range(raw_recording.get_num_segments()): + + assert raw_recording.has_time_vector(segment_idx) is False + assert times_recording.has_time_vector(segment_idx) is True + + @pytest.mark.parametrize("mode", ["binary", "zarr"]) + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_times_propagated_to_save_folder(self, request, fixture_name, mode, tmp_path): + """ + Test `t_start` or `time_vector` is propagated to a saved recording, + by saving, reloading, and checking times are correct. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + folder_name = "recording" + recording_cache = times_recording.save(format=mode, folder=tmp_path / folder_name) + + if mode == "zarr": + folder_name += ".zarr" + recording_load = si.load_extractor(tmp_path / folder_name) + + self._check_times_match(recording_cache, all_times) + self._check_times_match(recording_load, all_times) + + @pytest.mark.parametrize("sharedmem", [True, False]) + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_times_propagated_to_save_memory(self, request, fixture_name, sharedmem): + """ + Test t_start and time_vector are propagated to recording saved into memory. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + recording_load = times_recording.save(format="memory", sharedmem=sharedmem) + self._check_times_match(recording_load, all_times) -def test_frame_slicing(): - duration = [10] + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_time_propagated_to_select_segments(self, request, fixture_name): + """ + Test that when `recording.select_segments()` is used, the times + are propagated to the new recoridng object. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) - rec = generate_recording(num_channels=4, durations=duration) - sort = generate_sorting(num_units=10, durations=duration) + for segment_index in range(times_recording.get_num_segments()): + segment = times_recording.select_segments(segment_index) + assert np.array_equal(segment.get_times(), all_times[segment_index]) - original_times = rec.get_times() - new_times = original_times + 5 - rec.set_times(new_times) + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_times_propagated_to_sorting(self, request, fixture_name): + """ + Check that when attached to a sorting object, the times are propagated + to the object. This means that all spike times should respect the + `t_start` or `time_vector` added. + """ + raw_recording, times_recording, all_times = self._get_fixture_data(request, fixture_name) + sorting = self._get_sorting_with_recording_attached( + recording_for_durations=raw_recording, recording_to_attach=times_recording + ) + for segment_index in range(raw_recording.get_num_segments()): - sort.register_recording(rec) + if fixture_name == "time_vector_recording": + assert sorting.has_time_vector(segment_index=segment_index) - start_frame = 3 * rec.get_sampling_frequency() - end_frame = 7 * rec.get_sampling_frequency() + self._check_spike_times_are_correct(sorting, times_recording, segment_index) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_time_sample_converters(self, request, fixture_name): + """ + Test the `recording.sample_time_to_index` and + `recording.time_to_sample_index` convenience functions. + """ + raw_recording, times_recording, all_times = self._get_fixture_data(request, fixture_name) + with pytest.raises(ValueError) as e: + times_recording.sample_index_to_time(0) + assert "Provide 'segment_index'" in str(e) + + for segment_index in range(times_recording.get_num_segments()): + + sample_index = np.random.randint(low=0, high=times_recording.get_num_samples(segment_index)) + time_ = times_recording.sample_index_to_time(sample_index, segment_index=segment_index) + + assert time_ == all_times[segment_index][sample_index] + + new_sample_index = times_recording.time_to_sample_index(time_, segment_index=segment_index) + + assert new_sample_index == sample_index + + @pytest.mark.parametrize("time_type", ["time_vector", "t_start"]) + @pytest.mark.parametrize("bounds", ["start", "middle", "end"]) + def test_slice_recording(self, time_type, bounds): + """ + Test times are correct after applying `frame_slice` or `time_slice` + to a recording or sorting (for `frame_slice`). The the recording times + should be correct with respect to the set `t_start` or `time_vector`. + """ + raw_recording = generate_recording(num_channels=4, durations=[10]) + + if time_type == "time_vector": + raw_recording, times_recording, all_times = self._get_time_vector_recording(raw_recording) + else: + raw_recording, times_recording, all_times = self._get_t_start_recording(raw_recording) + + sorting = self._get_sorting_with_recording_attached( + recording_for_durations=raw_recording, recording_to_attach=times_recording + ) + + # Take some different times, including min and max bounds of + # the recording, and some arbitaray times in the middle (20% and 80%). + if bounds == "start": + start_frame = 0 + end_frame = int(times_recording.get_num_samples(0) * 0.8) + elif bounds == "end": + start_frame = int(times_recording.get_num_samples(0) * 0.2) + end_frame = times_recording.get_num_samples(0) - 1 + elif bounds == "middle": + start_frame = int(times_recording.get_num_samples(0) * 0.2) + end_frame = int(times_recording.get_num_samples(0) * 0.8) + + # Slice the recording and get the new times are correct + rec_frame_slice = times_recording.frame_slice(start_frame=start_frame, end_frame=end_frame) + sort_frame_slice = sorting.frame_slice(start_frame=start_frame, end_frame=end_frame) + + assert np.allclose(rec_frame_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8) + + self._check_spike_times_are_correct(sort_frame_slice, rec_frame_slice, segment_index=0) + + # Test `time_slice` + start_time = times_recording.sample_index_to_time(start_frame) + end_time = times_recording.sample_index_to_time(end_frame) + + rec_slice = times_recording.time_slice(start_time=start_time, end_time=end_time) + + assert np.allclose(rec_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8) + + # Helpers #### + def _check_times_match(self, recording, all_times): + """ + For every segment in a recording, check the `get_times()` + match the expected times in the list of time vectors, `all_times`. + """ + for segment_index in range(recording.get_num_segments()): + assert np.array_equal(recording.get_times(segment_index), all_times[segment_index]) + + def _check_spike_times_are_correct(self, sorting, times_recording, segment_index): + """ + For every unit in the `sorting`, for a particular segment, check that + the unit times match the times of the original recording as + retrieved with `get_times()`. + """ + for unit_id in sorting.get_unit_ids(): + spike_times = sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True) + spike_indexes = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) + rec_times = times_recording.get_times(segment_index=segment_index) + + assert np.array_equal( + spike_times, + rec_times[spike_indexes], + ) - rec_slice = rec.frame_slice(start_frame=start_frame, end_frame=end_frame) - sort_slice = sort.frame_slice(start_frame=start_frame, end_frame=end_frame) + def _get_sorting_with_recording_attached(self, recording_for_durations, recording_to_attach): + """ + Convenience function to create a sorting object with + a recording attached. Typically use the raw recordings + for the durations of which to make the sorter, as + the generate_sorter is not setup to handle the + (strange) edge case of the irregularly spaced + test time vectors. + """ + durations = [ + recording_for_durations.get_duration(idx) for idx in range(recording_for_durations.get_num_segments()) + ] - for u in sort_slice.get_unit_ids(): - spike_times = sort_slice.get_unit_spike_train(u, return_times=True) - rec_times = rec_slice.get_times() - assert np.all(spike_times >= rec_times[0]) - assert np.all(spike_times <= rec_times[-1]) + sorting = generate_sorting(num_units=10, durations=durations) + sorting.register_recording(recording_to_attach) + assert sorting.has_recording() -if __name__ == "__main__": - test_frame_slicing() + return sorting diff --git a/src/spikeinterface/extractors/extractorlist.py b/src/spikeinterface/extractors/extractorlist.py index e56d4fff52..bd35180a7e 100644 --- a/src/spikeinterface/extractors/extractorlist.py +++ b/src/spikeinterface/extractors/extractorlist.py @@ -116,3 +116,21 @@ event_extractor_full_list += neo_event_extractors_list snippets_extractor_full_list = [NpySnippetsExtractor, WaveClusSnippetsExtractor] + +recording_extractor_full_dict = {} +for rec_class in recording_extractor_full_list: + # here we get the class name, remove "Recording" and "Extractor" and make it lower case + rec_class_name = rec_class.__name__.replace("Recording", "").replace("Extractor", "").lower() + recording_extractor_full_dict[rec_class_name] = rec_class + +sorting_extractor_full_dict = {} +for sort_class in sorting_extractor_full_list: + # here we get the class name, remove "Extractor" and make it lower case + sort_class_name = sort_class.__name__.replace("Sorting", "").replace("Extractor", "").lower() + sorting_extractor_full_dict[sort_class_name] = sort_class + +event_extractor_full_dict = {} +for event_class in event_extractor_full_list: + # here we get the class name, remove "Extractor" and make it lower case + event_class_name = event_class.__name__.replace("Event", "").replace("Extractor", "").lower() + event_extractor_full_dict[event_class_name] = event_class diff --git a/src/spikeinterface/extractors/neoextractors/__init__.py b/src/spikeinterface/extractors/neoextractors/__init__.py index 0b11b72b2a..bf52de7c1d 100644 --- a/src/spikeinterface/extractors/neoextractors/__init__.py +++ b/src/spikeinterface/extractors/neoextractors/__init__.py @@ -36,7 +36,7 @@ ) from .spike2 import Spike2RecordingExtractor, read_spike2 from .spikegadgets import SpikeGadgetsRecordingExtractor, read_spikegadgets -from .spikeglx import SpikeGLXRecordingExtractor, read_spikeglx +from .spikeglx import SpikeGLXRecordingExtractor, SpikeGLXEventExtractor, read_spikeglx, read_spikeglx_event from .tdt import TdtRecordingExtractor, read_tdt from .neo_utils import get_neo_streams, get_neo_num_blocks @@ -73,4 +73,9 @@ Plexon2SortingExtractor, ] -neo_event_extractors_list = [AlphaOmegaEventExtractor, OpenEphysBinaryEventExtractor, Plexon2EventExtractor] +neo_event_extractors_list = [ + AlphaOmegaEventExtractor, + OpenEphysBinaryEventExtractor, + Plexon2EventExtractor, + SpikeGLXEventExtractor, +] diff --git a/src/spikeinterface/generation/__init__.py b/src/spikeinterface/generation/__init__.py index 7a2291d932..5bf42ecf0f 100644 --- a/src/spikeinterface/generation/__init__.py +++ b/src/spikeinterface/generation/__init__.py @@ -14,6 +14,7 @@ relocate_templates, ) from .noise_tools import generate_noise + from .drifting_generator import ( make_one_displacement_vector, generate_displacement_vector, @@ -26,3 +27,22 @@ list_available_datasets_in_template_database, query_templates_from_database, ) + +# expose the core generate functions +from ..core.generate import ( + generate_recording, + generate_sorting, + generate_snippets, + generate_templates, + generate_recording_by_size, + generate_ground_truth_recording, + add_synchrony_to_sorting, + synthesize_random_firings, + inject_some_duplicate_units, + inject_some_split_units, + synthetize_spike_train_bad_isi, + NoiseGeneratorRecording, + noise_generator_recording, + InjectTemplatesRecording, + inject_templates, +) diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index cce2e08b58..70e13160f4 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -458,6 +458,9 @@ def __init__( self.set_probe(drifting_templates.probe, in_place=True) + # templates are too large, we don't serialize them to JSON + self._serializability["json"] = False + self._kwargs = { "sorting": sorting, "drifting_templates": drifting_templates, diff --git a/src/spikeinterface/generation/noise_tools.py b/src/spikeinterface/generation/noise_tools.py index 11f30e352f..685f0113b4 100644 --- a/src/spikeinterface/generation/noise_tools.py +++ b/src/spikeinterface/generation/noise_tools.py @@ -7,22 +7,25 @@ def generate_noise( probe, sampling_frequency, durations, dtype="float32", noise_levels=15.0, spatial_decay=None, seed=None ): """ + Generate a noise recording. Parameters ---------- probe : Probe A probe object. sampling_frequency : float - Sampling frequency + The sampling frequency of the recording. durations : list of float - Durations + The duration(s) of the recording segment(s) in seconds. dtype : np.dtype - Dtype - noise_levels : float | np.array | tuple + The dtype of the recording. + noise_levels : float | np.array | tuple, default: 15.0 If scalar same noises on all channels. If array then per channels noise level. If tuple, then this represent the range. - seed : None | int + spatial_decay : float | None, default: None + If not None, the spatial decay of the noise used to generate the noise covariance matrix. + seed : int | None, default: None The seed for random generator. Returns diff --git a/src/spikeinterface/generation/template_database.py b/src/spikeinterface/generation/template_database.py index e1cba07c8e..17d2bdf521 100644 --- a/src/spikeinterface/generation/template_database.py +++ b/src/spikeinterface/generation/template_database.py @@ -20,7 +20,7 @@ def fetch_template_object_from_database(dataset="test_templates.zarr") -> Templa Returns ------- Templates - _description_ + The templates object. """ s3_path = f"s3://spikeinterface-template-database/{dataset}/" zarr_group = zarr.open_consolidated(s3_path, storage_options={"anon": True})