Skip to content

Commit

Permalink
Merge branch 'main' into tdc_2
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia authored Oct 6, 2023
2 parents f6e2f59 + a2d27ff commit f631484
Show file tree
Hide file tree
Showing 37 changed files with 202 additions and 150 deletions.
10 changes: 7 additions & 3 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,12 @@ def __init__(self, main_ids: Sequence) -> None:
self._kwargs = {}

# 'main_ids' will either be channel_ids or units_ids
# They is used for properties
# They are used for properties
self._main_ids = np.array(main_ids)
if len(self._main_ids) > 0:
assert (
self._main_ids.dtype.kind in "uiSU"
), f"Main IDs can only be integers (signed/unsigned) or strings, not {self._main_ids.dtype}"

# dict at object level
self._annotations = {}
Expand Down Expand Up @@ -984,7 +988,7 @@ def _load_extractor_from_dict(dic) -> BaseExtractor:
class_name = None

if "kwargs" not in dic:
raise Exception(f"This dict cannot be load into extractor {dic}")
raise Exception(f"This dict cannot be loaded into extractor {dic}")

# Create new kwargs to avoid modifying the original dict["kwargs"]
new_kwargs = dict()
Expand All @@ -1005,7 +1009,7 @@ def _load_extractor_from_dict(dic) -> BaseExtractor:
assert extractor_class is not None and class_name is not None, "Could not load spikeinterface class"
if not _check_same_version(class_name, dic["version"]):
warnings.warn(
f"Versions are not the same. This might lead compatibility errors. "
f"Versions are not the same. This might lead to compatibility errors. "
f"Using {class_name.split('.')[0]}=={dic['version']} is recommended"
)

Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def get_traces(

if not self.has_scaled():
raise ValueError(
"This recording do not support return_scaled=True (need gain_to_uV and offset_" "to_uV properties)"
"This recording does not support return_scaled=True (need gain_to_uV and offset_"
"to_uV properties)"
)
else:
gains = self.get_property("gain_to_uV")
Expand Down Expand Up @@ -416,8 +417,8 @@ def set_times(self, times, segment_index=None, with_warning=True):
if with_warning:
warn(
"Setting times with Recording.set_times() is not recommended because "
"times are not always propagated to across preprocessing"
"Use use this carefully!"
"times are not always propagated across preprocessing"
"Use this carefully!"
)

def sample_index_to_time(self, sample_ind, segment_index=None):
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/baserecordingsnippets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from __future__ import annotations
from pathlib import Path

import numpy as np
Expand All @@ -19,7 +19,7 @@ class BaseRecordingSnippets(BaseExtractor):

has_default_locations = False

def __init__(self, sampling_frequency: float, channel_ids: List, dtype):
def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype: np.dtype):
BaseExtractor.__init__(self, channel_ids)
self._sampling_frequency = sampling_frequency
self._dtype = np.dtype(dtype)
Expand Down
2 changes: 0 additions & 2 deletions src/spikeinterface/core/basesnippets.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import List, Union
from pathlib import Path
from .base import BaseSegment
from .baserecordingsnippets import BaseRecordingSnippets
import numpy as np
from warnings import warn
from probeinterface import Probe, ProbeGroup, write_probeinterface, read_probeinterface, select_axes

# snippets segments?

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def register_recording(self, recording, check_spike_frames=True):
if check_spike_frames:
if has_exceeding_spikes(recording, self):
warnings.warn(
"Some spikes are exceeding the recording's duration! "
"Some spikes exceed the recording's duration! "
"Removing these excess spikes with `spikeinterface.curation.remove_excess_spikes()` "
"Might be necessary for further postprocessing."
)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/binaryrecordingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
file_path_list = [Path(file_paths)]

if t_starts is not None:
assert len(t_starts) == len(file_path_list), "t_starts must be a list of same size than file_paths"
assert len(t_starts) == len(file_path_list), "t_starts must be a list of the same size as file_paths"
t_starts = [float(t_start) for t_start in t_starts]

dtype = np.dtype(dtype)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/channelsaggregationrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ def __init__(self, channel_map, parent_segments):
times_kargs0 = parent_segment0.get_times_kwargs()
if times_kargs0["time_vector"] is None:
for ps in parent_segments:
assert ps.get_times_kwargs()["time_vector"] is None, "All segment should not have times set"
assert ps.get_times_kwargs()["time_vector"] is None, "All segments should not have times set"
else:
for ps in parent_segments:
assert ps.get_times_kwargs()["t_start"] == times_kargs0["t_start"], (
"All segment should have the same " "t_start"
"All segments should have the same " "t_start"
)

BaseRecordingSegment.__init__(self, **times_kargs0)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/channelslice.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None)
), "ChannelSliceRecording: renamed channel_ids must be the same size"
assert (
self._channel_ids.size == np.unique(self._channel_ids).size
), "ChannelSliceRecording : channel_ids not unique"
), "ChannelSliceRecording : channel_ids are not unique"

sampling_frequency = parent_recording.get_sampling_frequency()

