Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve assert messages (preprocessing & core) #2078

Merged
merged 3 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, main_ids: Sequence) -> None:
self._kwargs = {}

# 'main_ids' will either be channel_ids or units_ids
# They is used for properties
# They are used for properties
self._main_ids = np.array(main_ids)

# dict at object level
Expand Down Expand Up @@ -984,7 +984,7 @@ def _load_extractor_from_dict(dic) -> BaseExtractor:
class_name = None

if "kwargs" not in dic:
raise Exception(f"This dict cannot be load into extractor {dic}")
raise Exception(f"This dict cannot be loaded into extractor {dic}")

# Create new kwargs to avoid modifying the original dict["kwargs"]
new_kwargs = dict()
Expand All @@ -1005,7 +1005,7 @@ def _load_extractor_from_dict(dic) -> BaseExtractor:
assert extractor_class is not None and class_name is not None, "Could not load spikeinterface class"
if not _check_same_version(class_name, dic["version"]):
warnings.warn(
f"Versions are not the same. This might lead compatibility errors. "
f"Versions are not the same. This might lead to compatibility errors. "
f"Using {class_name.split('.')[0]}=={dic['version']} is recommended"
)

Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def get_traces(

if not self.has_scaled():
raise ValueError(
"This recording do not support return_scaled=True (need gain_to_uV and offset_" "to_uV properties)"
"This recording does not support return_scaled=True (need gain_to_uV and offset_"
"to_uV properties)"
)
else:
gains = self.get_property("gain_to_uV")
Expand Down Expand Up @@ -416,8 +417,8 @@ def set_times(self, times, segment_index=None, with_warning=True):
if with_warning:
warn(
"Setting times with Recording.set_times() is not recommended because "
"times are not always propagated to across preprocessing"
"Use use this carefully!"
"times are not always propagated across preprocessing"
"Use this carefully!"
)

def sample_index_to_time(self, sample_ind, segment_index=None):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def register_recording(self, recording, check_spike_frames=True):
if check_spike_frames:
if has_exceeding_spikes(recording, self):
warnings.warn(
"Some spikes are exceeding the recording's duration! "
"Some spikes exceed the recording's duration! "
"Removing these excess spikes with `spikeinterface.curation.remove_excess_spikes()` "
"Might be necessary for further postprocessing."
)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/binaryrecordingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
file_path_list = [Path(file_paths)]

if t_starts is not None:
assert len(t_starts) == len(file_path_list), "t_starts must be a list of same size than file_paths"
assert len(t_starts) == len(file_path_list), "t_starts must be a list of the same size as file_paths"
t_starts = [float(t_start) for t_start in t_starts]

dtype = np.dtype(dtype)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/channelsaggregationrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ def __init__(self, channel_map, parent_segments):
times_kargs0 = parent_segment0.get_times_kwargs()
if times_kargs0["time_vector"] is None:
for ps in parent_segments:
assert ps.get_times_kwargs()["time_vector"] is None, "All segment should not have times set"
assert ps.get_times_kwargs()["time_vector"] is None, "All segments should not have times set"
else:
for ps in parent_segments:
assert ps.get_times_kwargs()["t_start"] == times_kargs0["t_start"], (
"All segment should have the same " "t_start"
"All segments should have the same " "t_start"
)

BaseRecordingSegment.__init__(self, **times_kargs0)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/channelslice.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None)
), "ChannelSliceRecording: renamed channel_ids must be the same size"
assert (
self._channel_ids.size == np.unique(self._channel_ids).size
), "ChannelSliceRecording : channel_ids not unique"
), "ChannelSliceRecording : channel_ids are not unique"

sampling_frequency = parent_recording.get_sampling_frequency()

Expand Down Expand Up @@ -123,7 +123,7 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None):
), "ChannelSliceSnippets: renamed channel_ids must be the same size"
assert (
self._channel_ids.size == np.unique(self._channel_ids).size
), "ChannelSliceSnippets : channel_ids not unique"
), "ChannelSliceSnippets : channel_ids are not unique"

