Skip to content

Commit

Permalink
Merge pull request #2109 from zm711/main
Browse files Browse the repository at this point in the history
Improve assert messaging (exporters, curation, part of sorters)
  • Loading branch information
alejoe91 authored Oct 17, 2023
2 parents 4c76371 + 82ee456 commit 2e6f7ce
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 30 deletions.
22 changes: 17 additions & 5 deletions src/spikeinterface/curation/curation_tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from typing import Optional
import numpy as np

Expand All @@ -9,9 +10,15 @@
except ModuleNotFoundError as err:
HAVE_NUMBA = False

_methods = ("keep_first", "random", "keep_last", "keep_first_iterative", "keep_last_iterative")
_methods_numpy = ("keep_first", "random", "keep_last")


def _find_duplicated_spikes_numpy(
spike_train: np.ndarray, censored_period: int, seed: Optional[int] = None, method: str = "keep_first"
spike_train: np.ndarray,
censored_period: int,
seed: Optional[int] = None,
method: "keep_first" | "random" | "keep_last" = "keep_first",
) -> np.ndarray:
(indices_of_duplicates,) = np.where(np.diff(spike_train) <= censored_period)

Expand All @@ -29,7 +36,9 @@ def _find_duplicated_spikes_numpy(

(indices_of_duplicates,) = np.where(~mask)
elif method != "keep_last":
raise ValueError(f"Method '{method}' isn't a valid method for _find_duplicated_spikes_numpy.")
raise ValueError(
f"Method '{method}' isn't a valid method for _find_duplicated_spikes_numpy use one of {_methods_numpy}."
)

return indices_of_duplicates

Expand Down Expand Up @@ -84,7 +93,10 @@ def _find_duplicated_spikes_keep_last_iterative(spike_train, censored_period):


def find_duplicated_spikes(
spike_train, censored_period: int, method: str = "random", seed: Optional[int] = None
spike_train,
censored_period: int,
method: "keep_first" | "keep_last" | "keep_first_iterative" | "keep_last_iterative" | "random" = "random",
seed: Optional[int] = None,
) -> np.ndarray:
"""
Finds the indices where spikes should be considered duplicates.
Expand All @@ -97,7 +109,7 @@ def find_duplicated_spikes(
The spike train on which to look for duplicated spikes.
censored_period: int
The censored period for duplicates (in sample time).
method: str in ("keep_first", "keep_last", "keep_first_iterative', 'keep_last_iterative", random")
method: "keep_first" |"keep_last" | "keep_first_iterative' | 'keep_last_iterative" |random"
Method used to remove the duplicated spikes.
seed: int | None
The seed to use if method="random".
Expand All @@ -120,4 +132,4 @@ def find_duplicated_spikes(
assert HAVE_NUMBA, "'keep_last' method requires numba. Install it with >>> pip install numba"
return _find_duplicated_spikes_keep_last_iterative(spike_train.astype(np.int64), censored_period)
else:
raise ValueError(f"Method '{method}' isn't a valid method for find_duplicated_spikes.")
raise ValueError(f"Method '{method}' isn't a valid method for find_duplicated_spikes. Use one of {_methods}")
12 changes: 6 additions & 6 deletions src/spikeinterface/curation/curationsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,24 +148,24 @@ def remove_empty_units(self):
edges = None
self._add_new_stage(new_sorting, edges)

def redo_avaiable(self):
def redo_available(self):
# useful function for a gui
return self._sorting_stages_i < len(self._sorting_stages)

def undo_avaiable(self):
def undo_available(self):
# useful function for a gui
return self._sorting_stages_i > 0

def undo(self):
if self.undo_avaiable():
if self.undo_available():
self._sorting_stages_i -= 1

def redo(self):
if self.redo_avaiable():
if self.redo_available():
self._sorting_stages_i += 1

def draw_graph(self, **kwargs):
assert self._make_graph, "to make a graph make_graph=True"
assert self._make_graph, "to make a graph use make_graph=True"
graph = self.graph
ids = [c.unit_id for c in graph.nodes]
pos = {n: (n.stage_id, -ids.index(n.unit_id)) for n in graph.nodes}
Expand All @@ -174,7 +174,7 @@ def draw_graph(self, **kwargs):

@property
def graph(self):
assert self._make_graph, "to have a graph make_graph=True"
assert self._make_graph, "to have a graph use make_graph=True"
return self._graphs[self._sorting_stages_i]

@property
Expand Down
12 changes: 8 additions & 4 deletions src/spikeinterface/curation/remove_redundant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import numpy as np

from spikeinterface import WaveformExtractor
Expand All @@ -6,6 +7,9 @@
from ..postprocessing import align_sorting


_remove_strategies = ("minimum_shift", "highest_amplitude", "max_spikes")


def remove_redundant_units(
sorting_or_waveform_extractor,
align=True,
Expand Down Expand Up @@ -42,15 +46,15 @@ def remove_redundant_units(
duplicate_threshold : float, optional
Final threshold on the portion of coincident events over the number of spikes above which the
unit is removed, by default 0.8
remove_strategy: str
remove_strategy: 'minimum_shift' | 'highest_amplitude' | 'max_spikes', default: 'minimum_shift'
Which strategy to remove one of the two duplicated units:
* 'minimum_shift': keep the unit with best peak alignment (minimum shift)
If shifts are equal then the 'highest_amplitude' is used
* 'highest_amplitude': keep the unit with the best amplitude on unshifted max.
* 'max_spikes': keep the unit with more spikes
peak_sign: str ('neg', 'pos', 'both')
peak_sign: 'neg' |'pos' | 'both', default: 'neg'
Used when remove_strategy='highest_amplitude'
extra_outputs: bool
If True, will return the redundant pairs.
Expand Down Expand Up @@ -93,7 +97,7 @@ def remove_redundant_units(
peak_values = {unit_id: np.max(np.abs(values)) for unit_id, values in peak_values.items()}

if remove_strategy == "minimum_shift":
assert align, "remove_strategy with minimum_shift need align=True"
assert align, "remove_strategy with minimum_shift needs align=True"
for u1, u2 in redundant_unit_pairs:
if np.abs(unit_peak_shifts[u1]) > np.abs(unit_peak_shifts[u2]):
remove_unit_ids.append(u1)
Expand Down Expand Up @@ -125,7 +129,7 @@ def remove_redundant_units(
# this will be implemented in a futur PR by the first who need it!
raise NotImplementedError()
else:
raise ValueError(f"remove_strategy : {remove_strategy} is not implemented!")
raise ValueError(f"remove_strategy : {remove_strategy} is not implemented! Options are {_remove_strategies}")

sorting_clean = sorting.remove_units(remove_unit_ids)

Expand Down
11 changes: 5 additions & 6 deletions src/spikeinterface/curation/splitunitsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ class SplitUnitSorting(BaseSorting):
be the same length as the spike train (for each segment)
new_unit_ids: int
Unit ids of the new units to be created.
properties_policy: str
properties_policy: 'keep' | 'remove', default: 'keep'
Policy used to propagate properties. If 'keep' the properties will be passed to the new units
(if the units_to_merge have the same value). If 'remove' the new units will have an empty
value for all the properties of the new unit.
Default: 'keep'
Returns
-------
sorting: Sorting
Expand All @@ -48,19 +47,19 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non
new_unit_ids = np.array([u + new_unit_ids for u in range(tot_splits)], dtype=parents_unit_ids.dtype)
else:
new_unit_ids = np.array(new_unit_ids, dtype=parents_unit_ids.dtype)
assert len(np.unique(new_unit_ids)) == len(new_unit_ids), "Each element in new_unit_ids should be unique"
assert len(new_unit_ids) <= tot_splits, "indices_list have more ids indices than the length of new_unit_ids"
assert len(np.unique(new_unit_ids)) == len(new_unit_ids), "Each element in new_unit_ids must be unique"
assert len(new_unit_ids) <= tot_splits, "indices_list has more id indices than the length of new_unit_ids"

assert parent_sorting.get_num_segments() == len(
indices_list
), "The length of indices_list must be the same as parent_sorting.get_num_segments"
assert split_unit_id in parents_unit_ids, "Unit to split should be in parent sorting"
assert split_unit_id in parents_unit_ids, "Unit to split must be in parent sorting"
assert properties_policy == "keep" or properties_policy == "remove", (
"properties_policy must be " "keep" " or " "remove" ""
)
assert not any(
np.isin(new_unit_ids, unchanged_units)
), "new_unit_ids should be new units or one could be equal to split_unit_id"
), "new_unit_ids should be new unit ids or no more than one unit id can be found in split_unit_id"

sampling_frequency = parent_sorting.get_sampling_frequency()
units_ids = np.concatenate([unchanged_units, new_unit_ids])
Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/exporters/to_phy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def export_to_phy(
), "waveform_extractor must be a WaveformExtractor object"
sorting = waveform_extractor.sorting

assert waveform_extractor.get_num_segments() == 1, "Export to phy only works with one segment"
assert (
waveform_extractor.get_num_segments() == 1
), f"Export to phy only works with one segment, your extractor has {waveform_extractor.get_num_segments()} segments"
num_chans = waveform_extractor.get_num_channels()
fs = waveform_extractor.sampling_frequency

Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/sorters/basesorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo
)

if not isinstance(recording, BaseRecordingSnippets):
raise ValueError("recording must be a Recording or Snippets!!")
raise ValueError("recording must be a Recording or a Snippets!!")

if cls.requires_locations:
locations = recording.get_channel_locations()
Expand Down Expand Up @@ -133,7 +133,7 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo
if recording.get_num_segments() > 1:
if not cls.handle_multi_segment:
raise ValueError(
f"This sorter {cls.sorter_name} do not handle multi segment, use si.concatenate_recordings(...)"
f"This sorter {cls.sorter_name} does not handle multi-segment recordings, use si.concatenate_recordings(...)"
)

rec_file = output_folder / "spikeinterface_recording.json"
Expand Down Expand Up @@ -299,7 +299,7 @@ def get_result_from_folder(cls, output_folder, register_recording=True, sorting_
# check errors in log file
log_file = output_folder / "spikeinterface_log.json"
if not log_file.is_file():
raise SpikeSortingError("get result error: the folder does not contain the `spikeinterface_log.json` file")
raise SpikeSortingError("Get result error: the folder does not contain the `spikeinterface_log.json` file")

with log_file.open("r", encoding="utf8") as f:
log = json.load(f)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/sorters/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def run_sorters(
mode_if_folder_exists in ("raise", "keep", "overwrite")

if mode_if_folder_exists == "raise" and working_folder.is_dir():
raise Exception("working_folder already exists, please remove it")
raise Exception(f"working_folder {working_folder} already exists, please remove it")

assert engine in _implemented_engine, f"engine must be in {_implemented_engine}"

Expand All @@ -390,7 +390,7 @@ def run_sorters(
elif isinstance(recording_dict_or_list, dict):
recording_dict = recording_dict_or_list
else:
raise ValueError("bad recording dict")
raise ValueError("Wrong format for recording_dict_or_list")

dtype_rec_name = np.dtype(type(list(recording_dict.keys())[0]))
assert dtype_rec_name.kind in ("i", "u", "S", "U"), "Dict keys can only be integers or strings!"
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/sorters/sorterlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def get_default_sorter_params(sorter_name_or_class):
elif sorter_name_or_class in sorter_full_list:
SorterClass = sorter_name_or_class
else:
raise (ValueError("Unknown sorter"))
raise (ValueError(f"Unknown sorter {sorter_name_or_class} has been given"))

return SorterClass.default_params()

Expand All @@ -113,7 +113,7 @@ def get_sorter_params_description(sorter_name_or_class):
elif sorter_name_or_class in sorter_full_list:
SorterClass = sorter_name_or_class
else:
raise (ValueError("Unknown sorter"))
raise (ValueError(f"Unknown sorter {sorter_name_or_class} has been given"))

return SorterClass.params_description()

Expand All @@ -137,6 +137,6 @@ def get_sorter_description(sorter_name_or_class):
elif sorter_name_or_class in sorter_full_list:
SorterClass = sorter_name_or_class
else:
raise (ValueError("Unknown sorter"))
raise (ValueError(f"Unknown sorter {sorter_name_or_class} has been given"))

return SorterClass.sorter_description

0 comments on commit 2e6f7ce

Please sign in to comment.