Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
… into curation_metrics
  • Loading branch information
jakeswann1 committed Jul 10, 2024
2 parents 18c0f2e + e4fa25a commit d2d1386
Show file tree
Hide file tree
Showing 19 changed files with 606 additions and 151 deletions.
46 changes: 46 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -338,14 +338,60 @@ spikeinterface.curation
spikeinterface.generation
-------------------------

Core
~~~~
.. automodule:: spikeinterface.generation

.. autofunction:: generate_recording
.. autofunction:: generate_sorting
.. autofunction:: generate_snippets
.. autofunction:: generate_templates
.. autofunction:: generate_recording_by_size
.. autofunction:: generate_ground_truth_recording
.. autofunction:: add_synchrony_to_sorting
.. autofunction:: synthesize_random_firings
.. autofunction:: inject_some_duplicate_units
.. autofunction:: inject_some_split_units
.. autofunction:: synthetize_spike_train_bad_isi
.. autofunction:: inject_templates
.. autofunction:: noise_generator_recording
.. autoclass:: InjectTemplatesRecording
.. autoclass:: NoiseGeneratorRecording

Drift
~~~~~
.. automodule:: spikeinterface.generation

.. autofunction:: generate_drifting_recording
.. autofunction:: generate_displacement_vector
.. autofunction:: make_one_displacement_vector
.. autofunction:: make_linear_displacement
.. autofunction:: move_dense_templates
.. autofunction:: interpolate_templates
.. autoclass:: DriftingTemplates
.. autoclass:: InjectDriftingTemplatesRecording

Hybrid
~~~~~~
.. automodule:: spikeinterface.generation

.. autofunction:: generate_hybrid_recording
.. autofunction:: estimate_templates_from_recording
.. autofunction:: select_templates
.. autofunction:: scale_template_to_range
.. autofunction:: relocate_templates
.. autofunction:: fetch_template_object_from_database
.. autofunction:: fetch_templates_database_info
.. autofunction:: list_available_datasets_in_template_database
.. autofunction:: query_templates_from_database


Noise
~~~~~
.. automodule:: spikeinterface.generation

.. autofunction:: generate_noise


spikeinterface.sortingcomponents
--------------------------------
Expand Down
2 changes: 1 addition & 1 deletion doc/how_to/benchmark_with_hybrid_recordings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ with known spiking activity. The template (aka average waveforms) of the
injected units can be from previous spike sorted data. In this example,
we will be using an open database of templates that we have constructed
from the International Brain Laboratory - Brain Wide Map (available on
`DANDI <https://dandiarchive.org/dandiset/000409?search=IBL&page=2&sortOption=0&sortDir=-1&showDrafts=true&showEmpty=false&pos=9>`__).
`DANDI <https://dandiarchive.org/dandiset/000409?search=IBL&page=2&sortOption=0&sortDir=-1&showDrafts=true&showEmpty=false&pos=9>`_).

Importantly, recordings from long-shank probes, such as Neuropixels,
usually experience drifts. Such drifts have to be taken into account in
Expand Down
27 changes: 23 additions & 4 deletions doc/modules/generation.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,28 @@
Generation module
=================

The :py:mod:`spikeinterface.generation` provides functions to generate recordings containing spikes.
This module proposes several approaches for this including purely synthetic recordings as well as "hybrid" recordings (where templates come from true datasets).
The :py:mod:`spikeinterface.generation` provides functions to generate recordings containing spikes,
which can be used as "ground-truth" for benchmarking spike sorting algorithms.

There are several approaches to generating such recordings.
One possibility is to generate purely synthetic recordings. Another approach is to use real
recordings and add synthetic spikes to them, to make "hybrid" recordings.
The advantage of the former is that the ground-truth is known exactly, which is useful for benchmarking.
The advantage of the latter is that the spikes are added to real noise, which can be more realistic.

