From 2307c9b5a302c6532270a54677af0f2f139f2f49 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 25 Jul 2024 09:42:30 +0100 Subject: [PATCH 1/3] Comparison, Generation, Postprocessing, QualityMetrics, SortingComponents docstrings compliance --- src/spikeinterface/comparison/collision.py | 14 ++- src/spikeinterface/comparison/correlogram.py | 15 +++ .../comparison/groundtruthstudy.py | 6 +- src/spikeinterface/comparison/hybrid.py | 2 + .../comparison/multicomparisons.py | 10 ++ .../comparison/paircomparisons.py | 54 +++++++- .../core/analyzer_extension_core.py | 6 +- src/spikeinterface/core/generate.py | 118 ++++++++++-------- src/spikeinterface/curation/auto_merge.py | 3 + .../curation/curation_format.py | 16 +-- .../curation/curationsorting.py | 3 +- .../curation/mergeunitssorting.py | 2 +- .../curation/remove_redundant.py | 4 + .../curation/splitunitsorting.py | 5 +- .../extractors/neoextractors/maxwell.py | 2 +- .../extractors/neoextractors/plexon.py | 2 +- src/spikeinterface/generation/drift_tools.py | 12 ++ .../generation/drifting_generator.py | 2 + src/spikeinterface/generation/hybrid_tools.py | 8 +- .../generation/template_database.py | 2 + .../postprocessing/correlograms.py | 2 +- src/spikeinterface/postprocessing/isi.py | 2 +- .../postprocessing/principal_component.py | 12 +- .../postprocessing/spike_amplitudes.py | 4 +- .../postprocessing/spike_locations.py | 4 +- .../postprocessing/template_metrics.py | 8 +- .../postprocessing/template_similarity.py | 11 +- .../postprocessing/unit_locations.py | 8 +- src/spikeinterface/preprocessing/motion.py | 5 + .../qualitymetrics/misc_metrics.py | 11 +- src/spikeinterface/sorters/launcher.py | 2 + src/spikeinterface/sorters/runsorter.py | 21 +--- .../sortingcomponents/clustering/main.py | 10 +- .../sortingcomponents/matching/main.py | 14 ++- .../motion/motion_cleaner.py | 10 +- .../motion/motion_estimation.py | 16 ++- .../motion/motion_interpolation.py | 37 +++--- .../sortingcomponents/motion/motion_utils.py | 1 + .../sortingcomponents/peak_detection.py | 10 +- .../sortingcomponents/peak_localization.py | 8 +- src/spikeinterface/sortingcomponents/tools.py | 6 +- 41 files changed, 312 insertions(+), 176 deletions(-) diff --git a/src/spikeinterface/comparison/collision.py b/src/spikeinterface/comparison/collision.py index 9b455e6200..574bd16093 100644 --- a/src/spikeinterface/comparison/collision.py +++ b/src/spikeinterface/comparison/collision.py @@ -13,9 +13,19 @@ class CollisionGTComparison(GroundTruthComparison): This class needs maintenance and need a bit of refactoring. - - collision_lag : float + Parameters + ---------- + gt_sorting : SortingExtractor + The first sorting for the comparison + collision_lag : float, default 2.0 Collision lag in ms. + tested_sorting : SortingExtractor + The second sorting for the comparison + nbins : int, default : 11 + Number of collision bins + **kwargs : dict + Keyword arguments for `GroundTruthComparison` + """ diff --git a/src/spikeinterface/comparison/correlogram.py b/src/spikeinterface/comparison/correlogram.py index 5a0dd1d3a7..0cafef2c12 100644 --- a/src/spikeinterface/comparison/correlogram.py +++ b/src/spikeinterface/comparison/correlogram.py @@ -15,6 +15,21 @@ class CorrelogramGTComparison(GroundTruthComparison): This class needs maintenance and need a bit of refactoring. + Parameters + ---------- + gt_sorting : SortingExtractor + The first sorting for the comparison + tested_sorting : SortingExtractor + The second sorting for the comparison + bin_ms : float, default: 1.0 + Size of bin for correlograms + window_ms : float, default: 100.0 + The window around the spike to compute the correlation in ms. + well_detected_score : float, default: 0.8 + Agreement score above which units are well detected + **kwargs : dict + Keyword arguments for `GroundTruthComparison` + """ def __init__(self, gt_sorting, tested_sorting, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs): diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index ba7268b4f0..d45956a07e 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -42,6 +42,11 @@ class GroundTruthStudy: This GroundTruthStudy have been refactor in version 0.100 to be more flexible than previous versions. Note that the underlying folder structure is not backward compatible! + + Parameters + ---------- + study_folder : srt | Path + Path to folder containing `GroundTruthStudy` """ def __init__(self, study_folder): @@ -370,7 +375,6 @@ def get_metrics(self, key): return metrics def get_units_snr(self, key): - """ """ return self.get_metrics(key)["snr"] def get_performance_by_unit(self, case_keys=None): diff --git a/src/spikeinterface/comparison/hybrid.py b/src/spikeinterface/comparison/hybrid.py index beb9682e37..657fc73b71 100644 --- a/src/spikeinterface/comparison/hybrid.py +++ b/src/spikeinterface/comparison/hybrid.py @@ -48,6 +48,8 @@ class HybridUnitsRecording(InjectTemplatesRecording): injected_sorting_folder : str | Path | None If given, the injected sorting is saved to this folder. It must be specified if injected_sorting is None or not serialisable to file. + seed : int, default: None + Random seed for amplitude_factor Returns ------- diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 499004e32e..dccff6118d 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -44,6 +44,8 @@ class MultiSortingComparison(BaseMultiComparison, MixinSpikeTrainComparison): best matching two sorters verbose : bool, default: False if True, output is verbose + do_matching : bool, default: True + if True, SOMETHING HAPPENS. Returns ------- @@ -319,6 +321,14 @@ class MultiTemplateComparison(BaseMultiComparison, MixinTemplateComparison): Minimum agreement score to for a possible match verbose : bool, default: False if True, output is verbose + do_matching : bool, default: True + if True, IT DOES SOMETHING + support : "dense" | "union" | "intersection", default: "union" + The support to compute the similarity matrix. + num_shifts : int, default: 0 + Number of shifts to use to shift templates to maximize similarity. + similarity_method : "cosine" | "l1" | "l2", default: "cosine" + Method for the similarity matrix. Returns ------- diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index 7d5f04dfdd..9566354918 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -263,7 +263,6 @@ def __init__( gt_name=None, tested_name=None, delta_time=0.4, - sampling_frequency=None, match_score=0.5, well_detected_score=0.8, redundant_score=0.2, @@ -425,6 +424,11 @@ def get_performance(self, method="by_unit", output="pandas"): def print_performance(self, method="pooled_with_average"): """ Print performance with the selected method + + Parameters + ---------- + method : "by_unit" | "pooled_with_average", default: "pooled_with_average" + The method to compute performance """ template_txt_performance = _template_txt_performance @@ -449,6 +453,19 @@ def print_summary(self, well_detected_score=None, redundant_score=None, overmerg * how many gt units (one or several) This summary mix several performance metrics. + + Parameters + ---------- + well_detected_score : float, default: None + The agreement score above which tested units + are counted as "well detected". + redundant_score : float, default: None + The agreement score below which tested units + are counted as "false positive"" (and not "redundant"). + overmerged_score : float, default: None + Tested units with 2 or more agreement scores above "overmerged_score" + are counted as "overmerged". + """ txt = _template_summary_part1 @@ -500,6 +517,12 @@ def count_well_detected_units(self, well_detected_score): """ Count how many well detected units. kwargs are the same as get_well_detected_units. + + Parameters + ---------- + well_detected_score : float, default: None + The agreement score above which tested units + are counted as "well detected". """ return len(self.get_well_detected_units(well_detected_score=well_detected_score)) @@ -540,6 +563,12 @@ def get_false_positive_units(self, redundant_score=None): def count_false_positive_units(self, redundant_score=None): """ See get_false_positive_units(). + + Parameters + ---------- + redundant_score : float, default: None + The agreement score below which tested units + are counted as "false positive"" (and not "redundant"). """ return len(self.get_false_positive_units(redundant_score)) @@ -554,7 +583,7 @@ def get_redundant_units(self, redundant_score=None): Parameters ---------- - redundant_score=None : float, default: None + redundant_score : float, default: None The agreement score above which tested units are counted as "redundant" (and not "false positive" ). """ @@ -577,6 +606,12 @@ def get_redundant_units(self, redundant_score=None): def count_redundant_units(self, redundant_score=None): """ See get_redundant_units(). + + Parameters + ---------- + redundant_score : float, default: None + The agreement score below which tested units + are counted as "false positive"" (and not "redundant"). """ return len(self.get_redundant_units(redundant_score=redundant_score)) @@ -609,6 +644,12 @@ def get_overmerged_units(self, overmerged_score=None): def count_overmerged_units(self, overmerged_score=None): """ See get_overmerged_units(). + + Parameters + ---------- + overmerged_score : float, default: None + Tested units with 2 or more agreement scores above "overmerged_score" + are counted as "overmerged". """ return len(self.get_overmerged_units(overmerged_score=overmerged_score)) @@ -704,6 +745,10 @@ class TemplateComparison(BasePairComparison, MixinTemplateComparison): List of units from sorting_analyzer_1 to compare. unit_ids2 : list, default: None List of units from sorting_analyzer_2 to compare. + name1 : str, default: "sess1" + Name of first session. + name2 : str, default: "sess1" + Name of second session. similarity_method : "cosine" | "l1" | "l2", default: "cosine" Method for the similarity matrix. support : "dense" | "union" | "intersection", default: "union" @@ -712,6 +757,11 @@ class TemplateComparison(BasePairComparison, MixinTemplateComparison): Number of shifts to use to shift templates to maximize similarity. verbose : bool, default: False If True, output is verbose. + chance_score : float, default: 0.3 + Minimum agreement score to for a possible match + match_score : float, default: 0.7 + Minimum agreement score to match units + Returns ------- diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index ff1dc5dafa..c9cff4fb94 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -675,13 +675,13 @@ class ComputeNoiseLevels(AnalyzerExtension): Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer A SortingAnalyzer object - **params: dict with additional parameters for the `spikeinterface.get_noise_levels()` function + **params : dict with additional parameters for the `spikeinterface.get_noise_levels()` function Returns ------- - noise_levels: np.array + noise_levels : np.array The noise level vector """ diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index ff75789aab..187103d031 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -44,10 +44,11 @@ def generate_recording( The number of channels in the recording. sampling_frequency : float, default: 30000. (in Hz) The sampling frequency of the recording, default: 30000. - durations: List[float], default: [5.0, 2.5] - The duration in seconds of each segment in the recording, default: [5.0, 2.5]. - Note that the number of segments is determined by the length of this list. - set_probe: bool, default: True + durations : List[float], default: [5.0, 2.5] + The duration in seconds of each segment in the recording. + The number of segments is determined by the length of this list. + set_probe : bool, default: True + If true, attaches probe to the returned `Recording` ndim : int, default: 2 The number of dimensions of the probe, default: 2. Set to 3 to make 3 dimensional probe. seed : Optional[int] @@ -621,6 +622,13 @@ def generate_snippets( The number of units. empty_units : list | None, default: None A list of units that will have no spikes. + durations : List[float], default: [10.325, 3.5] + The duration in seconds of each segment in the recording. + The number of segments is determined by the length of this list. + set_probe : bool, default: True + If true, attaches probe to the returned snippets object + **job_kwargs : dict, default: None + Job keyword arguments for `snippets_from_sorting` Returns ------- @@ -799,14 +807,14 @@ def synthesize_random_firings( Sampling rate. duration : float Duration of the segment in seconds. - refractory_period_ms: float + refractory_period_ms : float Refractory period in ms. - firing_rates: float or list[float] + firing_rates : float or list[float] The firing rate of each unit (in Hz). If float, all units will have the same firing rate. - add_shift_shuffle: bool, default: False + add_shift_shuffle : bool, default: False Optionally add a small shuffle on half of the spikes to make the autocorrelogram less flat. - seed: int, default: None + seed : int, default: None Seed for the generator. Returns @@ -903,8 +911,10 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No Number of injected units. max_shift : int range of the shift in sample. - ratio: float + ratio : float Proportion of original spike in the injected units. + seed : None|int, default: None + Random seed for creating unit peak shifts. Returns ------- @@ -1062,9 +1072,9 @@ class NoiseGeneratorRecording(BaseRecording): The sampling frequency of the recorder. durations : List[float] The durations of each segment in seconds. Note that the length of this list is the number of segments. - noise_levels: float or array, default: 1 + noise_levels : float or array, default: 1 Std of the white noise (if an array, defined by per channels) - cov_matrix: np.array, default None + cov_matrix : np.array, default None The covariance matrix of the noise dtype : Optional[Union[np.dtype, str]], default: "float32" The dtype of the recording. Note that only np.float32 and np.float64 are supported. @@ -1076,7 +1086,7 @@ class NoiseGeneratorRecording(BaseRecording): very fast and cusume only one noise block. * "on_the_fly": generate on the fly a new noise block by combining seed + noise block index no memory preallocation but a bit more computaion (random) - noise_block_size: int + noise_block_size : int Size in sample of noise block. Notes @@ -1279,10 +1289,14 @@ def generate_recording_by_size( ---------- full_traces_size_GiB : float The size in gigabytes (GiB) of the recording. - num_channels: int + num_channels : int Number of channels. seed : int, default: None The seed for np.random.default_rng. + strategy : "tile_pregenerated"| "on_the_fly", default: "tile_pregenerated" + The strategy of generating noise chunk: + * "tile_pregenerated": pregenerate a noise chunk of noise_block_size sample and repeat it very fast and cusume only one noise block. + * "on_the_fly": generate on the fly a new noise block by combining seed + noise block index no memory preallocation but a bit more computaion (random) Returns ------- @@ -1519,25 +1533,25 @@ def generate_templates( Parameters ---------- - channel_locations: np.ndarray + channel_locations : np.ndarray Channel locations. - units_locations: np.ndarray + units_locations : np.ndarray Must be 3D. - sampling_frequency: float + sampling_frequency : float Sampling frequency. - ms_before: float + ms_before : float Cut out in ms before spike peak. - ms_after: float + ms_after : float Cut out in ms after spike peak. - seed: int or None + seed : int or None A seed for random. - dtype: numpy.dtype, default: "float32" + dtype : numpy.dtype, default: "float32" Templates dtype - upsample_factor: None or int + upsample_factor : None or int If not None then template are generated upsampled by this factor. Then a new dimention (axis=3) is added to the template with intermediate inter sample representation. This allow easy random jitter by choising a template this new dim - unit_params: dict of arrays or dict of scalar of dict of tuple + unit_params : dict of arrays or dict of scalar of dict of tuple An optional dict containing parameters per units. Keys are parameter names: @@ -1555,6 +1569,10 @@ def generate_templates( * scalar, then an array is created * tuple, then this difine a range for random values. + mode : "sphere" | "ellipsoid", default: "ellipsoid" + Mode for how to calculate distances + + Returns ------- templates: np.array @@ -1674,31 +1692,33 @@ class InjectTemplatesRecording(BaseRecording): Parameters ---------- - sorting: BaseSorting + sorting : BaseSorting Sorting object containing all the units and their spike train. - templates: np.ndarray[n_units, n_samples, n_channels] or np.ndarray[n_units, n_samples, n_oversampling] + templates : np.ndarray[n_units, n_samples, n_channels] or np.ndarray[n_units, n_samples, n_oversampling] Array containing the templates to inject for all the units. Shape can be: * (num_units, num_samples, num_channels): standard case * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce sampling jitter. - nbefore: list[int] | int | None, default: None + nbefore : list[int] | int | None, default: None The number of samples before the peak of the template to align the spike. If None, will default to the highest peak. - amplitude_factor: list[float] | float | None, default: None + amplitude_factor : list[float] | float | None, default: None The amplitude of each spike for each unit. Can be None (no scaling). Can be scalar all spikes have the same factor (certainly useless). Can be a vector with same shape of spike_vector of the sorting. - parent_recording: BaseRecording | None + parent_recording : BaseRecording | None The recording over which to add the templates. If None, will default to traces containing all 0. - num_samples: list[int] | int | None + num_samples : list[int] | int | None The number of samples in the recording per segment. You can use int for mono-segment objects. - upsample_vector: np.array or None, default: None. + upsample_vector : np.array or None, default: None. When templates is 4d we can simulate a jitter. Optional the upsample_vector is the jitter index with a number per spike in range 0-templates.shape[3]. + check_borders : bool, default: False + Checks if the border of the templates are zero. Returns ------- @@ -2042,55 +2062,55 @@ def generate_ground_truth_recording( Parameters ---------- - durations: list of float, default: [10.] + durations : list of float, default: [10.] Durations in seconds for all segments. - sampling_frequency: float, default: 25000 + sampling_frequency : float, default: 25000 Sampling frequency. - num_channels: int, default: 4 + num_channels : int, default: 4 Number of channels, not used when probe is given. - num_units: int, default: 10 + num_units : int, default: 10 Number of units, not used when sorting is given. - sorting: Sorting or None + sorting : Sorting or None An external sorting object. If not provide, one is genrated. - probe: Probe or None + probe : Probe or None An external Probe object. If not provided a probe is generated using generate_probe_kwargs. - generate_probe_kwargs: dict + generate_probe_kwargs : dict A dict to constuct the Probe using :py:func:`probeinterface.generate_multi_columns_probe()`. - templates: np.array or None + templates : np.array or None The templates of units. If None they are generated. Shape can be: * (num_units, num_samples, num_channels): standard case * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce jitter. - ms_before: float, default: 1.5 + ms_before : float, default: 1.5 Cut out in ms before spike peak. - ms_after: float, default: 3 + ms_after : float, default: 3 Cut out in ms after spike peak. - upsample_factor: None or int, default: None + upsample_factor : None or int, default: None A upsampling factor used only when templates are not provided. - upsample_vector: np.array or None + upsample_vector : np.array or None Optional the upsample_vector can given. This has the same shape as spike_vector - generate_sorting_kwargs: dict + generate_sorting_kwargs : dict When sorting is not provide, this dict is used to generated a Sorting. - noise_kwargs: dict + noise_kwargs : dict Dict used to generated the noise with NoiseGeneratorRecording. - generate_unit_locations_kwargs: dict + generate_unit_locations_kwargs : dict Dict used to generated template when template not provided. - generate_templates_kwargs: dict + generate_templates_kwargs : dict Dict used to generated template when template not provided. - dtype: np.dtype, default: "float32" + dtype : np.dtype, default: "float32" The dtype of the recording. - seed: int or None + seed : int or None Seed for random initialization. If None a diffrent Recording is generated at every call. Note: even with None a generated recording keep internaly a seed to regenerate the same signal after dump/load. Returns ------- - recording: Recording + recording : Recording The generated recording extractor. - sorting: Sorting + sorting : Sorting The generated sorting extractor. """ generate_templates_kwargs = generate_templates_kwargs or dict() diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 920d6713ad..1b0f287d09 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -98,6 +98,7 @@ def get_potential_auto_merge( * | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN. | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations", | "knn", "quality_score" + If `preset` is None, you can specify the steps manually with the `steps` parameter. resolve_graph : bool, default: False If True, the function resolves the potential unit pairs to be merged into multiple-unit merges. @@ -145,6 +146,8 @@ def get_potential_auto_merge( Pontential steps : "num_spikes", "snr", "remove_contaminated", "unit_locations", "correlogram", "template_similarity", "presence_distance", "cross_contamination", "knn", "quality_score" Please check steps explanations above! + presence_distance_kwargs : None|dict, default: None + A dictionary of kwargs to be passed to compute_presence_distance(). Returns ------- diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 88190a9bab..20f20b1a2f 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -289,7 +289,7 @@ def apply_curation( The Sorting object to apply merges. curation_dict : dict The curation dict. - censor_ms: float | None, default: None + censor_ms : float | None, default: None When applying the merges, any consecutive spikes within the `censor_ms` are removed. This can be thought of as the desired refractory period. If `censor_ms=None`, no spikes are discarded. new_id_strategy : "append" | "take_first", default: "append" @@ -297,17 +297,17 @@ def apply_curation( * "append" : new_units_ids will be added at the end of max(sorting.unit_ids) * "take_first" : new_unit_ids will be the first unit_id of every list of merges - merging_mode : "soft" | "hard", default: "soft" + merging_mode : "soft" | "hard", default: "soft" How merges are performed for SortingAnalyzer. If the `merge_mode` is "soft" , merges will be approximated, with no reloading of the waveforms. This will lead to approximations. If `merge_mode` is "hard", recomputations are accurately performed, reloading waveforms if needed sparsity_overlap : float, default 0.75 - The percentage of overlap that units should share in order to accept merges. If this criteria is not - achieved, soft merging will not be possible and an error will be raised. This is for use with a SortingAnalyzer input. - - verbose: - - **job_kwargs + The percentage of overlap that units should share in order to accept merges. If this criteria is not + achieved, soft merging will not be possible and an error will be raised. This is for use with a SortingAnalyzer input. + verbose : bool, default: False + If True, output is verbose + **job_kwargs : dict + Job keyword arguments for `merge_units` Returns ------- diff --git a/src/spikeinterface/curation/curationsorting.py b/src/spikeinterface/curation/curationsorting.py index 702bb587f7..b4afeab547 100644 --- a/src/spikeinterface/curation/curationsorting.py +++ b/src/spikeinterface/curation/curationsorting.py @@ -18,7 +18,7 @@ class CurationSorting: Parameters ---------- - sorting: BaseSorting + sorting : BaseSorting The sorting object properties_policy : "keep" | "remove", default: "keep" Policy used to propagate properties after split and merge operation. If "keep" the properties will be @@ -26,6 +26,7 @@ class CurationSorting: an empty value for all the properties make_graph : bool True to keep a Networkx graph instance with the curation history + Returns ------- sorting : Sorting diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index 11f26ea778..df5bb7446c 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -13,7 +13,7 @@ class MergeUnitsSorting(BaseSorting): Parameters ---------- - sorting: BaseSorting + sorting : BaseSorting The sorting object units_to_merge : list/tuple of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), diff --git a/src/spikeinterface/curation/remove_redundant.py b/src/spikeinterface/curation/remove_redundant.py index 874552f767..bf03afbb8b 100644 --- a/src/spikeinterface/curation/remove_redundant.py +++ b/src/spikeinterface/curation/remove_redundant.py @@ -58,6 +58,10 @@ def remove_redundant_units( Used when remove_strategy="highest_amplitude" extra_outputs : bool, default: False If True, will return the redundant pairs. + unit_peak_shifts : dict + Dictionary mapping the unit_id to the unit's shift (in number of samples). + A positive shift means the spike train is shifted back in time, while + a negative shift means the spike train is shifted forward. Returns ------- diff --git a/src/spikeinterface/curation/splitunitsorting.py b/src/spikeinterface/curation/splitunitsorting.py index 33c14dfe5a..0804f637a5 100644 --- a/src/spikeinterface/curation/splitunitsorting.py +++ b/src/spikeinterface/curation/splitunitsorting.py @@ -13,9 +13,9 @@ class SplitUnitSorting(BaseSorting): Parameters ---------- - sorting: BaseSorting + sorting : BaseSorting The sorting object - parent_unit_id : int + split_unit_id : int Unit id of the unit to split indices_list : list or np.array A list of index arrays selecting the spikes to split in each segment. @@ -28,6 +28,7 @@ class SplitUnitSorting(BaseSorting): Policy used to propagate properties. If "keep" the properties will be passed to the new units (if the units_to_merge have the same value). If "remove" the new units will have an empty value for all the properties of the new unit + Returns ------- sorting : Sorting diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 58110cf7aa..04e41433e1 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -32,7 +32,7 @@ class MaxwellRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. use_names_as_ids : bool, default: False Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the - names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. rec_name : str, default: None When the file contains several recordings you need to specify the one you want to extract. (rec_name='rec0000'). diff --git a/src/spikeinterface/extractors/neoextractors/plexon.py b/src/spikeinterface/extractors/neoextractors/plexon.py index 0adddc2439..eed3188d16 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon.py +++ b/src/spikeinterface/extractors/neoextractors/plexon.py @@ -25,7 +25,7 @@ class PlexonRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. use_names_as_ids : bool, default: True Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the - names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. Example for wideband signals: names: ["WB01", "WB02", "WB03", "WB04"] diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index 0e4f1985c6..aa59de8f60 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -116,6 +116,16 @@ class DriftingTemplates(Templates): * move every templates on-the-fly, this lead to one interpolation per spike * precompute some displacements for all templates and use a discreate interpolation, for instance by step of 1um This is the same strategy used by MEArec. + + Parameters + ---------- + templates_array_moved : np.array + Shape is (num_displacement, num_templates, num_samples, num_channels) + displacements : np.array + Displacement vector + shape : (num_displacement, 2) + **static_kwargs : dict + Keyword arguments for `Templates` """ def __init__(self, templates_array_moved=None, displacements=None, **static_kwargs): @@ -306,6 +316,8 @@ class InjectDriftingTemplatesRecording(BaseRecording): If None, no amplitude scaling is applied. If scalar all spikes have the same factor (certainly useless). If vector, it must have the same size as the spike vector. + mode : str, default: "precompute" + Mode for how to compute templates. Returns ------- diff --git a/src/spikeinterface/generation/drifting_generator.py b/src/spikeinterface/generation/drifting_generator.py index b439c57c52..69f1fb6375 100644 --- a/src/spikeinterface/generation/drifting_generator.py +++ b/src/spikeinterface/generation/drifting_generator.py @@ -194,6 +194,8 @@ def generate_displacement_vector( motion_list : list of dict List of dicts containing individual motion vector parameters. len(motion_list) == displacement_vectors.shape[2] + seed : None | seed, default: None + Random seed for `make_one_displacement_vector` Returns ------- diff --git a/src/spikeinterface/generation/hybrid_tools.py b/src/spikeinterface/generation/hybrid_tools.py index 0c82e496c0..12958649dd 100644 --- a/src/spikeinterface/generation/hybrid_tools.py +++ b/src/spikeinterface/generation/hybrid_tools.py @@ -42,6 +42,8 @@ def estimate_templates_from_recording( Parameters ---------- + recording : BaseRecording + The recording to get temaples from. ms_before : float The time before peaks of templates. ms_after : float @@ -181,6 +183,8 @@ def scale_template_to_range( The minimum amplitude of the output templates after scaling. max_amplitude : float The maximum amplitude of the output templates after scaling. + amplitude_function : "ptp" | "min" | "max", default: "pip" + The function to use to compute the amplitude of the templates. Can be "ptp", "min" or "max". Returns ------- @@ -356,10 +360,6 @@ def generate_hybrid_recording( are_templates_scaled : bool, default: True If True, the templates are assumed to be in uV, otherwise in the same unit as the recording. In case the recording has scaling, the templates are "unscaled" before injection. - ms_before : float, default: 1.5 - Cut out in ms before spike peak. - ms_after : float, default: 3 - Cut out in ms after spike peak. unit_locations : np.array, default: None The locations at which the templates should be injected. If not provided, generated (see generate_unit_location_kwargs). diff --git a/src/spikeinterface/generation/template_database.py b/src/spikeinterface/generation/template_database.py index 17d2bdf521..6d094adf11 100644 --- a/src/spikeinterface/generation/template_database.py +++ b/src/spikeinterface/generation/template_database.py @@ -71,6 +71,8 @@ def query_templates_from_database(template_df: "pandas.DataFrame", verbose: bool ---------- template_df : pd.DataFrame Dataframe containing the template information, obtained by slicing/querying the output of fetch_templates_info. + verbose : bool, default: False + if True, output is verbose Returns ------- diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 3c65f2075c..8da1ed752a 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -20,7 +20,7 @@ class ComputeCorrelograms(AnalyzerExtension): Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer A SortingAnalyzer object window_ms : float, default: 50.0 The window around the spike to compute the correlation in ms. For example, diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index fa919e11e2..542f829f21 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -17,7 +17,7 @@ class ComputeISIHistograms(AnalyzerExtension): Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer A SortingAnalyzer object window_ms : float, default: 50 The window in ms diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 99c60a5043..f1f89403c7 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -25,21 +25,21 @@ class ComputePrincipalComponents(AnalyzerExtension): Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer A SortingAnalyzer object - n_components: int, default: 5 + n_components : int, default: 5 Number of components fo PCA - mode: "by_channel_local" | "by_channel_global" | "concatenated", default: "by_channel_local" + mode : "by_channel_local" | "by_channel_global" | "concatenated", default: "by_channel_local" The PCA mode: - "by_channel_local": a local PCA is fitted for each channel (projection by channel) - "by_channel_global": a global PCA is fitted for all channels (projection by channel) - "concatenated": channels are concatenated and a global PCA is fitted - sparsity: ChannelSparsity or None, default: None + sparsity : ChannelSparsity or None, default: None The sparsity to apply to waveforms. If sorting_analyzer is already sparse, the default sparsity will be used - whiten: bool, default: True + whiten : bool, default: True If True, waveforms are pre-whitened - dtype: dtype, default: "float32" + dtype : dtype, default: "float32" Dtype of the pc scores Examples diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index e82a9e61e4..9e8b5993b9 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -22,13 +22,13 @@ class ComputeSpikeAmplitudes(AnalyzerExtension): Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer A SortingAnalyzer object ms_before : float, default: 0.5 The left window, before a peak, in milliseconds ms_after : float, default: 0.5 The right window, after a peak, in milliseconds - spike_retriver_kwargs: dict + spike_retriver_kwargs : dict A dictionary to control the behavior for getting the maximum channel for each spike This dictionary contains: diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 53e55b4d1f..6995fc04da 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -17,13 +17,13 @@ class ComputeSpikeLocations(AnalyzerExtension): Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer A SortingAnalyzer object ms_before : float, default: 0.5 The left window, before a peak, in milliseconds ms_after : float, default: 0.5 The right window, after a peak, in milliseconds - spike_retriver_kwargs: dict + spike_retriver_kwargs : dict A dictionary to control the behavior for getting the maximum channel for each spike This dictionary contains: diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index eef2a2f32c..31652d8afc 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -50,7 +50,7 @@ class ComputeTemplateMetrics(AnalyzerExtension): Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer The SortingAnalyzer object metric_names : list or None, default: None List of metrics to compute (see si.postprocessing.get_template_metric_names()) @@ -58,13 +58,13 @@ class ComputeTemplateMetrics(AnalyzerExtension): Whether to use the positive ("pos") or negative ("neg") peaks to estimate extremum channels. upsampling_factor : int, default: 10 The upsampling factor to upsample the templates - sparsity: ChannelSparsity or None, default: None + sparsity : ChannelSparsity or None, default: None If None, template metrics are computed on the extremum channel only. If sparsity is given, template metrics are computed on all sparse channels of each unit. For more on generating a ChannelSparsity, see the `~spikeinterface.compute_sparsity()` function. - include_multi_channel_metrics: bool, default: False + include_multi_channel_metrics : bool, default: False Whether to compute multi-channel metrics - metrics_kwargs: dict + metrics_kwargs : dict Additional arguments to pass to the metric functions. Including: * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index cb4cc323ad..53df14ff8a 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -13,25 +13,22 @@ class ComputeTemplateSimilarity(AnalyzerExtension): Similarity is defined as 1 - distance(T_1, T_2) for two templates T_1, T_2 - Parameters ---------- sorting_analyzer : SortingAnalyzer The SortingAnalyzer object method : str, default: "cosine" The method to compute the similarity. Can be in ["cosine", "l2", "l1"] + In case of "l1" or "l2", the formula used is: + - similarity = 1 - norm(T_1 - T_2)/(norm(T_1) + norm(T_2)). + In case of cosine it is: + - similarity = 1 - sum(T_1.T_2)/(norm(T_1)norm(T_2)). max_lag_ms : float, default: 0 If specified, the best distance for all given lag within max_lag_ms is kept, for every template support : "dense" | "union" | "intersection", default: "union" Support that should be considered to compute the distances between the templates, given their sparsities. Can be either ["dense", "union", "intersection"] - In case of "l1" or "l2", the formula used is: - similarity = 1 - norm(T_1 - T_2)/(norm(T_1) + norm(T_2)) - - In case of cosine this is: - similarity = 1 - sum(T_1.T_2)/(norm(T_1)norm(T_2)) - Returns ------- similarity: np.array diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index 818f0a8062..5d190d43f1 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -24,16 +24,16 @@ class ComputeUnitLocations(AnalyzerExtension): Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer A SortingAnalyzer object - method: "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" + method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" The method to use for localization - method_kwargs: dict, default: {} + method_kwargs : dict, default: {} Other kwargs depending on the method Returns ------- - unit_locations: np.array + unit_locations : np.array unit location with shape (num_unit, 2) or (num_unit, 3) or (num_unit, 3) (with alpha) """ diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 8d1f9bc9f3..ddb981a944 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -216,6 +216,11 @@ def get_motion_presets(): def get_motion_parameters_preset(preset): """ Get the parameters tree for a given preset for motion correction. + + Parameters + ---------- + preset : str, default: None + The preset name. See available presets using `spikeinterface.preprocessing.get_motion_presets()`. """ preset_params = copy.deepcopy(motion_options_preset[preset]) all_default_params = _get_default_motion_params() diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 7465d58737..4e1136da91 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -69,7 +69,7 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): return num_spikes -def compute_firing_rates(sorting_analyzer, unit_ids=None, **kwargs): +def compute_firing_rates(sorting_analyzer, unit_ids=None): """ Compute the firing rate across segments. @@ -98,7 +98,7 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None, **kwargs): return firing_rates -def compute_presence_ratios(sorting_analyzer, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, unit_ids=None, **kwargs): +def compute_presence_ratios(sorting_analyzer, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, unit_ids=None): """ Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. @@ -620,7 +620,7 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ _default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8)) -def compute_firing_ranges(sorting_analyzer, bin_size_s=5, percentiles=(5, 95), unit_ids=None, **kwargs): +def compute_firing_ranges(sorting_analyzer, bin_size_s=5, percentiles=(5, 95), unit_ids=None): """ Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution computed in non-overlapping time bins. @@ -1437,6 +1437,8 @@ def compute_sd_ratio( In this case, noise refers to the global voltage trace on the same channel as the best channel of the unit. (ideally (not implemented yet), the noise would be computed outside of spikes from the unit itself). + TODO: Take jitter into account. + Parameters ---------- sorting_analyzer : SortingAnalyzer @@ -1450,9 +1452,8 @@ def compute_sd_ratio( and will make a rough estimation of what that impact is (and remove it). unit_ids : list or None, default: None The list of unit ids to compute this metric. If None, all units are used. - **kwargs: + **kwargs : dict, default: {} Keyword arguments for computing spike amplitudes and extremum channel. - TODO: Take jitter into account. Returns ------- diff --git a/src/spikeinterface/sorters/launcher.py b/src/spikeinterface/sorters/launcher.py index c7127226b0..7ed5b29556 100644 --- a/src/spikeinterface/sorters/launcher.py +++ b/src/spikeinterface/sorters/launcher.py @@ -250,6 +250,8 @@ def run_sorter_by_property( Controls sorter verboseness docker_image : None or str, default: None If str run the sorter inside a container (docker) using the docker package + singularity_image : None or str, default: None + If str run the sorter inside a container (singularity) using the docker package **sorter_params : keyword args Spike sorter specific arguments (they can be retrieved with `get_default_sorter_params(sorter_name_or_class)`) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 80608f8973..17700e7df8 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -96,23 +96,6 @@ If True, the output Sorting is returned as a Sorting delete_container_files : bool, default: True If True, the container temporary files are deleted after the sorting is done - extra_requirements : list, default: None - List of extra requirements to install in the container - installation_mode : "auto" | "pypi" | "github" | "folder" | "dev" | "no-install", default: "auto" - How spikeinterface is installed in the container: - * "auto" : if host installation is a pip release then use "github" with tag - if host installation is DEV_MODE=True then use "dev" - * "pypi" : use pypi with pip install spikeinterface - * "github" : use github with `pip install git+https` - * "folder" : mount a folder in container and install from this one. - So the version in the container is a different spikeinterface version from host, useful for - cross checks - * "dev" : same as "folder", but the folder is the spikeinterface.__file__ to ensure same version as host - * "no-install" : do not install spikeinterface in the container because it is already installed - spikeinterface_version : str, default: None - The spikeinterface version to install in the container. If None, the current version is used - spikeinterface_folder_source : Path or None, default: None - In case of installation_mode="folder", the spikeinterface folder source to use to install in the container output_folder : None, default: None Do not use. Deprecated output function to be removed in 0.103. **sorter_params : keyword args @@ -691,7 +674,9 @@ def read_sorter_folder(folder, register_recording=True, sorting_info=True, raise register_recording : bool, default: True Attach recording (when json or pickle) to the sorting sorting_info : bool, default: True - Attach sorting info to the sorting. + Attach sorting info to the sorting + raise_error : bool, detault: True + Raise an error if the spike sorting failed """ folder = Path(folder) log_file = folder / "spikeinterface_log.json" diff --git a/src/spikeinterface/sortingcomponents/clustering/main.py b/src/spikeinterface/sortingcomponents/clustering/main.py index 7381875557..99881f2f34 100644 --- a/src/spikeinterface/sortingcomponents/clustering/main.py +++ b/src/spikeinterface/sortingcomponents/clustering/main.py @@ -12,15 +12,15 @@ def find_cluster_from_peaks(recording, peaks, method="stupid", method_kwargs={}, Parameters ---------- - recording: RecordingExtractor + recording : RecordingExtractor The recording extractor object - peaks: numpy.array + peaks : numpy.array The peak vector - method: str + method : str Which method to use ("stupid" | "XXXX") - method_kwargs: dict, default: dict() + method_kwargs : dict, default: dict() Keyword arguments for the chosen method - extra_outputs: bool, default: False + extra_outputs : bool, default: False If True then debug is also return {} diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 9476a0df03..88b31476a9 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -14,20 +14,22 @@ def find_spikes_from_templates( Parameters ---------- - recording: RecordingExtractor + recording : RecordingExtractor The recording extractor object - method: "naive" | "tridesclous" | "circus" | "circus-omp" | "wobble" + method : "naive" | "tridesclous" | "circus" | "circus-omp" | "wobble" Which method to use for template matching - method_kwargs: dict, optional + method_kwargs : dict, optional Keyword arguments for the chosen method - extra_outputs: bool + extra_outputs : bool If True then method_kwargs is also returned - job_kwargs: dict + **job_kwargs : dict Parameters for ChunkRecordingExecutor + verbose : Bool, default: False + If True, output is verbose Returns ------- - spikes: ndarray + spikes : ndarray Spikes found from templates. method_kwargs: Optionaly returns for debug purpose. diff --git a/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py b/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py index 2fc1a281a9..87dca64496 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py @@ -11,16 +11,16 @@ def clean_motion_vector(motion, temporal_bins, bin_duration_s, speed_threshold=3 Parameters ---------- - motion: numpy array 2d + motion : numpy array 2d Motion estimate in um. - temporal_bins: numpy.array 1d + temporal_bins : numpy.array 1d temporal bins (bin center) - bin_duration_s: float + bin_duration_s : float bin duration in second - speed_threshold: float (units um/s) + speed_threshold : float (units um/s) Maximum speed treshold between 2 bins allowed. Expressed in um/s - sigma_smooth_s: None or float + sigma_smooth_s : None or float Optional smooting gaussian kernel. Returns diff --git a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py index 0d425c98da..c75f7129aa 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py @@ -43,20 +43,19 @@ def estimate_motion( Parameters ---------- - recording: BaseRecording + recording : BaseRecording The recording extractor - peaks: numpy array + peaks : numpy array Peak vector (complex dtype). Needed for decentralized and iterative_template methods. - peak_locations: numpy array + peak_locations : numpy array Complex dtype with "x", "y", "z" fields Needed for decentralized and iterative_template methods. - direction: "x" | "y" | "z", default: "y" + direction : "x" | "y" | "z", default: "y" Dimension on which the motion is estimated. "y" is depth along the probe. {method_doc} - **non-rigid section** rigid : bool, default: False Compute rigid (one motion for the entire probe) or non rigid motion @@ -76,15 +75,14 @@ def estimate_motion( See win_shape win_margin_um : None | float, default: None See win_shape - extra_outputs: bool, default: False + extra_outputs : bool, default: False If True then return an extra dict that contains variables to check intermediate steps (motion_histogram, non_rigid_windows, pairwise_displacement) - progress_bar: bool, default: False + progress_bar : bool, default: False Display progress bar or not - verbose: bool, default: False + verbose : bool, default: False If True, output is verbose - Returns ------- motion: Motion object diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 11ce11e1aa..57cc4d1371 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -69,23 +69,25 @@ def interpolate_motion_on_traces( Trace snippet (num_samples, num_channels) times : np.array Sample times in seconds for the frames of the traces snippet - channel_location: np.array 2d + channel_locations : np.array 2d Channel location with shape (n, 2) or (n, 3) - motion: Motion + motion : Motion The motion object. - segment_index: int or None + segment_index : int or None The segment index. - channel_inds: None or list + channel_inds : None or list If not None, interpolate only a subset of channels. interpolation_time_bin_centers_s : None or np.array Manually specify the time bins which the interpolation happens in for this segment. If None, these are the motion estimate's time bins. - spatial_interpolation_method: "idw" | "kriging", default: "kriging" + spatial_interpolation_method : "idw" | "kriging", default: "kriging" The spatial interpolation method used to interpolate the channel locations: * idw : Inverse Distance Weighing * kriging : kilosort2.5 like - spatial_interpolation_kwargs: - * specific option for the interpolation method + spatial_interpolation_kwargs : dict + specific option for the interpolation method + dtype : np.dtype, default: None + The dtype of the traces. If None, interhits from traces snippet Returns ------- @@ -237,11 +239,11 @@ class InterpolateMotionRecording(BasePreprocessor): Parameters ---------- - recording: Recording + recording : Recording The parent recording. - motion: Motion + motion : Motion The motion object - spatial_interpolation_method: "kriging" | "idw" | "nearest", default: "kriging" + spatial_interpolation_method : "kriging" | "idw" | "nearest", default: "kriging" The spatial interpolation method used to interpolate the channel locations. See `spikeinterface.preprocessing.get_spatial_interpolation_kernel()` for more details. Choice of the method: @@ -249,23 +251,24 @@ class InterpolateMotionRecording(BasePreprocessor): * "kriging" : the same one used in kilosort * "idw" : inverse distance weighted * "nearest" : use neareast channel - sigma_um: float, default: 20.0 + + sigma_um : float, default: 20.0 Used in the "kriging" formula - p: int, default: 1 + p : int, default: 1 Used in the "kriging" formula - num_closest: int, default: 3 + num_closest : int, default: 3 Number of closest channels used by "idw" method for interpolation. - border_mode: "remove_channels" | "force_extrapolate" | "force_zeros", default: "remove_channels" + border_mode : "remove_channels" | "force_extrapolate" | "force_zeros", default: "remove_channels" Control how channels are handled on border: * "remove_channels": remove channels on the border, the recording has less channels * "force_extrapolate": keep all channel and force extrapolation (can lead to strange signal) * "force_zeros": keep all channel but set zeros when outside (force_extrapolate=False) - interpolation_time_bin_centers_s: np.array or list of np.array, optional + interpolation_time_bin_centers_s : np.array or list of np.array, optional Spatially interpolate each frame according to the displacement estimate at its closest bin center in this array. If not supplied, this is set to the motion estimate's time bin centers. If it's supplied, the motion estimate is interpolated to these bin centers. If you have a multi-segment recording, pass a list of these, one per segment. - interpolation_time_bin_size_s: float, optional + interpolation_time_bin_size_s : float, optional Similar to the previous argument: interpolation_time_bin_centers_s will be constructed by bins spaced by interpolation_time_bin_size_s. This is ignored if interpolation_time_bin_centers_s is supplied. @@ -273,6 +276,8 @@ class InterpolateMotionRecording(BasePreprocessor): Interpolation needs to convert to a floating dtype. If dtype is supplied, that will be used. If the input recording is already floating and dtype=None, then its dtype is used by default. If the input recording is integer, then float32 is used by default. + **spatial_interpolation_kwargs: dict + Spatial interpolation kwargs for `interpolate_motion_on_traces`. Returns ------- diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index 203bd2473b..635624cca8 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -17,6 +17,7 @@ class Motion: Motion estimate in um. List is the number of segment. For each semgent : + * shape (temporal bins, spatial bins) * motion.shape[0] = temporal_bins.shape[0] * motion.shape[1] = 1 (rigid) or spatial_bins.shape[1] (non rigid) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index b984853123..4fe90dd7bc 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -59,19 +59,19 @@ def detect_peaks( Parameters ---------- - recording: RecordingExtractor + recording : RecordingExtractor The recording extractor object. - pipeline_nodes: None or list[PipelineNode] + pipeline_nodes : None or list[PipelineNode] Optional additional PipelineNode need to computed just after detection time. This avoid reading the recording multiple times. - gather_mode: str + gather_mode : str How to gather the results: * "memory": results are returned as in-memory numpy arrays * "npy": results are stored to .npy files in `folder` - folder: str or Path + folder : str or Path If gather_mode is "npy", the folder where the files are created. - names: list + names : list List of strings with file stems associated with returns. {method_doc} diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index b578eb4478..4dff27e338 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -98,10 +98,14 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_ Parameters ---------- - recording: RecordingExtractor + recording : RecordingExtractor The recording extractor object. - peaks: array + peaks : array Peaks array, as returned by detect_peaks() in "compact_numpy" way. + ms_before : float + The number of milliseconds to include before the peak of the spike + ms_after : float + The number of milliseconds to include after the peak of the spike {method_doc} diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index facefac4c5..05552d41a9 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -20,14 +20,14 @@ def make_multi_method_doc(methods, ident=" "): doc = "" - doc += "method: " + ", ".join(f"'{method.name}'" for method in methods) + "\n" + doc += "method : " + ", ".join(f"'{method.name}'" for method in methods) + "\n" doc += ident + " Method to use.\n" for method in methods: doc += "\n" - doc += ident + f"arguments for method='{method.name}'" + doc += ident + ident + f"arguments for method='{method.name}'" for line in method.params_doc.splitlines(): - doc += ident + line + "\n" + doc += ident + ident + line + "\n" return doc From 117e0f69356528fdeca417c6ae770d6c666f89b7 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 13 Aug 2024 13:00:20 +0100 Subject: [PATCH 2/3] More docstring updates --- src/spikeinterface/core/analyzer_extension_core.py | 3 ++- src/spikeinterface/postprocessing/correlograms.py | 4 ++-- src/spikeinterface/postprocessing/spike_amplitudes.py | 4 ++-- src/spikeinterface/postprocessing/template_metrics.py | 6 +++--- src/spikeinterface/postprocessing/unit_locations.py | 6 +++--- .../sortingcomponents/motion/motion_estimation.py | 4 +--- .../sortingcomponents/motion/motion_interpolation.py | 2 +- 7 files changed, 14 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index c9cff4fb94..bc5de63d07 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -677,7 +677,8 @@ class ComputeNoiseLevels(AnalyzerExtension): ---------- sorting_analyzer : SortingAnalyzer A SortingAnalyzer object - **params : dict with additional parameters for the `spikeinterface.get_noise_levels()` function + **kwargs : dict + Additional parameters for the `spikeinterface.get_noise_levels()` function Returns ------- diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 8da1ed752a..7f7946f634 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -20,8 +20,8 @@ class ComputeCorrelograms(AnalyzerExtension): Parameters ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object + sorting_analyzer_or_sorting : SortingAnalyzer | Sorting + A SortingAnalyzer or Sorting object window_ms : float, default: 50.0 The window around the spike to compute the correlation in ms. For example, if 50 ms, the correlations will be computed at lags -25 ms ... 25 ms. diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 7abd2e625e..2efac0e0d0 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -42,8 +42,8 @@ class ComputeSpikeAmplitudes(AnalyzerExtension): In case channel_from_template=False, this is the peak sign. method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" The localization method to use - method_kwargs : dict, default: dict() - Other kwargs depending on the method. + **method_kwargs : dict, default: {} + Kwargs which are passed to the method function. These can be found in the docstrings of `compute_center_of_mass`, `compute_grid_convolution` and `compute_monopolar_triangulation`. outputs : "numpy" | "by_unit", default: "numpy" The output format, either concatenated as numpy array or separated on a per unit basis diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 31652d8afc..e54ff87221 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -77,9 +77,9 @@ class ComputeTemplateMetrics(AnalyzerExtension): * spread_threshold: the threshold to compute the spread, default: 0.2 * spread_smooth_um: the smoothing in um to compute the spread, default: 20 * column_range: the range in um in the horizontal direction to consider channels for velocity, default: None - - If None, all channels all channels are considered - - If 0 or 1, only the "column" that includes the max channel is considered - - If > 1, only channels within range (+/-) um from the max channel horizontal position are used + - If None, all channels all channels are considered + - If 0 or 1, only the "column" that includes the max channel is considered + - If > 1, only channels within range (+/-) um from the max channel horizontal position are used Returns ------- diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index 5d190d43f1..4029fc88c7 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -28,8 +28,8 @@ class ComputeUnitLocations(AnalyzerExtension): A SortingAnalyzer object method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" The method to use for localization - method_kwargs : dict, default: {} - Other kwargs depending on the method + **method_kwargs : dict, default: {} + Kwargs which are passed to the method function. These can be found in the docstrings of `compute_center_of_mass`, `compute_grid_convolution` and `compute_monopolar_triangulation`. Returns ------- @@ -94,7 +94,7 @@ def _run(self, verbose=False): method_kwargs.pop("method") if method not in _unit_location_methods: - raise ValueError(f"Wrong ethod for unit_locations : it should be in {list(_unit_location_methods.keys())}") + raise ValueError(f"Wrong method for unit_locations : it should be in {list(_unit_location_methods.keys())}") func = _unit_location_methods[method] self.data["unit_locations"] = func(self.sorting_analyzer, **method_kwargs) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py index c75f7129aa..62b120e9a0 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py @@ -32,8 +32,6 @@ def estimate_motion( **method_kwargs, ): """ - - Estimate motion with several possible methods. Most of methods except dredge_lfp needs peaks and after their localization. @@ -56,7 +54,6 @@ def estimate_motion( {method_doc} - rigid : bool, default: False Compute rigid (one motion for the entire probe) or non rigid motion Rigid computation is equivalent to non-rigid with only one window with rectangular shape. @@ -82,6 +79,7 @@ def estimate_motion( Display progress bar or not verbose : bool, default: False If True, output is verbose + **method_kwargs : Returns ------- diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 57cc4d1371..4912c26ca0 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -276,7 +276,7 @@ class InterpolateMotionRecording(BasePreprocessor): Interpolation needs to convert to a floating dtype. If dtype is supplied, that will be used. If the input recording is already floating and dtype=None, then its dtype is used by default. If the input recording is integer, then float32 is used by default. - **spatial_interpolation_kwargs: dict + **spatial_interpolation_kwargs : dict Spatial interpolation kwargs for `interpolate_motion_on_traces`. Returns From 1b12d54267b304f8383684c5d2b60c607f4b1a65 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 14 Aug 2024 15:46:07 +0100 Subject: [PATCH 3/3] Reply to review --- src/spikeinterface/comparison/groundtruthstudy.py | 2 +- src/spikeinterface/comparison/multicomparisons.py | 10 +++++----- src/spikeinterface/comparison/paircomparisons.py | 4 ++-- src/spikeinterface/core/generate.py | 12 ++++++------ src/spikeinterface/generation/hybrid_tools.py | 2 +- .../postprocessing/template_similarity.py | 4 ++-- .../sortingcomponents/matching/main.py | 2 +- 7 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index d45956a07e..8929d6983c 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -45,7 +45,7 @@ class GroundTruthStudy: Parameters ---------- - study_folder : srt | Path + study_folder : str | Path Path to folder containing `GroundTruthStudy` """ diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index dccff6118d..f7d9782a07 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -43,9 +43,9 @@ class MultiSortingComparison(BaseMultiComparison, MixinSpikeTrainComparison): - "intersection" : spike trains are the intersection between the spike trains of the best matching two sorters verbose : bool, default: False - if True, output is verbose + If True, output is verbose do_matching : bool, default: True - if True, SOMETHING HAPPENS. + If True, the comparison is done when the `MultiSortingComparison` is initialized Returns ------- @@ -320,15 +320,15 @@ class MultiTemplateComparison(BaseMultiComparison, MixinTemplateComparison): chance_score : float, default: 0.3 Minimum agreement score to for a possible match verbose : bool, default: False - if True, output is verbose + If True, output is verbose do_matching : bool, default: True - if True, IT DOES SOMETHING + If True, the comparison is done when the `MultiSortingComparison` is initialized support : "dense" | "union" | "intersection", default: "union" The support to compute the similarity matrix. num_shifts : int, default: 0 Number of shifts to use to shift templates to maximize similarity. similarity_method : "cosine" | "l1" | "l2", default: "cosine" - Method for the similarity matrix. + Method for the similarity matrix. Returns ------- diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index 9566354918..f5e7cdcc1f 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -566,7 +566,7 @@ def count_false_positive_units(self, redundant_score=None): Parameters ---------- - redundant_score : float, default: None + redundant_score : float | None, default: None The agreement score below which tested units are counted as "false positive"" (and not "redundant"). """ @@ -747,7 +747,7 @@ class TemplateComparison(BasePairComparison, MixinTemplateComparison): List of units from sorting_analyzer_2 to compare. name1 : str, default: "sess1" Name of first session. - name2 : str, default: "sess1" + name2 : str, default: "sess2" Name of second session. similarity_method : "cosine" | "l1" | "l2", default: "cosine" Method for the similarity matrix. diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 187103d031..f8ab8a2d3a 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -44,7 +44,7 @@ def generate_recording( The number of channels in the recording. sampling_frequency : float, default: 30000. (in Hz) The sampling frequency of the recording, default: 30000. - durations : List[float], default: [5.0, 2.5] + durations : list[float], default: [5.0, 2.5] The duration in seconds of each segment in the recording. The number of segments is determined by the length of this list. set_probe : bool, default: True @@ -1295,8 +1295,8 @@ def generate_recording_by_size( The seed for np.random.default_rng. strategy : "tile_pregenerated"| "on_the_fly", default: "tile_pregenerated" The strategy of generating noise chunk: - * "tile_pregenerated": pregenerate a noise chunk of noise_block_size sample and repeat it very fast and cusume only one noise block. - * "on_the_fly": generate on the fly a new noise block by combining seed + noise block index no memory preallocation but a bit more computaion (random) + * "tile_pregenerated": pregenerate a noise chunk of `noise_block_size` samples and repeat it quickly consuming only one noise block. + * "on_the_fly": generate on the fly a new noise block by combining seed + noise block index. No memory preallocation but a bit more computaion (random) Returns ------- @@ -2062,9 +2062,9 @@ def generate_ground_truth_recording( Parameters ---------- - durations : list of float, default: [10.] + durations : list[float], default: [10.] Durations in seconds for all segments. - sampling_frequency : float, default: 25000 + sampling_frequency : float, default: 25000.0 Sampling frequency. num_channels : int, default: 4 Number of channels, not used when probe is given. @@ -2085,7 +2085,7 @@ def generate_ground_truth_recording( * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce jitter. ms_before : float, default: 1.5 Cut out in ms before spike peak. - ms_after : float, default: 3 + ms_after : float, default: 3.0 Cut out in ms after spike peak. upsample_factor : None or int, default: None A upsampling factor used only when templates are not provided. diff --git a/src/spikeinterface/generation/hybrid_tools.py b/src/spikeinterface/generation/hybrid_tools.py index 12958649dd..747389a6d7 100644 --- a/src/spikeinterface/generation/hybrid_tools.py +++ b/src/spikeinterface/generation/hybrid_tools.py @@ -183,7 +183,7 @@ def scale_template_to_range( The minimum amplitude of the output templates after scaling. max_amplitude : float The maximum amplitude of the output templates after scaling. - amplitude_function : "ptp" | "min" | "max", default: "pip" + amplitude_function : "ptp" | "min" | "max", default: "ptp" The function to use to compute the amplitude of the templates. Can be "ptp", "min" or "max". Returns diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 53df14ff8a..27214f32e6 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -17,8 +17,8 @@ class ComputeTemplateSimilarity(AnalyzerExtension): ---------- sorting_analyzer : SortingAnalyzer The SortingAnalyzer object - method : str, default: "cosine" - The method to compute the similarity. Can be in ["cosine", "l2", "l1"] + method : "cosine" | "l1" | "l2", default: "cosine" + The method to compute the similarity. In case of "l1" or "l2", the formula used is: - similarity = 1 - norm(T_1 - T_2)/(norm(T_1) + norm(T_2)). In case of cosine it is: diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 88b31476a9..fa2f7c055e 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -16,7 +16,7 @@ def find_spikes_from_templates( ---------- recording : RecordingExtractor The recording extractor object - method : "naive" | "tridesclous" | "circus" | "circus-omp" | "wobble" + method : "naive" | "tridesclous" | "circus" | "circus-omp" | "wobble", default: "naive" Which method to use for template matching method_kwargs : dict, optional Keyword arguments for the chosen method