Expand Down Expand Up @@ -123,7 +123,7 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None):
), "ChannelSliceSnippets: renamed channel_ids must be the same size"
assert (
self._channel_ids.size == np.unique(self._channel_ids).size
), "ChannelSliceSnippets : channel_ids not unique"
), "ChannelSliceSnippets : channel_ids are not unique"

sampling_frequency = parent_snippets.get_sampling_frequency()

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/frameslicerecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class FrameSliceRecording(BaseRecording):
def __init__(self, parent_recording, start_frame=None, end_frame=None):
channel_ids = parent_recording.get_channel_ids()

assert parent_recording.get_num_segments() == 1, "FrameSliceRecording work only with one segment"
assert parent_recording.get_num_segments() == 1, "FrameSliceRecording only works with one segment"

parent_size = parent_recording.get_num_samples(0)
if start_frame is None:
Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/core/frameslicesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class FrameSliceSorting(BaseSorting):
def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike_frames=True):
unit_ids = parent_sorting.get_unit_ids()

assert parent_sorting.get_num_segments() == 1, "FrameSliceSorting work only with one segment"
assert parent_sorting.get_num_segments() == 1, "FrameSliceSorting only works with one segment"

if start_frame is None:
start_frame = 0
Expand All @@ -49,10 +49,10 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike
end_frame = parent_n_samples
assert (
end_frame <= parent_n_samples
), "`end_frame` should be smaller than the sortings total number of samples."
), "`end_frame` should be smaller than the sortings' total number of samples."
assert (
start_frame <= parent_n_samples
), "`start_frame` should be smaller than the sortings total number of samples."
), "`start_frame` should be smaller than the sortings' total number of samples."
if check_spike_frames and has_exceeding_spikes(parent_sorting._recording, parent_sorting):
raise ValueError(
"The sorting object has spikes exceeding the recording duration. You have to remove those spikes "
Expand All @@ -67,7 +67,7 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike
end_frame = max_spike_time + 1

assert start_frame < end_frame, (
"`start_frame` should be greater than `end_frame`. "
"`start_frame` should be less than `end_frame`. "
"This may be due to start_frame >= max_spike_time, if the end frame "
"was not specified explicitly."
)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,11 +1101,11 @@ def __init__(
# handle also upsampling and jitter
upsample_factor = templates.shape[3]
elif templates.ndim == 5:
# handle also dirft
# handle also drift
raise NotImplementedError("Drift will be implented soon...")
# upsample_factor = templates.shape[3]
else:
raise ValueError("templates have wring dim should 3 or 4")
raise ValueError("templates have wrong dim should 3 or 4")

if upsample_factor is not None:
assert upsample_vector is not None
Expand Down
5 changes: 4 additions & 1 deletion src/spikeinterface/core/npysnippetsextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def __init__(
num_segments = len(file_paths)
data = np.load(file_paths[0], mmap_mode="r")

if channel_ids is None:
channel_ids = np.arange(data["snippet"].shape[2])

BaseSnippets.__init__(
self,
sampling_frequency,
Expand Down Expand Up @@ -84,7 +87,7 @@ def write_snippets(snippets, file_paths, dtype=None):
arr = np.empty(n, dtype=snippets_t, order="F")
arr["frame"] = snippets.get_frames(segment_index=i)
arr["snippet"] = snippets.get_snippets(segment_index=i).astype(dtype, copy=False)

file_paths[i].parent.mkdir(parents=True, exist_ok=True)
np.save(file_paths[i], arr)


Expand Down
6 changes: 5 additions & 1 deletion src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ def __init__(self, mask, unit_ids, channel_ids):

self.num_channels = self.channel_ids.size
self.num_units = self.unit_ids.size
self.max_num_active_channels = self.mask.sum(axis=1).max()
if self.mask.shape[0]:
self.max_num_active_channels = self.mask.sum(axis=1).max()
else:
# empty sorting without units
self.max_num_active_channels = 0

def __repr__(self):
density = np.mean(self.mask)
Expand Down
48 changes: 28 additions & 20 deletions src/spikeinterface/core/template_tools.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
from __future__ import annotations
import numpy as np
import warnings

from .sparsity import compute_sparsity, _sparsity_doc
from .recording_tools import get_channel_distances, get_noise_levels


def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: str = "extremum"):
def get_template_amplitudes(
waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum"
):
"""
Get amplitude per channel for each unit.
Parameters
----------
waveform_extractor: WaveformExtractor
The waveform extractor
peak_sign: str
Sign of the template to compute best channels ('neg', 'pos', 'both')
mode: str
peak_sign: "neg" | "pos" | "both", default: "neg"
Sign of the template to compute best channels
mode: "extremum" | "at_index", default: "extremum"
'extremum': max or min
'at_index': take value at spike index
Expand All @@ -24,8 +27,8 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st
peak_values: dict
Dictionary with unit ids as keys and template amplitudes as values
"""
assert peak_sign in ("both", "neg", "pos")
assert mode in ("extremum", "at_index")
assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'both', 'neg', or 'pos'"
assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'"
unit_ids = waveform_extractor.sorting.unit_ids

before = waveform_extractor.nbefore
Expand Down Expand Up @@ -57,7 +60,10 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st


def get_template_extremum_channel(
waveform_extractor, peak_sign: str = "neg", mode: str = "extremum", outputs: str = "id"
waveform_extractor,
peak_sign: "neg" | "pos" | "both" = "neg",
mode: "extremum" | "at_index" = "extremum",
outputs: "id" | "index" = "id",
):
"""
Compute the channel with the extremum peak for each unit.
Expand All @@ -66,12 +72,12 @@ def get_template_extremum_channel(
----------
waveform_extractor: WaveformExtractor
The waveform extractor
peak_sign: str
Sign of the template to compute best channels ('neg', 'pos', 'both')
mode: str
peak_sign: "neg" | "pos" | "both", default: "neg"
Sign of the template to compute best channels
mode: "extremum" | "at_index", default: "extremum"
'extremum': max or min
'at_index': take value at spike index
outputs: str
outputs: "id" | "index", default: "id"
* 'id': channel id
* 'index': channel index
Expand Down Expand Up @@ -159,7 +165,7 @@ def get_template_channel_sparsity(
get_template_channel_sparsity.__doc__ = get_template_channel_sparsity.__doc__.format(_sparsity_doc)


def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str = "neg"):
def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg"):
"""
In some situations spike sorters could return a spike index with a small shift related to the waveform peak.
This function estimates and return these alignment shifts for the mean template.
Expand All @@ -169,8 +175,8 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str
----------
waveform_extractor: WaveformExtractor
The waveform extractor
peak_sign: str
Sign of the template to compute best channels ('neg', 'pos', 'both')
peak_sign: "neg" | "pos" | "both", default: "neg"
Sign of the template to compute best channels
Returns
-------
Expand Down Expand Up @@ -203,17 +209,19 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str
return shifts


def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg", mode: str = "at_index"):
def get_template_extremum_amplitude(
waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "at_index"
):
"""
Computes amplitudes on the best channel.
Parameters
----------
waveform_extractor: WaveformExtractor
The waveform extractor
peak_sign: str
Sign of the template to compute best channels ('neg', 'pos', 'both')
mode: str
peak_sign: "neg" | "pos" | "both"
Sign of the template to compute best channels
mode: "extremum" | "at_index", default: "at_index"
Where the amplitude is computed
'extremum': max or min
'at_index': take value at spike index
Expand All @@ -223,8 +231,8 @@ def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg",
amplitudes: dict
Dictionary with unit ids as keys and amplitudes as values
"""
assert peak_sign in ("both", "neg", "pos")
assert mode in ("extremum", "at_index")
assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'neg' or 'pos' or 'both'"
assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'"
unit_ids = waveform_extractor.sorting.unit_ids

before = waveform_extractor.nbefore
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/core/tests/test_waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,3 +559,4 @@ def test_non_json_object():
test_recordingless()
# test_compute_sparsity()
# test_non_json_object()
test_empty_sorting()
2 changes: 1 addition & 1 deletion src/spikeinterface/core/unitsaggregationsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(self, sorting_list, renamed_unit_ids=None):
try:
property_dict[prop_name] = np.concatenate((property_dict[prop_name], values))
except Exception as e:
print(f"Skipping property '{prop_name}' for shape inconsistency")
print(f"Skipping property '{prop_name}' due to shape inconsistency")
del property_dict[prop_name]
break
for prop_name, prop_values in property_dict.items():
Expand Down
9 changes: 5 additions & 4 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,13 +1457,13 @@ def extract_waveforms(
folder=None,
mode="folder",
precompute_template=("average",),
ms_before=3.0,
ms_after=4.0,
ms_before=1.0,
ms_after=2.0,
max_spikes_per_unit=500,
overwrite=False,
return_scaled=True,
dtype=None,
sparse=False,
sparse=True,
sparsity=None,
num_spikes_for_sparsity=100,
allow_unfiltered=False,
Expand Down Expand Up @@ -1507,7 +1507,7 @@ def extract_waveforms(
If True and recording has gain_to_uV/offset_to_uV properties, waveforms are converted to uV.
dtype: dtype or None
Dtype of the output waveforms. If None, the recording dtype is maintained.
sparse: bool (default False)
sparse: bool, default: True
If True, before extracting all waveforms the `precompute_sparsity()` function is run using
a few spikes to get an estimate of dense templates to create a ChannelSparsity object.
Then, the waveforms will be sparse at extraction time, which saves a lot of memory.
Expand Down Expand Up @@ -1726,6 +1726,7 @@ def precompute_sparsity(
max_spikes_per_unit=num_spikes_for_sparsity,
return_scaled=False,
allow_unfiltered=allow_unfiltered,
sparse=False,
**job_kwargs,
)
local_sparsity = compute_sparsity(local_we, **sparse_kwargs)
Expand Down
Loading

0 comments on commit f631484

Please sign in to comment.