The :py:mod:`spikeinterface.core.generate` already provides functions for generating synthetic data but this module will supply an extended and more complex
machinery, for instance generating recordings that possess various types of drift.
For hybrid recordings, the main challenge is to generate realistic spike templates.
We therefore built an open database of templates that we have constructed from the International
Brain Laboratory - Brain Wide Map (available on
`DANDI <https://dandiarchive.org/dandiset/000409?search=IBL&page=2&sortOption=0&sortDir=-1&showDrafts=true&showEmpty=false&pos=9>`_).
You can check out this collection of over 600 templates from this `web app <https://spikeinterface.github.io/hybrid_template_library/>`_.

The :py:mod:`spikeinterface.generation` module offers tools to interact with this database to select and download templates,
manupulating (e.g. rescaling and relocating them), and construct hybrid recordings with them.
Importantly, recordings from long-shank probes, such as Neuropixels, usually experience drifts.
Such drifts can be taken into account in order to smoothly inject spikes into the recording.

The :py:mod:`spikeinterface.generation` also includes functions to generate different kinds of drift signals and drifting
recordings, as well as generating synthetic noise profiles of various types.

Some of the generation functions are defined in the :py:mod:`spikeinterface.core.generate` module, but also exposed at the
:py:mod:`spikeinterface.generation` level for convenience.
26 changes: 19 additions & 7 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import weakref
import json
import pickle
import os
import random
import string
from packaging.version import parse
Expand Down Expand Up @@ -41,7 +40,7 @@ class BaseExtractor:
# This replaces the old key_properties
# These are annotations/properties that always need to be
# dumped (for instance locations, groups, is_fileterd, etc.)
_main_annotations = []
_main_annotations = ["name"]
_main_properties = []

# these properties are skipped by default in copy_metadata
Expand Down Expand Up @@ -79,6 +78,19 @@ def __init__(self, main_ids: Sequence) -> None:
# preferred context for multiprocessing
self._preferred_mp_context = None

@property
def name(self):
name = self._annotations.get("name", None)
return name if name is not None else self.__class__.__name__

@name.setter
def name(self, value):
if value is not None:
self.annotate(name=value)
else:
# we remove the annotation if it exists
_ = self._annotations.pop("name", None)