sampling_frequency = parent_snippets.get_sampling_frequency()

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/frameslicerecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class FrameSliceRecording(BaseRecording):
def __init__(self, parent_recording, start_frame=None, end_frame=None):
channel_ids = parent_recording.get_channel_ids()

assert parent_recording.get_num_segments() == 1, "FrameSliceRecording work only with one segment"
assert parent_recording.get_num_segments() == 1, "FrameSliceRecording only works with one segment"

parent_size = parent_recording.get_num_samples(0)
if start_frame is None:
Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/core/frameslicesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class FrameSliceSorting(BaseSorting):
def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike_frames=True):
unit_ids = parent_sorting.get_unit_ids()

assert parent_sorting.get_num_segments() == 1, "FrameSliceSorting work only with one segment"
assert parent_sorting.get_num_segments() == 1, "FrameSliceSorting only works with one segment"

if start_frame is None:
start_frame = 0
Expand All @@ -49,10 +49,10 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike
end_frame = parent_n_samples
assert (
end_frame <= parent_n_samples
), "`end_frame` should be smaller than the sortings total number of samples."
), "`end_frame` should be smaller than the sortings' total number of samples."
assert (
start_frame <= parent_n_samples
), "`start_frame` should be smaller than the sortings total number of samples."
), "`start_frame` should be smaller than the sortings' total number of samples."
if check_spike_frames and has_exceeding_spikes(parent_sorting._recording, parent_sorting):
raise ValueError(
"The sorting object has spikes exceeding the recording duration. You have to remove those spikes "
Expand All @@ -67,7 +67,7 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike
end_frame = max_spike_time + 1

assert start_frame < end_frame, (
"`start_frame` should be greater than `end_frame`. "
"`start_frame` should be less than `end_frame`. "
"This may be due to start_frame >= max_spike_time, if the end frame "
"was not specified explicitly."
)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,11 +1101,11 @@ def __init__(
# handle also upsampling and jitter
upsample_factor = templates.shape[3]
elif templates.ndim == 5:
# handle also dirft
# handle also drift
raise NotImplementedError("Drift will be implented soon...")
# upsample_factor = templates.shape[3]
else:
raise ValueError("templates have wring dim should 3 or 4")
raise ValueError("templates have wrong dim should 3 or 4")

if upsample_factor is not None:
assert upsample_vector is not None
Expand Down
48 changes: 28 additions & 20 deletions src/spikeinterface/core/template_tools.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
from __future__ import annotations
import numpy as np
import warnings

from .sparsity import compute_sparsity, _sparsity_doc
from .recording_tools import get_channel_distances, get_noise_levels


def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: str = "extremum"):
def get_template_amplitudes(
waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum"
):
"""
Get amplitude per channel for each unit.

Parameters
----------
waveform_extractor: WaveformExtractor
The waveform extractor
peak_sign: str
Sign of the template to compute best channels ('neg', 'pos', 'both')
mode: str
peak_sign: "neg" | "pos" | "both", default: "neg"
Sign of the template to compute best channels
mode: "extremum" | "at_index", default: "extremum"
'extremum': max or min
'at_index': take value at spike index

Expand All @@ -24,8 +27,8 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st
peak_values: dict
Dictionary with unit ids as keys and template amplitudes as values
"""
assert peak_sign in ("both", "neg", "pos")
assert mode in ("extremum", "at_index")
assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'both', 'neg', or 'pos'"
assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'"
unit_ids = waveform_extractor.sorting.unit_ids

before = waveform_extractor.nbefore
Expand Down Expand Up @@ -57,7 +60,10 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st


def get_template_extremum_channel(
waveform_extractor, peak_sign: str = "neg", mode: str = "extremum", outputs: str = "id"
waveform_extractor,
peak_sign: "neg" | "pos" | "both" = "neg",
mode: "extremum" | "at_index" = "extremum",
outputs: "id" | "index" = "id",
):
"""
Compute the channel with the extremum peak for each unit.
Expand All @@ -66,12 +72,12 @@ def get_template_extremum_channel(
----------
waveform_extractor: WaveformExtractor
The waveform extractor
peak_sign: str
Sign of the template to compute best channels ('neg', 'pos', 'both')
mode: str
peak_sign: "neg" | "pos" | "both", default: "neg"
Sign of the template to compute best channels
mode: "extremum" | "at_index", default: "extremum"
'extremum': max or min
'at_index': take value at spike index
outputs: str
outputs: "id" | "index", default: "id"
* 'id': channel id
* 'index': channel index

Expand Down Expand Up @@ -159,7 +165,7 @@ def get_template_channel_sparsity(
get_template_channel_sparsity.__doc__ = get_template_channel_sparsity.__doc__.format(_sparsity_doc)


def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str = "neg"):
def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg"):
"""
In some situations spike sorters could return a spike index with a small shift related to the waveform peak.
This function estimates and return these alignment shifts for the mean template.
Expand All @@ -169,8 +175,8 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str
----------
waveform_extractor: WaveformExtractor
The waveform extractor
peak_sign: str
Sign of the template to compute best channels ('neg', 'pos', 'both')
peak_sign: "neg" | "pos" | "both", default: "neg"
Sign of the template to compute best channels

Returns
-------
Expand Down Expand Up @@ -203,17 +209,19 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str
return shifts


def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg", mode: str = "at_index"):
def get_template_extremum_amplitude(
waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "at_index"
):
"""
Computes amplitudes on the best channel.

Parameters
----------
waveform_extractor: WaveformExtractor
The waveform extractor
peak_sign: str
Sign of the template to compute best channels ('neg', 'pos', 'both')
mode: str
peak_sign: "neg" | "pos" | "both"
Sign of the template to compute best channels
mode: "extremum" | "at_index", default: "at_index"
Where the amplitude is computed
'extremum': max or min
'at_index': take value at spike index
Expand All @@ -223,8 +231,8 @@ def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg",
amplitudes: dict
Dictionary with unit ids as keys and amplitudes as values
"""
assert peak_sign in ("both", "neg", "pos")
assert mode in ("extremum", "at_index")
assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'neg' or 'pos' or 'both'"
assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'"
unit_ids = waveform_extractor.sorting.unit_ids

before = waveform_extractor.nbefore
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/unitsaggregationsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(self, sorting_list, renamed_unit_ids=None):
try:
property_dict[prop_name] = np.concatenate((property_dict[prop_name], values))
except Exception as e:
print(f"Skipping property '{prop_name}' for shape inconsistency")
print(f"Skipping property '{prop_name}' due to shape inconsistency")
del property_dict[prop_name]
break
for prop_name, prop_values in property_dict.items():
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
chunk_size=500,
seed=0,
):
assert direction in ("upper", "lower", "both")
assert direction in ("upper", "lower", "both"), "'direction' must be 'upper', 'lower', or 'both'"