def get_num_segments(self) -> int:
# This is implemented in BaseRecording or BaseSorting
raise NotImplementedError
Expand Down Expand Up @@ -938,13 +950,14 @@ def save_to_folder(
folder.mkdir(parents=True, exist_ok=False)

# dump provenance
provenance_file = folder / f"provenance.json"
if self.check_serializability("json"):
provenance_file = folder / f"provenance.json"
self.dump(provenance_file)
elif self.check_serializability("pickle"):
provenance_file = folder / f"provenance.pkl"
self.dump(provenance_file)
else:
provenance_file.write_text(
json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8"
)
warnings.warn("The extractor is not serializable to file. The provenance will not be saved.")

self.save_metadata_to_folder(folder)

Expand Down Expand Up @@ -1011,7 +1024,6 @@ def save_to_zarr(
cached: ZarrExtractor
Saved copy of the extractor.
"""
import zarr
from .zarrextractors import read_zarr

save_kwargs.pop("format", None)
Expand Down
44 changes: 26 additions & 18 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BaseRecording(BaseRecordingSnippets):
Internally handle list of RecordingSegment
"""

_main_annotations = ["is_filtered"]
_main_annotations = BaseRecordingSnippets._main_annotations + ["is_filtered"]
_main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"]
_main_features = [] # recording do not handle features

Expand All @@ -45,9 +45,6 @@ def __init__(self, sampling_frequency: float, channel_ids: list, dtype):
self.annotate(is_filtered=False)

def __repr__(self):

class_name = self.__class__.__name__
name_to_display = class_name
num_segments = self.get_num_segments()

txt = self._repr_header()
Expand All @@ -57,7 +54,7 @@ def __repr__(self):
split_index = txt.rfind("-", 0, 100) # Find the last "-" before character 100
if split_index != -1:
first_line = txt[:split_index]
recording_string_space = len(name_to_display) + 2 # Length of name_to_display plus ": "
recording_string_space = len(self.name) + 2 # Length of self.name plus ": "
white_space_to_align_with_first_line = " " * recording_string_space
second_line = white_space_to_align_with_first_line + txt[split_index + 1 :].lstrip()
txt = first_line + "\n" + second_line
Expand Down Expand Up @@ -97,21 +94,21 @@ def list_to_string(lst, max_size=6):
return txt

def _repr_header(self):
class_name = self.__class__.__name__
name_to_display = class_name
num_segments = self.get_num_segments()
num_channels = self.get_num_channels()
sf_khz = self.get_sampling_frequency() / 1000.0
sf_hz = self.get_sampling_frequency()
sf_khz = sf_hz / 1000
dtype = self.get_dtype()

total_samples = self.get_total_samples()
total_duration = self.get_total_duration()
total_memory_size = self.get_total_memory_size()
sampling_frequency_repr = f"{sf_khz:0.1f}kHz" if sf_hz > 10_000.0 else f"{sf_hz:0.1f}Hz"

txt = (
f"{name_to_display}: "
f"{self.name}: "
f"{num_channels} channels - "
f"{sf_khz:0.1f}kHz - "
f"{sampling_frequency_repr} - "
f"{num_segments} segments - "
f"{total_samples:,} samples - "
f"{convert_seconds_to_str(total_duration)} - "
Expand Down Expand Up @@ -501,24 +498,35 @@ def time_to_sample_index(self, time_s, segment_index=None):
rs = self._recording_segments[segment_index]
return rs.time_to_sample_index(time_s)

def _save(self, format="binary", verbose: bool = False, **save_kwargs):
def _get_t_starts(self):
# handle t_starts
t_starts = []
has_time_vectors = []
for segment_index, rs in enumerate(self._recording_segments):
for rs in self._recording_segments:
d = rs.get_times_kwargs()
t_starts.append(d["t_start"])
has_time_vectors.append(d["time_vector"] is not None)

if all(t_start is None for t_start in t_starts):
t_starts = None
return t_starts

def _get_time_vectors(self):
time_vectors = []
for rs in self._recording_segments:
d = rs.get_times_kwargs()
time_vectors.append(d["time_vector"])
if all(time_vector is None for time_vector in time_vectors):
time_vectors = None
return time_vectors

def _save(self, format="binary", verbose: bool = False, **save_kwargs):
kwargs, job_kwargs = split_job_kwargs(save_kwargs)

if format == "binary":
folder = kwargs["folder"]
file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())]
dtype = kwargs.get("dtype", None) or self.get_dtype()
t_starts = self._get_t_starts()

write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs)

Expand Down Expand Up @@ -575,11 +583,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
probegroup = self.get_probegroup()
cached.set_probegroup(probegroup)

for segment_index, rs in enumerate(self._recording_segments):
d = rs.get_times_kwargs()
time_vector = d["time_vector"]
if time_vector is not None:
cached._recording_segments[segment_index].time_vector = time_vector
time_vectors = self._get_time_vectors()
if time_vectors is not None:
for segment_index, time_vector in enumerate(time_vectors):
if time_vector is not None:
cached.set_times(time_vector, segment_index=segment_index)

return cached

Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/core/basesnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ class BaseSnippets(BaseRecordingSnippets):
Abstract class representing several multichannel snippets.
"""

_main_annotations = []
_main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"]
_main_features = []

Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@ def __init__(self, sampling_frequency: float, unit_ids: List):
self._cached_spike_trains = {}

def __repr__(self):
clsname = self.__class__.__name__
nseg = self.get_num_segments()
nunits = self.get_num_units()
sf_khz = self.get_sampling_frequency() / 1000.0
txt = f"{clsname}: {nunits} units - {nseg} segments - {sf_khz:0.1f}kHz"
txt = f"{self.name}: {nunits} units - {nseg} segments - {sf_khz:0.1f}kHz"
if "file_path" in self._kwargs:
txt += "\n file_path: {}".format(self._kwargs["file_path"])
return txt
Expand Down
Loading

0 comments on commit d2d1386

Please sign in to comment.