if fill_value is None or quantile_threshold is not None:
random_data = get_random_data_chunks(
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/common_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
ref_channel_ids = np.asarray(ref_channel_ids)
assert np.all(
[ch in recording.get_channel_ids() for ch in ref_channel_ids]
), "Some wrong 'ref_channel_ids'!"
), "Some 'ref_channel_ids' are wrong!"
elif reference == "local":
assert groups is None, "With 'local' CAR, the group option should not be used."
closest_inds, dist = get_closest_channels(recording)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/preprocessing/detect_bad_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ def detect_bad_channels(

if bad_channel_ids.size > recording.get_num_channels() / 3:
warnings.warn(
"Over 1/3 of channels are detected as bad. In the precense of a high"
"Over 1/3 of channels are detected as bad. In the presence of a high"
"number of dead / noisy channels, bad channel detection may fail "
"(erroneously label good channels as dead)."
"(good channels may be erroneously labeled as dead)."
)

elif method == "neighborhood_r2":
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def __init__(
):
import scipy.signal

assert filter_mode in ("sos", "ba")
assert filter_mode in ("sos", "ba"), "'filter' mode must be 'sos' or 'ba'"
fs = recording.get_sampling_frequency()
if coeff is None:
assert btype in ("bandpass", "highpass")
assert btype in ("bandpass", "highpass"), "'bytpe' must be 'bandpass' or 'highpass'"
# coefficient
# self.coeff is 'sos' or 'ab' style
filter_coeff = scipy.signal.iirfilter(
Expand Down Expand Up @@ -258,7 +258,7 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None):
if dtype.kind == "u":
raise TypeError(
"The notch filter only supports signed types. Use the 'dtype' argument"
"to specify a signed type (e.g. 'int16', 'float32'"
"to specify a signed type (e.g. 'int16', 'float32')"
)

BasePreprocessor.__init__(self, recording, dtype=dtype)
Expand Down
12 changes: 6 additions & 6 deletions src/spikeinterface/preprocessing/filter_opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def __init__(
margin_ms=5.0,
):
assert HAVE_PYOPENCL, "You need to install pyopencl (and GPU driver!!)"

assert btype in ("bandpass", "lowpass", "highpass", "bandstop")
assert filter_mode in ("sos",)
btype_modes = ("bandpass", "lowpass", "highpass", "bandstop")
assert btype in btype_modes, f"'btype' must be in {btype_modes}"
assert filter_mode in ("sos",), "'filter_mode' must be 'sos'"

# coefficient
sf = recording.get_sampling_frequency()
Expand Down Expand Up @@ -96,8 +96,8 @@ def __init__(self, parent_recording_segment, executor, margin):
self.margin = margin

def get_traces(self, start_frame, end_frame, channel_indices):
assert start_frame is not None, "FilterOpenCLRecording work with fixed chunk_size"
assert end_frame is not None, "FilterOpenCLRecording work with fixed chunk_size"
assert start_frame is not None, "FilterOpenCLRecording only works with fixed chunk_size"
assert end_frame is not None, "FilterOpenCLRecording only works with fixed chunk_size"

chunk_size = end_frame - start_frame
if chunk_size != self.executor.chunk_size:
Expand Down Expand Up @@ -157,7 +157,7 @@ def process(self, traces):

if traces.shape[0] != self.full_size:
if self.full_size is not None:
print(f"Warning : chunk_size have change {self.chunk_size} {traces.shape[0]}, need recompile CL!!!")
print(f"Warning : chunk_size has changed {self.chunk_size} {traces.shape[0]}, need to recompile CL!!!")
self.create_buffers_and_compile()

event = pyopencl.enqueue_copy(self.queue, self.input_cl, traces)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
traces = traces * self.taper[np.newaxis, :]

# apply actual HP filter
import scipy
import scipy.signal

traces = scipy.signal.sosfiltfilt(self.sos_filter, traces, axis=1)

Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/preprocessing/normalize_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
dtype="float32",
**random_chunk_kwargs,
):
assert mode in ("pool_channel", "by_channel")
assert mode in ("pool_channel", "by_channel"), "'mode' must be 'pool_channel' or 'by_channel'"

random_data = get_random_data_chunks(recording, **random_chunk_kwargs)

Expand Down Expand Up @@ -260,7 +260,7 @@ def __init__(
dtype="float32",
**random_chunk_kwargs,
):
assert mode in ("median+mad", "mean+std")
assert mode in ("median+mad", "mean+std"), "'mode' must be 'median+mad' or 'mean+std'"

# fix dtype
dtype_ = fix_dtype(recording, dtype)
Expand Down
Loading