diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e907976163..9770856dfa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 23.10.0 + rev: 23.10.1 hooks: - id: black files: ^src/ diff --git a/doc/development/development.rst b/doc/development/development.rst index 7656da11ab..dc10b0c470 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -50,7 +50,7 @@ If you want to run a specific test in a specific file, you can use the following .. code-block:: bash - pytest pytest src/spikeinterface/core/tests/test_baserecording.py::specific_test_in_this_module + pytest src/spikeinterface/core/tests/test_baserecording.py::specific_test_in_this_module We also mantain pytest markers to run specific tests. For example, if you want to run only the tests for the :code:`spikeinterface.extractors` module, you can use the following command: @@ -77,97 +77,121 @@ The extractor tests require datalad for some of the tests. Here are instructions Installing Datalad ------------------ -First install the datalad-installer package using pip: +In order to get datalad for your OS please see the `datalad instruction `_. +For more information on datalad visit the `datalad handbook `_. +Note, this will also require having git-annex. The instruction links above provide information on also +downloading git-annex for your particular OS. -.. code-block:: shell +Stylistic conventions +--------------------- - pip install datalad-installer +SpikeInterface maintains a consistent coding style across the project. This helps to ensure readability and +maintainability of the code, making it easier for contributors to collaborate. To facilitate code style +for the developer we use the follwing tools and conventions: -The following instructions depend on the operating system you are using: -Linux -^^^^^ -.. code-block:: shell +Install Black and pre-commit +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - datalad-installer --sudo ok git-annex --method datalad/packages +We use the python formatter Black, with defaults set in the :code:`pyproject.toml`. This allows for +easy local formatting of code. -Mac OS -^^^^^^ -.. code-block:: shell +To install Black, you can use pip, the Python package installer. Run the following command in your terminal: - datalad-installer --sudo ok git-annex --method brew +.. code-block:: bash -Windows -^^^^^^^ + pip install black -.. code-block:: shell +This will install Black into your current Python environment. - datalad-installer --sudo ok git-annex --method datalad/git-annex:release +In addition to Black, we use pre-commit to manage a suite of code formatting. +Pre-commit helps to automate the process of running these tools (including Black) before every commit, +ensuring that all code is checked for style. +You can install pre-commit using pip as well: -The following steps are common to all operating systems: +.. code-block:: bash -.. code-block:: shell + pip install pre-commit - pip install datalad -(Optional) Configure Git to use git-annex for large files for efficiency: +Once pre-commit is installed, you can set up the pre-commit hooks for your local repository. +These hooks are scripts that pre-commit will run prior to each commit. To install the pre-commit hooks, +navigate to your local repository in your terminal and run the following command: -.. code-block:: shell +.. code-block:: bash - git config --global filter.annex.process "git-annex filter-process" + pre-commit install -Stylistic conventions ---------------------- +Now, each time you make a commit, pre-commit will automatically run Black and any other configured hooks. +If the hooks make changes or if there are any issues, the commit will be stopped, and you'll be able to review and add the changes. +If you want Black to omit a line from formatting, you can add the following comment to the end of the line: -SpikeInterface maintains a consistent coding style across the project, leveraging the black Python code formatter. -This helps to ensure readability and maintainability of the code, making it easier for contributors to collaborate. +.. code-block:: python -To install black, you can use pip, the Python package installer. Run the following command in your terminal: + # fmt: skip -.. code-block:: bash +To ignore a block of code you must flank the code with two comments: - pip install black +.. code-block:: python -This will install black into your current Python environment. + # fmt: off + code here + # fmt: on -In addition to black, we use pre-commit to manage a suite of code formatting. -Pre-commit helps to automate the process of running these tools before every commit, -ensuring that all code is checked for style. +As described in the `black documentation `_. -You can install pre-commit using pip as well: + +Docstring Conventions +^^^^^^^^^^^^^^^^^^^^^ + +For docstrings, SpikeInterface generally follows the `numpy docstring standard `_. +This includes providing a one line summary of a function, and the standard NumPy sections including :code:`Parameters`, :code:`Returns`, etc. The format used +for providing parameters, however is a little different. The project prefers the format: .. code-block:: bash - pip install pre-commit + parameter_name: type, default: default_value -Once pre-commit is installed, you can set up the pre-commit hooks for your local repository. -These hooks are scripts that pre-commit will run prior to each commit. To install the pre-commit hooks, -navigate to your local repository in your terminal and run the following command: +This allows users to quickly understand the type of data that should be input into a function as well as whether a default is supplied. A full example would be: -.. code-block:: bash +.. code-block:: python - pre-commit install + def a_function(param_a, param_b=5, param_c="mean"): + """ + A function for analyzing data -Now, each time you make a commit, pre-commit will automatically run black and any other configured hooks. -If the hooks make changes or if there are any issues, the commit will be stopped, and you'll be able to review and add the changes. + Parameters + ---------- + param_a: dict + A dictionary containing the data + param_b: int, default: 5 + A scaling factor to be applied to the data + param_c: "mean" | "median", default: "mean" + What to calculate on the data -If you want black to omit a line from formatting, you can add the following comment to the end of the line: + Returns + ------- + great_data: dict + A dictionary of the processed data + """ -.. code-block:: python - # fmt: off +Note that in this example we demonstrate two other docstring conventions followed by SpikeInterface. First, that all string arguments should be presented +with double quotes. This is the same stylistic convention followed by Black and enforced by the pre-commit for the repo. Second, when a parameter is a +string with a limited number of values (e.g. :code:`mean` and :code:`median`), rather than give the type a value of :code:`str`, please list the possible strings +so that the user knows what the options are. -As described in the `black documentation `_, -The following are some styling conventions that we follow in SpikeInterface: +Miscelleaneous Stylistic Conventions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ #. Avoid using abreviations in variable names (e.g., use :code:`recording` instead of :code:`rec`). It is especially important to avoid single letter variables. -#. Use index as singular and indices for plural following Numpy. Avoid idx or indexes. Plus, id and ids are reserved for identifiers (i.e. channel_ids) +#. Use index as singular and indices for plural following the NumPy convention. Avoid idx or indexes. Plus, id and ids are reserved for identifiers (i.e. channel_ids) #. We use file_path and folder_path (instead of file_name and folder_name) for clarity. -#. Use the `numpy docstring standard `_ in all the docstrings. + How to build the documentation ------------------------------ @@ -199,14 +223,14 @@ Implement a new extractor ------------------------- SpikeInterface already supports over 30 file formats, but the acquisition system you use might not be among the -supported formats list (***ref***). Most of the extractord rely on the `NEO `_ +supported formats list (***ref***). Most of the extractors rely on the `NEO `_ package to read information from files. -Therefore, to implement a new extractor to handle the unsupported format, we recommend make a new :code:`neo.rawio.BaseRawIO` class (see `example `_). +Therefore, to implement a new extractor to handle the unsupported format, we recommend making a new :code:`neo.rawio.BaseRawIO` class (see `example `_). Once that is done, the new class can be easily wrapped into SpikeInterface as an extension of the :py:class:`~spikeinterface.extractors.neoextractors.neobaseextractors.NeoBaseRecordingExtractor` (for :py:class:`~spikeinterface.core.BaseRecording` objects) or :py:class:`~spikeinterface.extractors.neoextractors.neobaseextractors.NeoBaseRecordingExtractor` -(for py:class:`~spikeinterface.core.BaseSorting` objects) or with a few lines of +(for :py:class:`~spikeinterface.core.BaseSorting` objects) or with a few lines of code (e.g., see reader for `SpikeGLX `_ or `Neuralynx `_). @@ -345,9 +369,9 @@ Moreover, you have to add a launcher function like `run_XXXX()`. When you are done you need to write a test in **tests/test_myspikesorter.py**. In order to be tested, you can -install the required packages by changing the **.travis.yml**. Note that MATLAB based tests cannot be run at the moment, +install the required packages by changing the **pyproject.toml**. Note that MATLAB based tests cannot be run at the moment, but we recommend testing the implementation locally. -After this you need to add a block in doc/sorters_info.rst +After this you need to add a block in **doc/sorters_info.rst** Finally, make a pull request to the spikesorters repo, so we can review the code and merge it to the spikesorters! diff --git a/pyproject.toml b/pyproject.toml index 51efe1f585..fb7f08a038 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,13 @@ widgets = [ "sortingview>=0.11.15", ] +qualitymetrics = [ + "scikit-learn", + "scipy", + "pandas", + "numba", +] + test_core = [ "pytest", "zarr", @@ -178,6 +185,7 @@ markers = [ "widgets", "sortingcomponents", "streaming_extractors: extractors that require streaming such as ross and fsspec", + "ros3_test" ] filterwarnings =[ 'ignore:.*distutils Version classes are deprecated.*:DeprecationWarning', diff --git a/src/spikeinterface/comparison/basecomparison.py b/src/spikeinterface/comparison/basecomparison.py index 5af20d79b5..7a231f3cb4 100644 --- a/src/spikeinterface/comparison/basecomparison.py +++ b/src/spikeinterface/comparison/basecomparison.py @@ -223,7 +223,7 @@ class BasePairComparison(BaseComparison): It handles the matching procedurs. Agreement scores must be computed in inherited classes by overriding the - '_do_agreement(self)' function + "_do_agreement(self)" function """ def __init__(self, object1, object2, name1, name2, match_score=0.5, chance_score=0.1, verbose=False): diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 20ee7910b4..7a1fb87175 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -3,7 +3,6 @@ """ import numpy as np -from joblib import Parallel, delayed def count_matching_events(times1, times2, delta=10): @@ -109,48 +108,169 @@ def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, ev return matching_event_counts -def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1): - """ - Make the match_event_count matrix. - Basically it counts the matching events for all given pairs of spike trains from - sorting1 and sorting2. +def get_optimized_compute_matching_matrix(): + """ + This function is to avoid the bare try-except pattern when importing the compute_matching_matrix function + which uses numba. I tested using the numba dispatcher programatically to avoids this + but the performance improvements were lost. Think you can do better? Don't forget to measure performance against + the current implementation! + TODO: unify numba decorator across all modules + """ + + if hasattr(get_optimized_compute_matching_matrix, "_cached_function"): + return get_optimized_compute_matching_matrix._cached_function + + import numba + + @numba.jit(nopython=True, nogil=True) + def compute_matching_matrix( + frames_spike_train1, + frames_spike_train2, + unit_indices1, + unit_indices2, + num_units_sorting1, + num_units_sorting2, + delta_frames, + ): + """ + Compute a matrix representing the matches between two spike trains. + + Given two spike trains, this function finds matching spikes based on a temporal proximity criterion + defined by `delta_frames`. The resulting matrix indicates the number of matches between units + in `frames_spike_train1` and `frames_spike_train2`. + + Parameters + ---------- + frames_spike_train1 : ndarray + Array of frames for the first spike train. Should be ordered in ascending order. + frames_spike_train2 : ndarray + Array of frames for the second spike train. Should be ordered in ascending order. + unit_indices1 : ndarray + Array indicating the unit indices corresponding to each spike in `frames_spike_train1`. + unit_indices2 : ndarray + Array indicating the unit indices corresponding to each spike in `frames_spike_train2`. + num_units_sorting1 : int + Total number of units in the first spike train. + num_units_sorting2 : int + Total number of units in the second spike train. + delta_frames : int + Maximum difference in frames between two spikes to consider them as a match. + + Returns + ------- + matching_matrix : ndarray + A matrix of shape (num_units_sorting1, num_units_sorting2) where each entry [i, j] represents + the number of matching spikes between unit i of `frames_spike_train1` and unit j of `frames_spike_train2`. + + Notes + ----- + This algorithm identifies matching spikes between two ordered spike trains. + By iterating through each spike in the first train, it compares them against spikes in the second train, + determining matches based on the two spikes frames being within `delta_frames` of each other. + + To avoid redundant comparisons the algorithm maintains a reference, `lower_search_limit_in_second_train`, + which signifies the minimal index in the second spike train that might match the upcoming spike + in the first train. This means that the start of the search moves forward in the second train as the + matches between the two trains are found decreasing the number of comparisons needed. + + An important condition here is thatthe same spike is not matched twice. This is managed by keeping track + of the last matched frame for each unit pair in `previous_frame1_match` and `previous_frame2_match` + + For more details on the rationale behind this approach, refer to the documentation of this module and/or + the metrics section in SpikeForest documentation. + """ + + matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) + + # Used to avoid the same spike matching twice + previous_frame1_match = -np.ones_like(matching_matrix, dtype=np.int64) + previous_frame2_match = -np.ones_like(matching_matrix, dtype=np.int64) + + lower_search_limit_in_second_train = 0 + + for index1 in range(len(frames_spike_train1)): + # Keeps track of which frame in the second spike train should be used as a search start for matches + index2 = lower_search_limit_in_second_train + frame1 = frames_spike_train1[index1] + + # Determine next_frame1 if current frame is not the last frame + not_in_the_last_loop = index1 < len(frames_spike_train1) - 1 + if not_in_the_last_loop: + next_frame1 = frames_spike_train1[index1 + 1] + + while index2 < len(frames_spike_train2): + frame2 = frames_spike_train2[index2] + not_a_match = abs(frame1 - frame2) > delta_frames + if not_a_match: + # Go to the next frame in the first train + break + + # Map the match to a matrix + row, column = unit_indices1[index1], unit_indices2[index2] + + # The same spike cannot be matched twice see the notes in the docstring for more info on this constraint + if frame1 != previous_frame1_match[row, column] and frame2 != previous_frame2_match[row, column]: + previous_frame1_match[row, column] = frame1 + previous_frame2_match[row, column] = frame2 + + matching_matrix[row, column] += 1 + + index2 += 1 + + # Advance the lower_search_limit_in_second_train if the next frame in the first train does not match + not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames + if not_a_match_with_next: + lower_search_limit_in_second_train = index2 + + return matching_matrix + + # Cache the compiled function + get_optimized_compute_matching_matrix._cached_function = compute_matching_matrix + + return compute_matching_matrix + + +def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): + num_units_sorting1 = sorting1.get_num_units() + num_units_sorting2 = sorting2.get_num_units() + matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) + + spike_vector1_segments = sorting1.to_spike_vector(concatenated=False) + spike_vector2_segments = sorting2.to_spike_vector(concatenated=False) + + num_segments_sorting1 = sorting1.get_num_segments() + num_segments_sorting2 = sorting2.get_num_segments() + assert ( + num_segments_sorting1 == num_segments_sorting2 + ), "make_match_count_matrix : sorting1 and sorting2 must have the same segment number" + + # Segments should be matched one by one + for segment_index in range(num_segments_sorting1): + spike_vector1 = spike_vector1_segments[segment_index] + spike_vector2 = spike_vector2_segments[segment_index] + + sample_frames1_sorted = spike_vector1["sample_index"] + sample_frames2_sorted = spike_vector2["sample_index"] + + unit_indices1_sorted = spike_vector1["unit_index"] + unit_indices2_sorted = spike_vector2["unit_index"] - Parameters - ---------- - sorting1: SortingExtractor - The first sorting extractor - sorting2: SortingExtractor - The second sorting extractor - delta_frames: int - Number of frames to consider spikes coincident - n_jobs: int - Number of jobs to run in parallel + matching_matrix += get_optimized_compute_matching_matrix()( + sample_frames1_sorted, + sample_frames2_sorted, + unit_indices1_sorted, + unit_indices2_sorted, + num_units_sorting1, + num_units_sorting2, + delta_frames, + ) - Returns - ------- - match_event_count: array (int64) - Matrix of match count spike - """ + # Build a data frame from the matching matrix import pandas as pd - unit1_ids = np.array(sorting1.get_unit_ids()) - unit2_ids = np.array(sorting2.get_unit_ids()) - - match_event_counts = np.zeros((len(unit1_ids), len(unit2_ids)), dtype="int64") - - # preload all spiketrains 2 into a list - for segment_index in range(sorting1.get_num_segments()): - s2_spiketrains = [sorting2.get_unit_spike_train(u2, segment_index=segment_index) for u2 in unit2_ids] - - match_event_count_segment = Parallel(n_jobs=n_jobs)( - delayed(count_match_spikes)( - sorting1.get_unit_spike_train(u1, segment_index=segment_index), s2_spiketrains, delta_frames - ) - for i1, u1 in enumerate(unit1_ids) - ) - match_event_counts += np.array(match_event_count_segment) - - match_event_counts_df = pd.DataFrame(np.array(match_event_counts), index=unit1_ids, columns=unit2_ids) + unit_ids_of_sorting1 = sorting1.get_unit_ids() + unit_ids_of_sorting2 = sorting2.get_unit_ids() + match_event_counts_df = pd.DataFrame(matching_matrix, index=unit_ids_of_sorting1, columns=unit_ids_of_sorting2) return match_event_counts_df @@ -570,7 +690,7 @@ def do_confusion_matrix(event_counts1, event_counts2, match_12, match_event_coun def do_count_score(event_counts1, event_counts2, match_12, match_event_count): """ For each ground truth units count how many: - 'tp', 'fn', 'cl', 'fp', 'num_gt', 'num_tested', 'tested_id' + "tp", "fn", "cl", "fp", "num_gt", "num_tested", "tested_id" Parameters ---------- @@ -634,8 +754,8 @@ def compute_performance(count_score): Note : * we don't have TN because it do not make sens here. - * 'accuracy' = 'tp_rate' because TN=0 - * 'recall' = 'sensitivity' + * "accuracy" = "tp_rate" because TN=0 + * "recall" = "sensitivity" """ import pandas as pd @@ -674,7 +794,7 @@ def make_matching_events(times1, times2, delta): Returns ------- - matching_event: numpy array dtype = ['index1', 'index2', 'delta'] + matching_event: numpy array dtype = ["index1", "index2", "delta"] 1d of collision """ times_concat = np.concatenate((times1, times2)) @@ -731,8 +851,8 @@ def make_collision_events(sorting, delta): ------- collision_events: numpy array dtype = [('index1', 'int64'), ('unit_id1', 'int64'), - ('index2', 'int64'), ('unit_id2', 'int64'), - ('delta', 'int64')] + ('index2', 'int64'), ('unit_id2', 'int64'), + ('delta', 'int64')] 1d of all collision """ unit_ids = np.array(sorting.get_unit_ids()) diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index f44e14c4c4..bc7d76ea5a 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -25,22 +25,22 @@ class MultiSortingComparison(BaseMultiComparison, MixinSpikeTrainComparison): ---------- sorting_list: list List of sorting extractor objects to be compared - name_list: list - List of spike sorter names. If not given, sorters are named as 'sorter0', 'sorter1', 'sorter2', etc. - delta_time: float - Number of ms to consider coincident spikes (default 0.4 ms) - match_score: float - Minimum agreement score to match units (default 0.5) - chance_score: float - Minimum agreement score to for a possible match (default 0.1) - n_jobs: int + name_list: list, default: None + List of spike sorter names. If not given, sorters are named as "sorter0", "sorter1", "sorter2", etc. + delta_time: float, default: 0.4 + Number of ms to consider coincident spikes + match_score: float, default: 0.5 + Minimum agreement score to match units + chance_score: float, default: 0.1 + Minimum agreement score to for a possible match + n_jobs: int, default: -1 Number of cores to use in parallel. Uses all available if -1 - spiketrain_mode: str + spiketrain_mode: "union" | "intersection", default: "union" Mode to extract agreement spike trains: - - 'union': spike trains are the union between the spike trains of the best matching two sorters - - 'intersection': spike trains are the intersection between the spike trains of the + - "union": spike trains are the union between the spike trains of the best matching two sorters + - "intersection": spike trains are the intersection between the spike trains of the best matching two sorters - verbose: bool + verbose: bool, default: False if True, output is verbose Returns @@ -156,15 +156,15 @@ def _do_agreement_matrix(self, minimum_agreement=1): def get_agreement_sorting(self, minimum_agreement_count=1, minimum_agreement_count_only=False): """ - Returns AgreementSortingExtractor with units with a 'minimum_matching' agreement. + Returns AgreementSortingExtractor with units with a "minimum_matching" agreement. Parameters ---------- minimum_agreement_count: int Minimum number of matches among sorters to include a unit. minimum_agreement_count_only: bool - If True, only units with agreement == 'minimum_matching' are included. - If False, units with an agreement >= 'minimum_matching' are included + If True, only units with agreement == "minimum_matching" are included. + If False, units with an agreement >= "minimum_matching" are included Returns ------- @@ -309,13 +309,13 @@ class MultiTemplateComparison(BaseMultiComparison, MixinTemplateComparison): ---------- waveform_list: list List of waveform extractor objects to be compared - name_list: list - List of session names. If not given, sorters are named as 'sess0', 'sess1', 'sess2', etc. - match_score: float - Minimum agreement score to match units (default 0.5) - chance_score: float - Minimum agreement score to for a possible match (default 0.1) - verbose: bool + name_list: list, default: None + List of session names. If not given, sorters are named as "sess0", "sess1", "sess2", etc. + match_score: float, default: 0.8 + Minimum agreement score to match units + chance_score: float, default: 0.3 + Minimum agreement score to for a possible match + verbose: bool, default: False if True, output is verbose Returns diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index 75976ed44f..e2dc30493d 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -111,19 +111,19 @@ class SymmetricSortingComparison(BasePairSorterComparison): The first sorting for the comparison sorting2: SortingExtractor The second sorting for the comparison - sorting1_name: str + sorting1_name: str, default: None The name of sorter 1 - sorting2_name: : str + sorting2_name: : str, default: None The name of sorter 2 - delta_time: float - Number of ms to consider coincident spikes (default 0.4 ms) - match_score: float - Minimum agreement score to match units (default 0.5) - chance_score: float - Minimum agreement score to for a possible match (default 0.1) - n_jobs: int + delta_time: float, default: 0.4 + Number of ms to consider coincident spikes + match_score: float, default: 0.5 + Minimum agreement score to match units + chance_score: float, default: 0.1 + Minimum agreement score to for a possible match + n_jobs: int, default: -1 Number of cores to use in parallel. Uses all available if -1 - verbose: bool + verbose: bool, default: False If True, output is verbose Returns @@ -139,7 +139,6 @@ def __init__( sorting1_name=None, sorting2_name=None, delta_time=0.4, - sampling_frequency=None, match_score=0.5, chance_score=0.1, n_jobs=-1, @@ -214,34 +213,35 @@ class GroundTruthComparison(BasePairSorterComparison): The first sorting for the comparison tested_sorting: SortingExtractor The second sorting for the comparison - gt_name: str + gt_name: str, default: None The name of sorter 1 - tested_name: : str + tested_name: : str, default: None The name of sorter 2 - delta_time: float - Number of ms to consider coincident spikes (default 0.4 ms) match_score: float - Minimum agreement score to match units (default 0.5) - chance_score: float - Minimum agreement score to for a possible match (default 0.1) - redundant_score: float - Agreement score above which units are redundant (default 0.2) - overmerged_score: float - Agreement score above which units can be overmerged (default 0.2) - well_detected_score: float - Agreement score above which units are well detected (default 0.8) - exhaustive_gt: bool (default True) + delta_time: float, default: 0.4 + Number of ms to consider coincident spikes + match_score: float, default: 0.5 + Minimum agreement score to match units + chance_score: float, default: 0.1 + Minimum agreement score to for a possible match + redundant_score: float, default: 0.2 + Agreement score above which units are redundant + overmerged_score: float, default: 0.2 + Agreement score above which units can be overmerged + well_detected_score: float, default: 0.8 + Agreement score above which units are well detected + exhaustive_gt: bool, default: False Tell if the ground true is "exhaustive" or not. In other world if the GT have all possible units. It allows more performance measurement. For instance, MEArec simulated dataset have exhaustive_gt=True - match_mode: 'hungarian', or 'best' - What is match used for counting : 'hungarian' or 'best match'. - n_jobs: int + match_mode: "hungarian" | "best", default: "hungarian" + The method to match units + n_jobs: int, default: -1 Number of cores to use in parallel. Uses all available if -1 - compute_labels: bool - If True, labels are computed at instantiation (default False) - compute_misclassifications: bool - If True, misclassifications are computed at instantiation (default False) - verbose: bool + compute_labels: bool, default: False + If True, labels are computed at instantiation + compute_misclassifications: bool, default: False + If True, misclassifications are computed at instantiation + verbose: bool, default: False If True, output is verbose Returns @@ -379,21 +379,21 @@ def _do_score_labels(self): def get_performance(self, method="by_unit", output="pandas"): """ Get performance rate with several method: - * 'raw_count' : just render the raw count table - * 'by_unit' : render perf as rate unit by unit of the GT - * 'pooled_with_average' : compute rate unit by unit and average + * "raw_count" : just render the raw count table + * "by_unit" : render perf as rate unit by unit of the GT + * "pooled_with_average" : compute rate unit by unit and average Parameters ---------- - method: str - 'by_unit', or 'pooled_with_average' - output: str - 'pandas' or 'dict' + method: "by_unit" | "pooled_with_average", default: "by_unit" + The method to compute performance + output: "pandas" | "dict", default: "pandas" + The output format Returns ------- perf: pandas dataframe/series (or dict) - dataframe/series (based on 'output') with performance entries + dataframe/series (based on "output") with performance entries """ import pandas as pd @@ -471,7 +471,7 @@ def get_well_detected_units(self, well_detected_score=None): Parameters ---------- - well_detected_score: float (default 0.8) + well_detected_score: float, default: None The agreement score above which tested units are counted as "well detected". """ @@ -507,7 +507,7 @@ def get_false_positive_units(self, redundant_score=None): Parameters ---------- - redundant_score: float (default 0.2) + redundant_score: float, default: None The agreement score below which tested units are counted as "false positive"" (and not "redundant"). """ @@ -547,7 +547,7 @@ def get_redundant_units(self, redundant_score=None): Parameters ---------- - redundant_score=None: float (default 0.2) + redundant_score=None: float, default: None The agreement score above which tested units are counted as "redundant" (and not "false positive" ). """ @@ -582,8 +582,8 @@ def get_overmerged_units(self, overmerged_score=None): Parameters ---------- - overmerged_score: float (default 0.4) - Tested units with 2 or more agreement scores above 'overmerged_score' + overmerged_score: float, default: None + Tested units with 2 or more agreement scores above "overmerged_score" are counted as "overmerged". """ assert self.exhaustive_gt, "overmerged_units list is valid only if exhaustive_gt=True" @@ -693,16 +693,16 @@ class TemplateComparison(BasePairComparison, MixinTemplateComparison): The first waveform extractor to get templates to compare we2 : WaveformExtractor The second waveform extractor to get templates to compare - unit_ids1 : list, optional - List of units from we1 to compare, by default None - unit_ids2 : list, optional - List of units from we2 to compare, by default None - similarity_method : str, optional - Method for the similaroty matrix, by default "cosine_similarity" - sparsity_dict : dict, optional - Dictionary for sparsity, by default None - verbose : bool, optional - If True, output is verbose, by default False + unit_ids1 : list, default: None + List of units from we1 to compare + unit_ids2 : list, default: None + List of units from we2 to compare + similarity_method : str, default: "cosine_similarity" + Method for the similaroty matrix + sparsity_dict : dict, default: None + Dictionary for sparsity + verbose : bool, default: False + If True, output is verbose Returns ------- diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index 5d5c56d15c..c6494b04d1 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -15,6 +15,7 @@ do_count_score, compute_performance, ) +from spikeinterface.core.generate import generate_sorting def make_sorting(times1, labels1, times2, labels2): @@ -27,25 +28,113 @@ def make_sorting(times1, labels1, times2, labels2): def test_make_match_count_matrix(): delta_frames = 10 - # simple match sorting1, sorting2 = make_sorting( [100, 200, 300, 400], [0, 0, 1, 0], - [ - 101, - 201, - 301, - ], + [101, 201, 301], [0, 0, 5], ) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1) - # ~ print(match_event_count) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) assert match_event_count.shape[0] == len(sorting1.get_unit_ids()) assert match_event_count.shape[1] == len(sorting2.get_unit_ids()) +def test_make_match_count_matrix_sorting_with_itself_simple(): + delta_frames = 10 + + # simple sorting with itself + sorting1, sorting2 = make_sorting( + [100, 200, 300, 400], + [0, 0, 1, 0], + [100, 200, 300, 400], + [0, 0, 1, 0], + ) + + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) + + expected_result = [[3, 0], [0, 1]] + assert_array_equal(match_event_count.to_numpy(), expected_result) + + +def test_make_match_count_matrix_sorting_with_itself_longer(): + seed = 2 + sorting = generate_sorting(num_units=10, sampling_frequency=30000, durations=[5, 5], seed=seed) + + delta_frame_milliseconds = 0.1 # Short so that we only matches between a unit and itself + delta_frames_seconds = delta_frame_milliseconds / 1000 + delta_frames = delta_frames_seconds * sorting.get_sampling_frequency() + match_event_count = make_match_count_matrix(sorting, sorting, delta_frames) + + match_event_count_as_array = match_event_count.to_numpy() + matches_with_itself = np.diag(match_event_count_as_array) + + # The number of matches with itself should be equal to the number of spikes in each unit + spikes_per_unit_dict = sorting.count_num_spikes_per_unit() + expected_result = np.array([spikes_per_unit_dict[u] for u in spikes_per_unit_dict.keys()]) + assert_array_equal(matches_with_itself, expected_result) + + +def test_make_match_count_matrix_with_mismatched_sortings(): + delta_frames = 10 + + sorting1, sorting2 = make_sorting( + [100, 200, 300, 400], [0, 0, 1, 0], [500, 600, 700, 800], [0, 0, 1, 0] # Completely different spike times + ) + + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) + + expected_result = [[0, 0], [0, 0]] # No matches between sorting1 and sorting2 + assert_array_equal(match_event_count.to_numpy(), expected_result) + + +def test_make_match_count_matrix_no_double_matching(): + # Jeremy Magland condition: no double matching + frames_spike_train1 = [100, 105, 120, 1000] + unit_indices1 = [0, 1, 0, 0] + frames_spike_train2 = [101, 150, 1000] + unit_indices2 = [0, 1, 0] + delta_frames = 100 + + # Here the key is that the first frame in the first sorting (120) should not match anything in the second sorting + # Because the matching candidates in the second sorting are already matched to the first two frames + # in the first sorting + + # In detail: + # The first frame in sorting 1 (100) from unit 0 should match: + # * The first frame in sorting 2 (101) from unit 0 + # * The second frame in sorting 2 (150) from unit 1 + # The second frame in sorting 1 (105) from unit 1 should match: + # * The first frame in sorting 2 (101) from unit 0 + # * The second frame in sorting 2 (150) from unit 1 + # The third frame in sorting 1 (120) from unit 0 should not match anything + # The final frame in sorting 1 (1000) from unit 0 should only match the final frame in sorting 2 (1000) from unit 0 + + sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) + + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames) + + expected_result = np.array([[2, 1], [1, 1]]) # Only one match is expected despite potential repeats + assert_array_equal(result.to_numpy(), expected_result) + + +def test_make_match_count_matrix_repeated_matching_but_no_double_counting(): + # Challenging condition, this was failing with the previous approach that used np.where and np.diff + frames_spike_train1 = [100, 105, 110] # Will fail with [100, 105, 110, 120] + frames_spike_train2 = [100, 105, 110] + unit_indices1 = [0, 0, 0] # Will fail with [0, 0, 0, 0] + unit_indices2 = [0, 0, 0] + delta_frames = 20 # long enough, so all frames in both sortings are within each other reach + + sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) + + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames) + + expected_result = np.array([[3]]) + assert_array_equal(result.to_numpy(), expected_result) + + def test_make_agreement_scores(): delta_frames = 10 diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index ad31b97d8e..b51bace55f 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -1,4 +1,5 @@ from pathlib import Path +import re from typing import Any, Iterable, List, Optional, Sequence, Union import importlib import warnings @@ -44,7 +45,7 @@ def __init__(self, main_ids: Sequence) -> None: # store init kwargs for nested serialisation self._kwargs = {} - # 'main_ids' will either be channel_ids or units_ids + # "main_ids" will either be channel_ids or units_ids # They are used for properties self._main_ids = np.array(main_ids) if len(self._main_ids) > 0: @@ -90,7 +91,7 @@ def ids_to_indices(self, ids: Iterable, prefer_slice: bool = False) -> Union[np. * data * properties - 'prefer_slice' is an efficient option that tries to make a slice object + "prefer_slice" is an efficient option that tries to make a slice object when indices are consecutive. """ @@ -171,11 +172,11 @@ def set_property(self, key, values: Sequence, ids: Optional[Sequence] = None, mi The property name values : np.array Array of values for the property - ids : list/np.array, optional - List of subset of ids to set the values, by default None - missing_value : object, optional - In case the property is set on a subset of values ('ids' not None), - it specifies the how the missing values should be filled, by default None. + ids : list/np.array, default: None + List of subset of ids to set the values, default: None + missing_value : object, default: None + In case the property is set on a subset of values ("ids" not None), + it specifies the how the missing values should be filled. The missing_value has to be specified for types int and unsigned int. """ @@ -269,8 +270,8 @@ def copy_metadata( If True, only the main annotations/properties are copied. ids: list List of ids to copy the metadata to. If None, all ids are copied. - skip_properties: list - List of properties to skip. Default is None. + skip_properties: list, default: None + List of properties to skip """ if ids is None: @@ -320,18 +321,18 @@ def to_dict( Parameters ---------- - include_annotations: bool - If True, all annotations are added to the dict, by default False - include_properties: bool - If True, all properties are added to the dict, by default False - relative_to: str, Path, or None - If not None, files and folders are serialized relative to this path, by default None + include_annotations: bool, default: False + If True, all annotations are added to the dict + include_properties: bool, default: False + If True, all properties are added to the dict + relative_to: str, Path, or None, default: None + If not None, files and folders are serialized relative to this path Used in waveform extractor to maintain relative paths to binary files even if the containing folder / diretory is moved folder_metadata: str, Path, or None Folder with numpy `npy` files containing additional information (e.g. probe in BaseRecording) and properties. - recursive: bool - If True, all dicitionaries in the kwargs are expanded with `to_dict` as well, by default False. + recursive: bool, default: False + If True, all dicitionaries in the kwargs are expanded with `to_dict` as well Returns ------- @@ -342,7 +343,7 @@ def to_dict( kwargs = self._kwargs if relative_to and not recursive: - raise ValueError("`relative_to` is only posible when `recursive=True`") + raise ValueError("`relative_to` is only possible when `recursive=True`") if recursive: to_dict_kwargs = dict( @@ -571,7 +572,12 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No else: raise ValueError("Dump: file must .json or .pkl") - def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=None, folder_metadata=None) -> None: + def dump_to_json( + self, + file_path: Union[str, Path, None] = None, + relative_to: Union[str, Path, bool, None] = None, + folder_metadata: Union[str, Path, None] = None, + ) -> None: """ Dump recording extractor to json file. The extractor can be re-loaded with load_extractor(json_file) @@ -584,7 +590,7 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder. This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. folder_metadata: str, Path, or None - Folder with files containing additional information (e.g. probe in BaseRecording) and properties. + Folder with files containing additional information (e.g. probe in BaseRecording) and properties """ assert self.check_serializablility("json"), "The extractor is not json serializable" @@ -610,8 +616,9 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non def dump_to_pickle( self, file_path: Union[str, Path, None] = None, + relative_to: Union[str, Path, bool, None] = None, include_properties: bool = True, - folder_metadata=None, + folder_metadata: Union[str, Path, None] = None, ): """ Dump recording extractor to a pickle file. @@ -621,6 +628,9 @@ def dump_to_pickle( ---------- file_path: str Path of the pickle file + relative_to: str, Path, True or None + If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder. + This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. include_properties: bool If True, all properties are dumped folder_metadata: str, Path, or None @@ -628,12 +638,21 @@ def dump_to_pickle( """ assert self.check_if_pickle_serializable(), "The extractor is not serializable to file with pickle" + # Writing paths as relative_to requires recursively expanding the dict + if relative_to: + relative_to = Path(file_path).parent if relative_to is True else Path(relative_to) + relative_to = relative_to.resolve().absolute() + # if relative_to is used, the dictionaru needs recursive expansion + recursive = True + else: + recursive = False + dump_dict = self.to_dict( include_annotations=True, include_properties=include_properties, folder_metadata=folder_metadata, - relative_to=None, - recursive=False, + relative_to=relative_to, + recursive=recursive, ) file_path = self._get_file_path(file_path, [".pkl", ".pickle"]) @@ -741,7 +760,7 @@ def save(self, **kwargs) -> "BaseExtractor": folder (use set_global_tmp_folder() to change this folder) where the object is saved. If folder and name are not given, the object is saved in the global temporary folder with a random string - * dump_ext: 'json' or 'pkl', default 'json' (if format is "folder") + * dump_ext: "json" or "pkl", default "json" (if format is "folder") * verbose: if True output is verbose * **save_kwargs: additional kwargs format-dependent and job kwargs for recording {} @@ -781,7 +800,7 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): This replaces the use of the old CacheRecordingExtractor and CacheSortingExtractor. - There are 2 option for the 'folder' argument: + There are 2 option for the "folder" argument: * explicit folder: `extractor.save(folder="/path-for-saving/")` * explicit sub-folder, implicit base-folder : `extractor.save(name="extarctor_name")` * generated: `extractor.save()` @@ -796,10 +815,10 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): ---------- name: None str or Path Name of the subfolder in get_global_tmp_folder() - If 'name' is given, 'folder' must be None. + If "name" is given, "folder" must be None. folder: None str or Path Name of the folder. - If 'folder' is given, 'name' must be None. + If "folder" is given, "name" must be None. Returns ------- @@ -866,21 +885,22 @@ def save_to_zarr( Parameters ---------- - name: str or None + name: str or None, default: None Name of the subfolder in get_global_tmp_folder() - If 'name' is given, 'folder' must be None. - folder: str, Path, or None - The folder used to save the zarr output. If the folder does not have a '.zarr' suffix, + If "name" is given, "folder" must be None. + folder: str, Path, or None, default: None + The folder used to save the zarr output. If the folder does not have a ".zarr" suffix, it will be automatically appended. - storage_options: dict or None + storage_options: dict or None, default: None Storage options for zarr `store`. E.g., if "s3://" or "gcs://" they can provide authentication methods, etc. For cloud storage locations, this should not be None (in case of default values, use an empty dict) - channel_chunk_size: int or None - Channels per chunk. Default None (chunking in time only) - verbose: bool - If True (default), the output is verbose. - zarr_path: str, Path, or None - (Deprecated) Name of the zarr folder (.zarr). + channel_chunk_size: int or None, default: None + Channels per chunk + verbose: bool, default: True + If True, the output is verbose + zarr_path: str, Path, or None, default: None + (Deprecated) Name of the zarr folder (.zarr) + **save_kwargs: Keyword arguments for saving. Returns ------- diff --git a/src/spikeinterface/core/baseevent.py b/src/spikeinterface/core/baseevent.py index 87651977e5..895a1f501a 100644 --- a/src/spikeinterface/core/baseevent.py +++ b/src/spikeinterface/core/baseevent.py @@ -80,14 +80,14 @@ def get_events( Parameters ---------- - channel_id : int or str, optional - The event channel id, by default None - segment_index : int, optional - The segment index, required for multi-segment objects, by default None - start_time : float, optional - The start time in seconds, by default None - end_time : float, optional - The end time in seconds, by default None + channel_id : int or str, default: None + The event channel id + segment_index : int or None, default: None + The segment index, required for multi-segment objects + start_time : float, default: None + The start time in seconds + end_time : float, default: None + The end time in seconds Returns ------- @@ -110,14 +110,14 @@ def get_event_times( Parameters ---------- - channel_id : int or str, optional - The event channel id, by default None - segment_index : int, optional - The segment index, required for multi-segment objects, by default None - start_time : float, optional - The start time in seconds, by default None - end_time : float, optional - The end time in seconds, by default None + channel_id : int or str, default: None + The event channel id + segment_index : int or None, default: None + The segment index, required for multi-segment objects + start_time : float, default: None + The start time in seconds + end_time : float, default: None + The end time in seconds Returns ------- diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 2977211c25..6dfe038558 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -136,9 +136,9 @@ def get_num_samples(self, segment_index=None) -> int: Parameters ---------- - segment_index : int, optional + segment_index : int or None, default: None The segment index to retrieve the number of samples for. - For multi-segment objects, it is required, by default None + For multi-segment objects, it is required, default: None With single segment recording returns the number of samples in the segment Returns @@ -171,9 +171,9 @@ def get_duration(self, segment_index=None) -> float: Parameters ---------- - segment_index : int, optional + segment_index : int or None, default: None The sample index to retrieve the duration for. - For multi-segment objects, it is required, by default None + For multi-segment objects, it is required, default: None With single segment recording returns the duration of the single segment Returns @@ -204,9 +204,9 @@ def get_memory_size(self, segment_index=None) -> int: Parameters ---------- - segment_index : int, optional + segment_index : int or None, default: None The index of the segment for which the memory size should be calculated. - For multi-segment objects, it is required, by default None + For multi-segment objects, it is required, default: None With single segment recording returns the memory size of the single segment Returns @@ -249,22 +249,22 @@ def get_traces( Parameters ---------- - segment_index : Union[int, None], optional - The segment index to get traces from. If recording is multi-segment, it is required, by default None - start_frame : Union[int, None], optional - The start frame. If None, 0 is used, by default None - end_frame : Union[int, None], optional - The end frame. If None, the number of samples in the segment is used, by default None - channel_ids : Union[Iterable, None], optional - The channel ids. If None, all channels are used, by default None - order : Union[str, None], optional - The order of the traces ("C" | "F"). If None, traces are returned as they are, by default None - return_scaled : bool, optional + segment_index : Union[int, None], default: None + The segment index to get traces from. If recording is multi-segment, it is required, default: None + start_frame : Union[int, None], default: None + The start frame. If None, 0 is used, default: None + end_frame : Union[int, None], default: None + The end frame. If None, the number of samples in the segment is used, default: None + channel_ids : Union[Iterable, None], default: None + The channel ids. If None, all channels are used, default: None + order : Union[str, None], default: None + The order of the traces ("C" | "F"). If None, traces are returned as they are, default: None + return_scaled : bool, default: None If True and the recording has scaling (gain_to_uV and offset_to_uV properties), - traces are scaled to uV, by default False - cast_unsigned : bool, optional + traces are scaled to uV, default: False + cast_unsigned : bool, default: None If True and the traces are unsigned, they are cast to integer and centered - (an offset of (2**nbits) is subtracted), by default False + (an offset of (2**nbits) is subtracted), default: False Returns ------- @@ -337,9 +337,9 @@ def get_time_info(self, segment_index=None) -> dict: dict A dictionary containing the following key-value pairs: - - 'sampling_frequency': The sampling frequency of the RecordingSegment. - - 't_start': The start time of the RecordingSegment. - - 'time_vector': The time vector of the RecordingSegment. + - "sampling_frequency": The sampling frequency of the RecordingSegment. + - "t_start": The start time of the RecordingSegment. + - "time_vector": The time vector of the RecordingSegment. Notes ----- @@ -362,8 +362,8 @@ def get_times(self, segment_index=None): Parameters ---------- - segment_index : int, optional - The segment index (required for multi-segment), by default None + segment_index : int or None, default: None + The segment index (required for multi-segment) Returns ------- @@ -380,8 +380,8 @@ def has_time_vector(self, segment_index=None): Parameters ---------- - segment_index : int, optional - The segment index (required for multi-segment), by default None + segment_index : int or None, default: None + The segment index (required for multi-segment) Returns ------- @@ -400,10 +400,10 @@ def set_times(self, times, segment_index=None, with_warning=True): ---------- times : 1d np.array The time vector - segment_index : int, optional - The segment index (required for multi-segment), by default None - with_warning : bool, optional - If True, a warning is printed, by default True + segment_index : int or None, default: None + The segment index (required for multi-segment) + with_warning : bool, default: True + If True, a warning is printed """ segment_index = self._check_segment_index(segment_index) rs = self._recording_segments[segment_index] @@ -437,7 +437,7 @@ def time_to_sample_index(self, time_s, segment_index=None): def _save(self, format="binary", **save_kwargs): """ This function replaces the old CacheRecordingExtractor, but enables more engines - for caching a results. At the moment only 'binary' with memmap is supported. + for caching a results. At the moment only "binary" with memmap is supported. We plan to add other engines, such as zarr and NWB. """ @@ -712,9 +712,9 @@ def get_times_kwargs(self) -> dict: dict A dictionary containing the following key-value pairs: - - 'sampling_frequency': The sampling frequency of the RecordingSegment. - - 't_start': The start time of the RecordingSegment. - - 'time_vector': The time vector of the RecordingSegment. + - "sampling_frequency": The sampling frequency of the RecordingSegment. + - "t_start": The start time of the RecordingSegment. + - "time_vector": The time vector of the RecordingSegment. Notes ----- @@ -770,16 +770,15 @@ def get_traces( Parameters ---------- - start_frame: (Union[int, None], optional) - start sample index, or zero if None. Defaults to None. - end_frame: (Union[int, None], optional) - end_sample, or number of samples if None. Defaults to None. - channel_indices: (Union[List, None], optional) - Indices of channels to return, or all channels if None. Defaults to None. - order: (Order, optional) + start_frame: Union[int, None], default: None + start sample index, or zero if None + end_frame: Union[int, None], default: None + end_sample, or number of samples if None + channel_indices: Union[List, None], default: None + Indices of channels to return, or all channels if None + order: list or None, default: None The memory order of the returned array. - Use Order.C for C order, Order.F for Fortran order, or Order.K to keep the order of the underlying data. - Defaults to Order.K. + Use Order.C for C order, Order.F for Fortran order, or Order.K to keep the order of the underlying data Returns ------- diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index d411f38d2a..5d0d1b130a 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -72,7 +72,23 @@ def _frame_slice(self, channel_ids, renamed_channel_ids=None): def set_probe(self, probe, group_mode="by_probe", in_place=False): """ - Wrapper on top on set_probes when there one unique probe. + Attach a list of Probe object to a recording. + + Parameters + ---------- + probe_or_probegroup: Probe, list of Probe, or ProbeGroup + The probe(s) to be attached to the recording + group_mode: "by_probe" | "by_shank", default: "by_probe + "by_probe" or "by_shank". Adds grouping property to the recording based on the probes ("by_probe") + or shanks ("by_shanks") + in_place: bool + False by default. + Useful internally when extractor do self.set_probegroup(probe) + + Returns + ------- + sub_recording: BaseRecording + A view of the recording (ChannelSlice or clone or itself) """ assert isinstance(probe, Probe), "must give Probe" probegroup = ProbeGroup() @@ -84,7 +100,7 @@ def set_probegroup(self, probegroup, group_mode="by_probe", in_place=False): def set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False): """ - Attach a Probe to a recording. + Attach a list of Probe objects to a recording. For this Probe.device_channel_indices is used to link contacts to recording channels. If some contacts of the Probe are not connected (device_channel_indices=-1) then the recording is "sliced" and only connected channel are kept. @@ -96,9 +112,9 @@ def set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False) ---------- probe_or_probegroup: Probe, list of Probe, or ProbeGroup The probe(s) to be attached to the recording - group_mode: str - 'by_probe' or 'by_shank'. Adds grouping property to the recording based on the probes ('by_probe') - or shanks ('by_shanks') + group_mode: "by_probe" | "by_shank", default: "by_probe" + "by_probe" or "by_shank". Adds grouping property to the recording based on the probes ("by_probe") + or shanks ("by_shank") in_place: bool False by default. Useful internally when extractor do self.set_probegroup(probe) @@ -253,18 +269,18 @@ def _extra_metadata_to_folder(self, folder): def create_dummy_probe_from_locations(self, locations, shape="circle", shape_params={"radius": 1}, axes="xy"): """ - Creates a 'dummy' probe based on locations. + Creates a "dummy" probe based on locations. Parameters ---------- locations : np.array Array with channel locations (num_channels, ndim) [ndim can be 2 or 3] - shape : str, optional - Electrode shapes, by default "circle" - shape_params : dict, optional - Shape parameters, by default {"radius": 1} - axes : str, optional - If ndim is 3, indicates the axes that define the plane of the electrodes, by default "xy" + shape : str, default: "circle" + Electrode shapes + shape_params : dict, default: {"radius": 1} + Shape parameters + axes : str, default: "xy" + If ndim is 3, indicates the axes that define the plane of the electrodes Returns ------- @@ -287,18 +303,18 @@ def create_dummy_probe_from_locations(self, locations, shape="circle", shape_par def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params={"radius": 1}, axes="xy"): """ - Sets a 'dummy' probe based on locations. + Sets a "dummy" probe based on locations. Parameters ---------- locations : np.array Array with channel locations (num_channels, ndim) [ndim can be 2 or 3] - shape : str, optional - Electrode shapes, by default "circle" - shape_params : dict, optional - Shape parameters, by default {"radius": 1} - axes : str, optional - If ndim is 3, indicates the axes that define the plane of the electrodes, by default "xy" + shape : str, default: default: "circle" + Electrode shapes + shape_params : dict, default: {"radius": 1} + Shape parameters + axes : "xy" | "yz" | "xz", default: "xy" + If ndim is 3, indicates the axes that define the plane of the electrodes """ probe = self.create_dummy_probe_from_locations(locations, shape=shape, shape_params=shape_params, axes=axes) self.set_probe(probe, in_place=True) @@ -386,8 +402,8 @@ def planarize(self, axes: str = "xy"): Parameters ---------- - axes : str, optional - The axes to keep, by default "xy" + axes : "xy" | "yz" |"xz", default: "xy" + The axes to keep Returns ------- @@ -412,8 +428,8 @@ def channel_slice(self, channel_ids, renamed_channel_ids=None): ---------- channel_ids : np.array or list The list of channels to keep - renamed_channel_ids : np.array or list, optional - A list of renamed channels, by default None + renamed_channel_ids : np.array or list, default: None + A list of renamed channels Returns ------- @@ -459,7 +475,7 @@ def frame_slice(self, start_frame, end_frame): def select_segments(self, segment_indices): """ - Return a new object with the segments specified by 'segment_indices'. + Return a new object with the segments specified by "segment_indices". Parameters ---------- @@ -475,14 +491,14 @@ def select_segments(self, segment_indices): def split_by(self, property="group", outputs="dict"): """ - Splits object based on a certain property (e.g. 'group') + Splits object based on a certain property (e.g. "group") Parameters ---------- - property : str, optional - The property to use to split the object, by default 'group' - outputs : str, optional - 'dict' or 'list', by default 'dict' + property : str, default: "group" + The property to use to split the object, default: "group" + outputs : "dict" | "list", default: "dict" + Whether to return a dict or a list Returns ------- diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index b4e3c11f55..02262fd88e 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -1,4 +1,6 @@ -from typing import List, Union +from __future__ import annotations + +from typing import Union from .base import BaseSegment from .baserecordingsnippets import BaseRecordingSnippets import numpy as np @@ -17,7 +19,7 @@ class BaseSnippets(BaseRecordingSnippets): _main_features = [] def __init__( - self, sampling_frequency: float, nbefore: Union[int, None], snippet_len: int, channel_ids: List, dtype + self, sampling_frequency: float, nbefore: Union[int, None], snippet_len: int, channel_ids: list, dtype ): BaseRecordingSnippets.__init__( self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype @@ -25,7 +27,7 @@ def __init__( self._nbefore = nbefore self._snippet_len = snippet_len - self._snippets_segments: List[BaseSnippetsSegment] = [] + self._snippets_segments: list[BaseSnippetsSegment] = [] # initialize main annotation and properties def __repr__(self): @@ -90,7 +92,7 @@ def get_snippets( self, indices=None, segment_index: Union[int, None] = None, - channel_ids: Union[List, None] = None, + channel_ids: Union[list, None] = None, return_scaled=False, ): segment_index = self._check_segment_index(segment_index) @@ -116,7 +118,7 @@ def get_snippets_from_frames( segment_index: Union[int, None] = None, start_frame: Union[int, None] = None, end_frame: Union[int, None] = None, - channel_ids: Union[List, None] = None, + channel_ids: Union[list, None] = None, return_scaled=False, ): segment_index = self._check_segment_index(segment_index) @@ -151,7 +153,7 @@ def _select_segments(self, segment_indices): def _save(self, format="npy", **save_kwargs): """ - At the moment only 'npy' and 'memory' avaiable: + At the moment only "npy" and "memory" avaiable: """ if format == "npy": @@ -220,18 +222,18 @@ def __init__(self): def get_snippets( self, - indices=None, - channel_indices: Union[List, None] = None, + indices, + channel_indices: Union[list, None] = None, ) -> np.ndarray: """ Return the snippets, optionally for a subset of samples and/or channels Parameters ---------- - indexes: (Union[int, None], optional) - indices of the snippets to return, or all if None. Defaults to None. - channel_indices: (Union[List, None], optional) - Indices of channels to return, or all channels if None. Defaults to None. + indices: list[int] + Indices of the snippets to return + channel_indices: Union[list, None], default: None + Indices of channels to return, or all channels if None Returns ------- @@ -262,10 +264,10 @@ def frames_to_indices(self, start_frame: Union[int, None] = None, end_frame: Uni Parameters ---------- - start_frame: (Union[int, None], optional) - start sample index, or zero if None. Defaults to None. - end_frame: (Union[int, None], optional) - end_sample, or number of samples if None. Defaults to None. + start_frame: Union[int, None], default: None + start sample index, or zero if None + end_frame: Union[int, None], default: None + end_sample, or number of samples if None Returns ------- diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 2a06a699cb..94b08d8cc3 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings from typing import List, Optional, Union @@ -66,9 +68,9 @@ def get_num_samples(self, segment_index=None): Parameters ---------- - segment_index : int, optional + segment_index : int or None, default: None The segment index to retrieve the number of samples for. - For multi-segment objects, it is required, by default None + For multi-segment objects, it is required Returns ------- @@ -157,9 +159,8 @@ def register_recording(self, recording, check_spike_frames=True): recording : BaseRecording Recording with the same number of segments as current sorting. Assigned to self._recording. - check_spike_frames : bool, optional + check_spike_frames : bool, default: True If True, assert for each segment that all spikes are within the recording's range. - By default True. """ assert np.isclose( self.get_sampling_frequency(), recording.get_sampling_frequency(), atol=0.1 @@ -221,8 +222,8 @@ def _save(self, format="numpy_folder", **save_kwargs): This function replaces the old CachesortingExtractor, but enables more engines for caching a results. - Since v0.98.0 'numpy_folder' is used by defult. - From v0.96.0 to 0.97.0 'npz_folder' was the default. + Since v0.98.0 "numpy_folder" is used by defult. + From v0.96.0 to 0.97.0 "npz_folder" was the default. """ if format == "numpy_folder": @@ -268,7 +269,7 @@ def get_total_num_spikes(self): ) return self.count_num_spikes_per_unit() - def count_num_spikes_per_unit(self): + def count_num_spikes_per_unit(self) -> dict: """ For each unit : get number of spikes across segments. @@ -317,8 +318,8 @@ def select_units(self, unit_ids, renamed_unit_ids=None): ---------- unit_ids : numpy.array or list List of unit ids to keep - renamed_unit_ids : numpy.array or list, optional - If given, the kept unit ids are renamed, by default None + renamed_unit_ids : numpy.array or list, default: None + If given, the kept unit ids are renamed Returns ------- @@ -432,23 +433,22 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac Parameters ---------- - concatenated: bool - With concatenated=True (default) the output is one numpy "spike vector" with spikes from all segments. + concatenated: bool, default: True + With concatenated=True the output is one numpy "spike vector" with spikes from all segments. With concatenated=False the output is a list "spike vector" by segment. - extremum_channel_inds: None or dict - If a dictionnary of unit_id to channel_ind is given then an extra field 'channel_index'. + extremum_channel_inds: None or dict, default: None + If a dictionnary of unit_id to channel_ind is given then an extra field "channel_index". This can be convinient for computing spikes postion after sorter. - This dict can be computed with `get_template_extremum_channel(we, outputs="index")` - use_cache: bool - When True (default) the spikes vector is cached as an attribute of the object (`_cached_spike_vector`). + use_cache: bool, default: True + When True the spikes vector is cached as an attribute of the object (`_cached_spike_vector`). This caching only occurs when extremum_channel_inds=None. Returns ------- spikes: np.array - Structured numpy array ('sample_index', 'unit_index', 'segment_index') with all spikes - Or ('sample_index', 'unit_index', 'segment_index', 'channel_index') if extremum_channel_inds + Structured numpy array ("sample_index", "unit_index", "segment_index") with all spikes + Or ("sample_index", "unit_index", "segment_index", "channel_index") if extremum_channel_inds is given """ @@ -597,8 +597,8 @@ def get_unit_spike_train( Parameters ---------- unit_id - start_frame: int, optional - end_frame: int, optional + start_frame: int, default: None + end_frame: int, default: None Returns ------- diff --git a/src/spikeinterface/core/binaryfolder.py b/src/spikeinterface/core/binaryfolder.py index d185111b8c..b3ae8c3145 100644 --- a/src/spikeinterface/core/binaryfolder.py +++ b/src/spikeinterface/core/binaryfolder.py @@ -13,7 +13,7 @@ class BinaryFolderRecording(BinaryRecordingExtractor): BinaryFolderRecording is an internal format used in spikeinterface. It is a BinaryRecordingExtractor + metadata contained in a folder. - It is created with the function: `recording.save(format='binary', folder='/myfolder')` + It is created with the function: `recording.save(format="binary", folder="/myfolder")` Parameters ---------- diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index b45290caa5..d8c6512a38 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -26,19 +26,19 @@ class BinaryRecordingExtractor(BaseRecording): Number of channels dtype: str or dtype The dtype of the binary file - time_axis: int - The axis of the time dimension (default 0: F order) - t_starts: None or list of float + time_axis: int, default: 0 + The axis of the time dimension + t_starts: None or list of float, default: None Times in seconds of the first sample for each segment - channel_ids: list (optional) + channel_ids: list, default: None A list of channel ids - file_offset: int (optional) + file_offset: int, default: 0 Number of bytes in the file to offset by during memmap instantiation. - gain_to_uV: float or array-like (optional) + gain_to_uV: float or array-like, default: None The gain to apply to the traces - offset_to_uV: float or array-like + offset_to_uV: float or array-like, default: None The offset to apply to the traces - is_filtered: bool or None + is_filtered: bool or None, default: None If True, the recording is assumed to be filtered. If None, is_filtered is not set. Notes @@ -140,8 +140,8 @@ def write_recording(recording, file_paths, dtype=None, **job_kwargs): The recording extractor object to be saved in .dat format file_paths: str The path to the file. - dtype: dtype - Type of the saved data. Default float32. + dtype: dtype, default: None + Type of the saved data {} """ write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs) diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 3a21e356a6..9987edadc6 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -1,4 +1,5 @@ -from typing import List, Union +from __future__ import annotations +from typing import Union import numpy as np @@ -87,7 +88,7 @@ def get_traces( self, start_frame: Union[int, None] = None, end_frame: Union[int, None] = None, - channel_indices: Union[List, None] = None, + channel_indices: Union[list, None] = None, ) -> np.ndarray: parent_indices = self._parent_channel_indices[channel_indices] traces = self._parent_recording_segment.get_traces(start_frame, end_frame, parent_indices) @@ -181,20 +182,18 @@ def get_frames(self, indices=None): def get_snippets( self, - indices, - channel_indices: Union[List, None] = None, + indices: list[int], + channel_indices: Union[list, None] = None, ) -> np.ndarray: """ Return the snippets, optionally for a subset of samples and/or channels Parameters ---------- - indexes: (Union[int, None], optional) - start sample index, or zero if None. Defaults to None. - end_frame: (Union[int, None], optional) - end_sample, or number of samples if None. Defaults to None. - channel_indices: (Union[List, None], optional) - Indices of channels to return, or all channels if None. Defaults to None. + indices: list[int] + Indices of the snippets to return + channel_indices: Union[List, None], default: None + Indices of channels to return, or all channels if None Returns ------- diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 106a794f6e..2d387da239 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -172,10 +172,10 @@ def read_binary_recording(file, num_channels, dtype, time_axis=0, offset=0): Number of channels dtype: dtype dtype of the file - time_axis: 0 (default) or 1 + time_axis: 0 or 1, default: 0 If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file. If 1, the traces shape (nb_channel, nb_sample) is kept in the file. - offset: int + offset: int, default: 0 number of offset bytes """ @@ -243,7 +243,7 @@ def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx): def write_binary_recording( recording, - file_paths=None, + file_paths, dtype=None, add_file_extension=True, byte_offset=0, @@ -261,19 +261,17 @@ def write_binary_recording( ---------- recording: RecordingExtractor The recording extractor object to be saved in .dat format - file_path: str + file_path: str or list[str] The path to the file. - dtype: dtype - Type of the saved data. Default float32. - add_file_extension: bool - If True (default), file the '.raw' file extension is added if the file name is not a 'raw', 'bin', or 'dat' - byte_offset: int - Offset in bytes (default 0) to for the binary file (e.g. to write a header) - auto_cast_uint: bool - If True (default), unsigned integers are automatically cast to int if the specified dtype is signed + dtype: dtype or None, default: None + Type of the saved data + If True, file the ".raw" file extension is added if the file name is not a "raw", "bin", or "dat" + byte_offset: int, default: 0 + Offset in bytes to for the binary file (e.g. to write a header) + auto_cast_uint: bool, default: True + If True, unsigned integers are automatically cast to int if the specified dtype is signed {} """ - assert file_paths is not None, "Provide 'file_path'" job_kwargs = fix_job_kwargs(job_kwargs) file_path_list = [file_paths] if not isinstance(file_paths, list) else file_paths @@ -430,12 +428,12 @@ def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint= ---------- recording: RecordingExtractor The recording extractor object to be saved in .dat format - dtype: dtype - Type of the saved data. Default float32. - verbose: bool + dtype: dtype, default: None + Type of the saved data + verbose: bool, default: False If True, output is verbose (when chunks are used) - auto_cast_uint: bool - If True (default), unsigned integers are automatically cast to int if the specified dtype is signed + auto_cast_uint: bool, default: True + If True, unsigned integers are automatically cast to int if the specified dtype is signed {} Returns @@ -511,33 +509,33 @@ def write_to_h5_dataset_format( recording: RecordingExtractor The recording extractor object to be saved in .dat format dataset_path: str - Path to dataset in h5 file (e.g. '/dataset') + Path to dataset in h5 file (e.g. "/dataset") segment_index: int index of segment - save_path: str + save_path: str, default: None The path to the file. - file_handle: file handle + file_handle: file handle, default: None The file handle to dump data. This can be used to append data to an header. In case file_handle is given, the file is NOT closed after writing the binary data. - time_axis: 0 (default) or 1 + time_axis: 0 or 1, default: 0 If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file. If 1, the traces shape (nb_channel, nb_sample) is kept in the file. - single_axis: bool, default False - If True, a single-channel recording is saved as a one dimensional array. - dtype: dtype - Type of the saved data. Default float32. - chunk_size: None or int + single_axis: bool, default: False + If True, a single-channel recording is saved as a one dimensional array + dtype: dtype, default: None + Type of the saved data + chunk_size: None or int, default: None Number of chunks to save the file in. This avoid to much memory consumption for big files. - If None and 'chunk_memory' is given, the file is saved in chunks of 'chunk_memory' MB (default 500MB) - chunk_memory: None or str - Chunk size in bytes must endswith 'k', 'M' or 'G' (default '500M') - verbose: bool + If None and "chunk_memory" is given, the file is saved in chunks of "chunk_memory" MB + chunk_memory: None or str, default: "500M" + Chunk size in bytes must endswith "k", "M" or "G" + verbose: bool, default: False If True, output is verbose (when chunks are used) - auto_cast_uint: bool - If True (default), unsigned integers are automatically cast to int if the specified dtype is signed - return_scaled : bool, optional + auto_cast_uint: bool, default: True + If True, unsigned integers are automatically cast to int if the specified dtype is signed + return_scaled : bool, default: False If True and the recording has scaling (gain_to_uV and offset_to_uV properties), - traces are dumped to uV, by default False + traces are dumped to uV """ import h5py @@ -655,18 +653,18 @@ def write_traces_to_zarr( Storage options for zarr `store`. E.g., if "s3://" or "gcs://" they can provide authentication methods, etc. dataset_paths: list List of paths to traces datasets in the zarr group - channel_chunk_size: int or None - Channels per chunk. Default None (chunking in time only) - dtype: dtype - Type of the saved data. Default float32. - compressor: zarr compressor or None + channel_chunk_size: int or None, default: None (chunking in time only) + Channels per chunk + dtype: dtype, default: None + Type of the saved data + compressor: zarr compressor or None, default: None Zarr compressor - filters: list + filters: list, default: None List of zarr filters - verbose: bool + verbose: bool, default: False If True, output is verbose (when chunks are used) - auto_cast_uint: bool - If True (default), unsigned integers are automatically cast to int if the specified dtype is signed + auto_cast_uint: bool, default: True + If True, unsigned integers are automatically cast to int if the specified dtype is signed {} """ assert dataset_paths is not None, "Provide 'file_path'" @@ -804,10 +802,10 @@ def recursive_path_modifier(d, func, target="path", copy=True) -> dict: Extractor dictionary func : function Function to apply to the path. It must take a path as input and return a path - target : str, optional - String to match to dictionary key, by default 'path' - copy : bool, optional - If True the original dictionary is deep copied, by default True (at first call) + target : str, default: "path" + String to match to dictionary key + copy : bool, default: True (at first call) + If True the original dictionary is deep copied Returns ------- @@ -870,7 +868,7 @@ def convert_seconds_to_str(seconds: float, long_notation: bool = True) -> str: ---------- seconds : float The duration in seconds. - long_notation : bool, optional, default: True + long_notation : bool, default: True Whether to display the time with additional units (such as milliseconds, minutes, hours, or days). If set to True, the function will display a more detailed representation of the duration, including other units alongside the primary diff --git a/src/spikeinterface/core/datasets.py b/src/spikeinterface/core/datasets.py index e3b6d7b22d..7e03492d1a 100644 --- a/src/spikeinterface/core/datasets.py +++ b/src/spikeinterface/core/datasets.py @@ -20,19 +20,17 @@ def download_dataset( Parameters ---------- - repo : str, optional - The repository to download the dataset from, - defaults to: 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' - remote_path : str + repo : str, default: "https://gin.g-node.org/NeuralEnsemble/ephy_testing_data" + The repository to download the dataset from + remote_path : str, default: "mearec/mearec_test_10s.h5" A specific subdirectory in the repository to download (e.g. Mearec, SpikeGLX, etc) - defaults to: "mearec/mearec_test_10s.h5" - local_folder : str, optional + local_folder : str, default: None The destination folder / directory to download the dataset to. defaults to the path "get_global_dataset_folder()" / f{repo_name} (see `spikeinterface.core.globals`) - update_if_exists : bool, optional - Forces re-download of the dataset if it already exists, by default False - unlock : bool, optional - Use to enable the edition of the downloaded file content, by default False + update_if_exists : bool, default: False + Forces re-download of the dataset if it already exists, default: False + unlock : bool, default: False + Use to enable the edition of the downloaded file content, default: False Returns ------- diff --git a/src/spikeinterface/core/frameslicesorting.py b/src/spikeinterface/core/frameslicesorting.py index ed1391b0e2..5cd7adab34 100644 --- a/src/spikeinterface/core/frameslicesorting.py +++ b/src/spikeinterface/core/frameslicesorting.py @@ -20,15 +20,15 @@ class FrameSliceSorting(BaseSorting): Parameters ---------- parent_sorting: BaseSorting - start_frame: None or int + start_frame: None or int, default: None Earliest included frame in the parent sorting(/recording). Spike times(/traces) are re-referenced to start_frame in the - sliced objects. Set to 0 by default. - end_frame: None or int + sliced objects. Set to 0 if None. + end_frame: None or int, default: None Latest frame in the parent sorting(/recording). As for usual python slicing, the end frame is excluded (such that the max spike frame in the sliced sorting is `end_frame - start_frame - 1`) - If None (default), the end_frame is either: + If None, the end_frame is either: - The total number of samples, if a recording is assigned - The maximum spike frame + 1, if no recording is assigned """ diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 44ea02d32c..c670474f0e 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -40,19 +40,19 @@ def generate_recording( Parameters ---------- - num_channels : int, default 2 + num_channels : int, default: 2 The number of channels in the recording. - sampling_frequency : float, default 30000. (in Hz) - The sampling frequency of the recording, by default 30000. - durations: List[float], default [5.0, 2.5] - The duration in seconds of each segment in the recording, by default [5.0, 2.5]. + sampling_frequency : float, default: 30000. (in Hz) + The sampling frequency of the recording, default: 30000. + durations: List[float], default: [5.0, 2.5] + The duration in seconds of each segment in the recording, default: [5.0, 2.5]. Note that the number of segments is determined by the length of this list. - set_probe: bool, default True - ndim : int, default 2 - The number of dimensions of the probe, by default 2. Set to 3 to make 3 dimensional probes. + set_probe: bool, default: True + ndim : int, default: 2 + The number of dimensions of the probe, default: 2. Set to 3 to make 3 dimensional probes. seed : Optional[int] A seed for the np.ramdom.default_rng function - mode: str ["lazy", "legacy"] Default "lazy". + mode: str ["lazy", "legacy"], default: "lazy". "legacy": generate a NumpyRecording with white noise. This mode is kept for backward compatibility and will be deprecated version 0.100.0. "lazy": return a NoiseGeneratorRecording instance. @@ -172,7 +172,7 @@ def generate_sorting( duration=durations[segment_index], refractory_period_ms=refractory_period_ms, firing_rates=firing_rates, - seed=seed, + seed=seed + segment_index, ) if empty_units is not None: @@ -342,9 +342,9 @@ def synthesize_random_firings( firing_rates: float or list[float] The firing rate of each unit (in Hz). If float, all units will have the same firing rate. - add_shift_shuffle: bool, default False + add_shift_shuffle: bool, default: False Optionaly add a small shuffle on half spike to autocorrelogram - seed: int, optional + seed: int, default: None seed for the generator Returns @@ -543,11 +543,11 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol duration : float Length of simulated recording (in seconds). baseline_rate : float - Firing rate for 'true' spikes. + Firing rate for "true" spikes. num_violations : int Number of contaminating spikes. - violation_delta : float, optional - Temporal offset of contaminating spikes (in seconds), by default 1e-5. + violation_delta : float, default: 1e-5 + Temporal offset of contaminating spikes (in seconds) Returns ------- @@ -586,11 +586,11 @@ class NoiseGeneratorRecording(BaseRecording): The sampling frequency of the recorder. durations : List[float] The durations of each segment in seconds. Note that the length of this list is the number of segments. - noise_level: float, default 1: + noise_level: float, default: 1 Std of the white noise - dtype : Optional[Union[np.dtype, str]], default='float32' + dtype : Optional[Union[np.dtype, str]], default: "float32" The dtype of the recording. Note that only np.float32 and np.float64 are supported. - seed : Optional[int], default=None + seed : Optional[int], default: None The seed for np.random.default_rng. strategy : "tile_pregenerated" or "on_the_fly" The strategy of generating noise chunk: @@ -763,8 +763,9 @@ def generate_recording_by_size( The size in gigabytes (GiB) of the recording. num_channels: int Number of channels. - seed : int, optional - The seed for np.random.default_rng, by default None + seed : int, default: None + The seed for np.random.default_rng + Returns ------- GeneratorRecording @@ -921,7 +922,7 @@ def generate_templates( Cut out in ms after spike peak. seed: int or None A seed for random. - dtype: numpy.dtype, default "float32" + dtype: numpy.dtype, default: "float32" Templates dtype upsample_factor: None or int If not None then template are generated upsampled by this factor. @@ -931,14 +932,14 @@ def generate_templates( An optional dict containing parameters per units. Keys are parameter names: - * 'alpha': amplitude of the action potential in a.u. (default range: (6'000-9'000)) - * 'depolarization_ms': the depolarization interval in ms (default range: (0.09-0.14)) - * 'repolarization_ms': the repolarization interval in ms (default range: (0.5-0.8)) - * 'recovery_ms': the recovery interval in ms (default range: (1.0-1.5)) - * 'positive_amplitude': the positive amplitude in a.u. (default range: (0.05-0.15)) (negative is always -1) - * 'smooth_ms': the gaussian smooth in ms (default range: (0.03-0.07)) - * 'decay_power': the decay power (default range: (1.2-1.8)) - * 'propagation_speed': mimic a propagation delay with a kind of a "speed" (default range: (250., 350.)). + * "alpha": amplitude of the action potential in a.u. (default range: (6'000-9'000)) + * "depolarization_ms": the depolarization interval in ms (default range: (0.09-0.14)) + * "repolarization_ms": the repolarization interval in ms (default range: (0.5-0.8)) + * "recovery_ms": the recovery interval in ms (default range: (1.0-1.5)) + * "positive_amplitude": the positive amplitude in a.u. (default range: (0.05-0.15)) (negative is always -1) + * "smooth_ms": the gaussian smooth in ms (default range: (0.03-0.07)) + * "decay_power": the decay power (default range: (1.2-1.8)) + * "propagation_speed": mimic a propagation delay with a kind of a "speed" (default range: (250., 350.)). Values contains vector with same size of num_units. If the key is not in dict then it is generated using unit_params_range unit_params_range: dict of tuple @@ -1068,10 +1069,10 @@ class InjectTemplatesRecording(BaseRecording): Shape can be: * (num_units, num_samples, num_channels): standard case * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce sampling jitter. - nbefore: list[int] | int | None + nbefore: list[int] | int | None, default: None Where is the center of the template for each unit? If None, will default to the highest peak. - amplitude_factor: list[float] | float | None, default None + amplitude_factor: list[float] | float | None, default: None The amplitude of each spike for each unit. Can be None (no scaling). Can be scalar all spikes have the same factor (certainly useless). @@ -1082,7 +1083,7 @@ class InjectTemplatesRecording(BaseRecording): num_samples: list[int] | int | None The number of samples in the recording per segment. You can use int for mono-segment objects. - upsample_vector: np.array or None, default None. + upsample_vector: np.array or None, default: None. When templates is 4d we can simulate a jitter. Optional the upsample_vector is the jitter index with a number per spike in range 0-templates.sahpe[3] @@ -1376,13 +1377,13 @@ def generate_ground_truth_recording( Parameters ---------- - durations: list of float, default [10.] + durations: list of float, default: [10.] Durations in seconds for all segments. - sampling_frequency: float, default 25000 + sampling_frequency: float, default: 25000 Sampling frequency. - num_channels: int, default 4 + num_channels: int, default: 4 Number of channels, not used when probe is given. - num_units: int, default 10. + num_units: int, default: 10 Number of units, not used when sorting is given. sorting: Sorting or None An external sorting object. If not provide, one is genrated. @@ -1396,11 +1397,11 @@ def generate_ground_truth_recording( Shape can be: * (num_units, num_samples, num_channels): standard case * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce jitter. - ms_before: float, default 1.5 + ms_before: float, default: 1.5 Cut out in ms before spike peak. - ms_after: float, default 3. + ms_after: float, default: 3 Cut out in ms after spike peak. - upsample_factor: None or int, default None + upsample_factor: None or int, default: None A upsampling factor used only when templates are not provided. upsample_vector: np.array or None Optional the upsample_vector can given. This has the same shape as spike_vector @@ -1412,7 +1413,7 @@ def generate_ground_truth_recording( Dict used to generated template when template not provided. generate_templates_kwargs: dict Dict used to generated template when template not provided. - dtype: np.dtype, default "float32" + dtype: np.dtype, default: "float32" The dtype of the recording. seed: int or None Seed for random initialization. diff --git a/src/spikeinterface/core/globals.py b/src/spikeinterface/core/globals.py index d039206296..aea11c90be 100644 --- a/src/spikeinterface/core/globals.py +++ b/src/spikeinterface/core/globals.py @@ -1,7 +1,7 @@ """ -'global_tmp_folder' is a variable that is generated or can be set manually. +"global_tmp_folder" is a variable that is generated or can be set manually. -It is useful when we do extractor.save(name='name'). +It is useful when we do extractor.save(name="name"). """ import tempfile diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index cf7a67489c..a9df4f4626 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -21,11 +21,11 @@ - chunk_size: int Number of samples per chunk - chunk_memory: str - Memory usage for each job (e.g. '100M', '1G') + Memory usage for each job (e.g. "100M", "1G") - total_memory: str - Total memory usage (e.g. '500M', '2G') + Total memory usage (e.g. "500M", "2G") - chunk_duration : str or float or None - Chunk duration in s if float or with units if str (e.g. '1s', '500ms') + Chunk duration in s if float or with units if str (e.g. "1s", "500ms") * n_jobs: int Number of jobs to use. With -1 the number of jobs is the same as number of cores * progress_bar: bool @@ -181,7 +181,7 @@ def ensure_chunk_size( recording, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs ): """ - 'chunk_size' is the traces.shape[0] for each worker. + "chunk_size" is the traces.shape[0] for each worker. Flexible chunk_size setter with 3 ways: * "chunk_size": is the length in sample for each chunk independently of channel count and dtype. @@ -196,12 +196,12 @@ def ensure_chunk_size( chunk_size: int or None size for one chunk per job chunk_memory: str or None - must endswith 'k', 'M' or 'G' + must end with "k", "M" or "G" total_memory: str or None - must endswith 'k', 'M' or 'G' + must end with "k", "M" or "G" chunk_duration: None or float or str Units are second if float. - If str then the str must contain units(e.g. '1s', '500ms') + If str then the str must contain units(e.g. "1s", "500ms") """ if chunk_size is not None: # manual setting @@ -255,7 +255,7 @@ class ChunkRecordingExecutor: * at once if chunk_size is None (high RAM usage) * in parallel with ProcessPoolExecutor (higher speed) - The initializer ('init_func') allows to set a global context to avoid heavy serialization + The initializer ("init_func") allows to set a global context to avoid heavy serialization (for examples, see implementation in `core.WaveformExtractor`). Parameters @@ -265,43 +265,43 @@ class ChunkRecordingExecutor: func: function Function that runs on each chunk init_func: function - Initializer function to set the global context (accessible by 'func') + Initializer function to set the global context (accessible by "func") init_args: tuple Arguments for init_func verbose: bool If True, output is verbose - progress_bar: bool - If True, a progress bar is printed to monitor the progress of the process - handle_returns: bool + job_name: str, default: "" + Job name + handle_returns: bool, default: False If True, the function can return values - gather_func: None or callable + gather_func: None or callable, default: None Optional function that is called in the main thread and retrieves the results of each worker. This function can be used instead of `handle_returns` to implement custom storage on-the-fly. - n_jobs: int - Number of jobs to be used (default 1). Use -1 to use as many jobs as number of cores - total_memory: str + n_jobs: int, default: 1 + Number of jobs to be used. Use -1 to use as many jobs as number of cores + total_memory: str, default: None Total memory (RAM) to use (e.g. "1G", "500M") - chunk_memory: str + chunk_memory: str, default: None Memory per chunk (RAM) to use (e.g. "1G", "500M") - chunk_size: int or None - Size of each chunk in number of samples. If 'total_memory' or 'chunk_memory' are used, it is ignored. + chunk_size: int or None, default: None + Size of each chunk in number of samples. If "total_memory" or "chunk_memory" are used, it is ignored. chunk_duration : str or float or None - Chunk duration in s if float or with units if str (e.g. '1s', '500ms') - mp_context : str or None - "fork" (default) or "spawn". If None, the context is taken by the recording.get_preferred_mp_context(). + Chunk duration in s if float or with units if str (e.g. "1s", "500ms") + mp_context : str or None, default: None + "fork" or "spawn". If None, the context is taken by the recording.get_preferred_mp_context(). "fork" is only available on UNIX systems. - job_name: str - Job name - max_threads_per_process: int or None + max_threads_per_process: int or None, default: None Limit the number of thread per process using threadpoolctl modules. This used only when n_jobs>1 If None, no limits. + progress_bar: bool, default: False + If True, a progress bar is printed to monitor the progress of the process Returns ------- res: list - If 'handle_returns' is True, the results for each chunk process + If "handle_returns" is True, the results for each chunk process """ def __init__( diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index a6dabf77b5..a00df98e05 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -55,10 +55,10 @@ def __init__( ---------- recording : BaseRecording The recording object. - parents : Optional[List[PipelineNode]], optional - Pass parents nodes to perform a previous computation, by default None - return_output : bool or tuple of bool - Whether or not the output of the node is returned by the pipeline, by default False + parents : Optional[List[PipelineNode]], default: None + Pass parents nodes to perform a previous computation + return_output : bool or tuple of bool, default: True + Whether or not the output of the node is returned by the pipeline. When a Node have several toutputs then this can be a tuple of bool. @@ -154,10 +154,10 @@ class SpikeRetriever(PeakSource): If False, the max channel is computed for each spike given a radius around the template max channel. extremum_channel_inds: dict of int The extremum channel index dict given from template. - radius_um: float (default 50.) + radius_um: float, default: 50 The radius to find the real max channel. Used only when channel_from_template=False - peak_sign: str (default "neg") + peak_sign: str, default: "neg" Peak sign to find the max channel. Used only when channel_from_template=False """ @@ -256,14 +256,14 @@ def __init__( ---------- recording : BaseRecording The recording object. - parents : Optional[List[PipelineNode]], optional - Pass parents nodes to perform a previous computation, by default None - return_output : bool, optional - Whether or not the output of the node is returned by the pipeline, by default False - ms_before : float, optional - The number of milliseconds to include before the peak of the spike, by default 1. - ms_after : float, optional - The number of milliseconds to include after the peak of the spike, by default 1. + ms_before : float + The number of milliseconds to include before the peak of the spike + ms_after : float + The number of milliseconds to include after the peak of the spike + parents : Optional[List[PipelineNode]], default: None + Pass parents nodes to perform a previous computation + return_output : bool, default: False + Whether or not the output of the node is returned by the pipeline """ PipelineNode.__init__(self, recording=recording, parents=parents, return_output=return_output) @@ -291,14 +291,15 @@ def __init__( ---------- recording : BaseRecording The recording object. - parents : Optional[List[PipelineNode]], optional - Pass parents nodes to perform a previous computation, by default None - return_output : bool, optional - Whether or not the output of the node is returned by the pipeline, by default False - ms_before : float, optional - The number of milliseconds to include before the peak of the spike, by default 1. - ms_after : float, optional - The number of milliseconds to include after the peak of the spike, by default 1. + ms_before : float + The number of milliseconds to include before the peak of the spike + ms_after : float + The number of milliseconds to include after the peak of the spike + parents : Optional[List[PipelineNode]], default: None + Pass parents nodes to perform a previous computation + return_output : bool, default: False + Whether or not the output of the node is returned by the pipeline + """ WaveformsNode.__init__( @@ -344,17 +345,15 @@ def __init__( Parameters ---------- recording : BaseRecording - The recording object. - parents : Optional[List[PipelineNode]], optional - Pass parents nodes to perform a previous computation, by default None - return_output : bool, optional - Whether or not the output of the node is returned by the pipeline, by default False - ms_before : float, optional - The number of milliseconds to include before the peak of the spike, by default 1. - ms_after : float, optional - The number of milliseconds to include after the peak of the spike, by default 1. - - + The recording object + ms_before : float + The number of milliseconds to include before the peak of the spike + ms_after : float + The number of milliseconds to include after the peak of the spike + parents : Optional[List[PipelineNode]], default: None + Pass parents nodes to perform a previous computation + return_output : bool, default: False + Whether or not the output of the node is returned by the pipeline """ WaveformsNode.__init__( self, diff --git a/src/spikeinterface/core/npyfoldersnippets.py b/src/spikeinterface/core/npyfoldersnippets.py index c002bbe044..271a8f4f12 100644 --- a/src/spikeinterface/core/npyfoldersnippets.py +++ b/src/spikeinterface/core/npyfoldersnippets.py @@ -13,7 +13,7 @@ class NpyFolderSnippets(NpySnippetsExtractor): NpyFolderSnippets is an internal format used in spikeinterface. It is a NpySnippetsExtractor + metadata contained in a folder. - It is created with the function: `snippets.save(format='npy', folder='/myfolder')` + It is created with the function: `snippets.save(format="npy", folder="/myfolder")` Parameters ---------- diff --git a/src/spikeinterface/core/npysnippetsextractor.py b/src/spikeinterface/core/npysnippetsextractor.py index 69c48356e5..40fbfac4d3 100644 --- a/src/spikeinterface/core/npysnippetsextractor.py +++ b/src/spikeinterface/core/npysnippetsextractor.py @@ -109,12 +109,10 @@ def get_snippets( Parameters ---------- - indexes: (Union[int, None], optional) - start sample index, or zero if None. Defaults to None. - end_frame: (Union[int, None], optional) - end_sample, or number of samples if None. Defaults to None. - channel_indices: (Union[List, None], optional) - Indices of channels to return, or all channels if None. Defaults to None. + indices: list[int] + Indices of the snippets to return, or all if None + channel_indices: Union[List, None], default: None + Indices of channels to return, or all channels if None Returns ------- @@ -134,10 +132,10 @@ def frames_to_indices(self, start_frame: Union[int, None] = None, end_frame: Uni Parameters ---------- - start_frame: (Union[int, None], optional) - start sample index, or zero if None. Defaults to None. - end_frame: (Union[int, None], optional) - end_sample, or number of samples if None. Defaults to None. + start_frame: Union[int, None], default: None + start sample index, or zero if None + end_frame: Union[int, None], default: None + end_sample, or number of samples if None Returns ------- diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 3d7ec6cd1a..82075e638c 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from spikeinterface.core import ( BaseRecording, @@ -14,7 +16,7 @@ from multiprocessing.shared_memory import SharedMemory -from typing import List, Union +from typing import Union class NumpyRecording(BaseRecording): @@ -169,10 +171,11 @@ def from_times_labels(times_list, labels_list, sampling_frequency, unit_ids=None Parameters ---------- times_list: list of array (or array) - An array of spike times (in frames). + An array of spike times (in frames) labels_list: list of array (or array) - An array of spike labels corresponding to the given times. - unit_ids: (None by default) the explicit list of unit_ids that should be extracted from labels_list + An array of spike labels corresponding to the given times + unit_ids: list or None, default: None + The explicit list of unit_ids that should be extracted from labels_list If None, then it will be np.unique(labels_list) """ @@ -547,19 +550,17 @@ def __init__(self, snippets, spikesframes): def get_snippets( self, indices, - channel_indices: Union[List, None] = None, + channel_indices: Union[list, None] = None, ) -> np.ndarray: """ Return the snippets, optionally for a subset of samples and/or channels Parameters ---------- - indexes: (Union[int, None], optional) - start sample index, or zero if None. Defaults to None. - end_frame: (Union[int, None], optional) - end_sample, or number of samples if None. Defaults to None. - channel_indices: (Union[List, None], optional) - Indices of channels to return, or all channels if None. Defaults to None. + indices: list[int] + Indices of the snippets to return + channel_indices: Union[list, None], default: None + Indices of channels to return, or all channels if None Returns ------- @@ -579,11 +580,10 @@ def frames_to_indices(self, start_frame: Union[int, None] = None, end_frame: Uni Parameters ---------- - start_frame: (Union[int, None], optional) - start sample index, or zero if None. Defaults to None. - end_frame: (Union[int, None], optional) - end_sample, or number of samples if None. Defaults to None. - + start_frame: Union[int, None], default: None + start sample index, or zero if None + end_frame: Union[int, None], default: None + end_sample, or number of samples if None Returns ------- snippets: slice diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index ff9cd99389..34313cd7ae 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -23,16 +23,19 @@ def get_random_data_chunks( ---------- recording: BaseRecording The recording to get random chunks from - return_scaled: bool + return_scaled: bool, default: False If True, returned chunks are scaled to uV - num_chunks_per_segment: int + num_chunks_per_segment: int, default: 20 Number of chunks per segment - chunk_size: int + chunk_size: int, default: 10000 Size of a chunk in number of frames - concatenated: bool (default True) - If True chunk are concatenated along time axis. - seed: int + concatenated: bool, default: True + If True chunk are concatenated along time axis + seed: int, default: 0 Random seed + margin_frames: int, default: 0 + Margin in number of frames to avoid edge effects + Returns ------- chunk_list: np.array @@ -98,7 +101,7 @@ def get_closest_channels(recording, channel_ids=None, num_channels=None): The recording extractor to get closest channels channel_ids: list List of channels ids to compute there near neighborhood - num_channels: int, optional + num_channels: int, default: None Maximum number of neighborhood channels to return Returns @@ -135,7 +138,7 @@ def get_noise_levels( ): """ Estimate noise for each channel using MAD methods. - You can use standard deviation with `method='std'` + You can use standard deviation with `method="std"` Internally it samples some chunk across segment. And then, it use MAD estimator (more robust than STD) @@ -147,8 +150,8 @@ def get_noise_levels( The recording extractor to get noise levels return_scaled: bool If True, returned noise levels are scaled to uV - method: str - 'mad' or 'std' + method: "mad" | "std", default: "mad" + The method to use to estimate noise levels force_recompute: bool If True, noise levels are recomputed even if they are already stored in the recording extractor random_chunk_kwargs: dict @@ -312,10 +315,10 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y"), The input recording channel_ids : list/array or None If given, a subset of channels to order locations for - dimensions : str, tuple, or list + dimensions : str, tuple, or list, default: ('x', 'y') If str, it needs to be 'x', 'y', 'z'. If tuple or list, it sorts the locations in two dimensions using lexsort. - This approach is recommended since there is less ambiguity, by default ('x', 'y') + This approach is recommended since there is less ambiguity flip: bool, default: False If flip is False then the order is bottom first (starting from tip of the probe). If flip is True then the order is upper first. diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index 85e36cf7a5..614dd0b295 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -41,8 +41,8 @@ class AppendSegmentRecording(BaseRecording): ---------- recording_list : list of BaseRecording A list of recordings - sampling_frequency_max_diff : float - Maximum allowed difference of sampling frequencies across recordings (default 0) + sampling_frequency_max_diff : float, default: 0 + Maximum allowed difference of sampling frequencies across recordings """ def __init__(self, recording_list, sampling_frequency_max_diff=0): @@ -106,10 +106,10 @@ class ConcatenateSegmentRecording(BaseRecording): ---------- recording_list : list of BaseRecording A list of recordings - ignore_times: bool - If True (default), time information (t_start, time_vector) is ignored when concatenating recordings. - sampling_frequency_max_diff : float - Maximum allowed difference of sampling frequencies across recordings (default 0) + ignore_times: bool, default: True + If True, time information (t_start, time_vector) is ignored when concatenating recordings + sampling_frequency_max_diff : float, default: 0 + Maximum allowed difference of sampling frequencies across recordings """ def __init__(self, recording_list, ignore_times=True, sampling_frequency_max_diff=0): @@ -284,8 +284,8 @@ class AppendSegmentSorting(BaseSorting): ---------- sorting_list : list of BaseSorting A list of sortings - sampling_frequency_max_diff : float - Maximum allowed difference of sampling frequencies across sortings (default 0) + sampling_frequency_max_diff : float, default: 0 + Maximum allowed difference of sampling frequencies across sortings """ def __init__(self, sorting_list, sampling_frequency_max_diff=0): @@ -345,15 +345,15 @@ class ConcatenateSegmentSorting(BaseSorting): A list of sortings. If `total_samples_list` is not provided, all sortings should have an assigned recording. Otherwise, all sortings should be monosegments. - total_samples_list : list[int] or None + total_samples_list : list[int] or None, default: None If the sortings have no assigned recording, the total number of samples of each of the concatenated (monosegment) sortings is pulled from this list. - ignore_times : bool - If True (default), time information (t_start, time_vector) is ignored + ignore_times : bool, default: True + If True, time information (t_start, time_vector) is ignored when concatenating the sortings' assigned recordings. - sampling_frequency_max_diff : float - Maximum allowed difference of sampling frequencies across sortings (default 0) + sampling_frequency_max_diff : float, default: 0 + Maximum allowed difference of sampling frequencies across sortings """ def __init__(self, sorting_list, total_samples_list=None, ignore_times=True, sampling_frequency_max_diff=0): @@ -523,12 +523,12 @@ class SplitSegmentSorting(BaseSorting): ---------- parent_sorting : BaseSorting Sorting with a single segment (e.g. from sorting concatenated recording) - recording_or_recording_list : list of recordings, ConcatenateSegmentRecording, or None + recording_or_recording_list : list of recordings, ConcatenateSegmentRecording, or None, default: None If list of recordings, uses the lengths of those recordings to split the sorting into smaller segments If ConcatenateSegmentRecording, uses the associated list of recordings to split the sorting into smaller segments - If None, looks for the recording associated with the sorting (default None) + If None, looks for the recording associated with the sorting """ def __init__(self, parent_sorting: BaseSorting, recording_or_recording_list=None): diff --git a/src/spikeinterface/core/sortingfolder.py b/src/spikeinterface/core/sortingfolder.py index c813c26442..c86eb0d796 100644 --- a/src/spikeinterface/core/sortingfolder.py +++ b/src/spikeinterface/core/sortingfolder.py @@ -19,7 +19,7 @@ class NumpyFolderSorting(BaseSorting): * a "numpysorting_info.json" containing sampling_frequency, unit_ids and num_segments * a metadata folder for units properties. - It is created with the function: `sorting.save(folder='/myfolder', format="numpy_folder")` + It is created with the function: `sorting.save(folder="/myfolder", format="numpy_folder")` """ @@ -80,7 +80,7 @@ class NpzFolderSorting(NpzSortingExtractor): * "npz.json" which the json description of NpzSortingExtractor * a metadata folder for units properties. - It is created with the function: `sorting.save(folder='/myfolder', format="npz_folder")` + It is created with the function: `sorting.save(folder="/myfolder", format="npz_folder")` Parameters ---------- diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 896e3800d7..07a57f7807 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -7,29 +7,29 @@ _sparsity_doc = """ method: str - * "best_channels": N best channels with the largest amplitude. Use the 'num_channels' argument to specify the + * "best_channels": N best channels with the largest amplitude. Use the "num_channels" argument to specify the number of channels. - * "radius": radius around the best channel. Use the 'radius_um' argument to specify the radius in um - * "snr": threshold based on template signal-to-noise ratio. Use the 'threshold' argument + * "radius": radius around the best channel. Use the "radius_um" argument to specify the radius in um + * "snr": threshold based on template signal-to-noise ratio. Use the "threshold" argument to specify the SNR threshold (in units of noise levels) - * "ptp": threshold based on the peak-to-peak values on every channels. Use the 'threshold' argument + * "ptp": threshold based on the peak-to-peak values on every channels. Use the "threshold" argument to specify the ptp threshold (in units of noise levels) * "energy": threshold based on the expected energy that should be present on the channels, - given their noise levels. Use the 'threshold' argument to specify the SNR threshold + given their noise levels. Use the "threshold" argument to specify the SNR threshold (in units of noise levels) - * "by_property": sparsity is given by a property of the recording and sorting(e.g. 'group'). - Use the 'by_property' argument to specify the property name. + * "by_property": sparsity is given by a property of the recording and sorting(e.g. "group"). + Use the "by_property" argument to specify the property name. peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') + Sign of the template to compute best channels ("neg", "pos", "both") num_channels: int - Number of channels for 'best_channels' method + Number of channels for "best_channels" method radius_um: float - Radius in um for 'radius' method + Radius in um for "radius" method threshold: float - Threshold in SNR 'threshold' method + Threshold in SNR "threshold" method by_property: object - Property name for 'by_property' method + Property name for "by_property" method """ @@ -71,19 +71,19 @@ class ChannelSparsity: Using the N best channels (largest template amplitude): - >>> sparsity = ChannelSparsity.from_best_channels(we, num_channels, peak_sign='neg') + >>> sparsity = ChannelSparsity.from_best_channels(we, num_channels, peak_sign="neg") Using a neighborhood by radius: - >>> sparsity = ChannelSparsity.from_radius(we, radius_um, peak_sign='neg') + >>> sparsity = ChannelSparsity.from_radius(we, radius_um, peak_sign="neg") Using a SNR threshold: - >>> sparsity = ChannelSparsity.from_snr(we, threshold, peak_sign='neg') + >>> sparsity = ChannelSparsity.from_snr(we, threshold, peak_sign="neg") Using a template energy threshold: >>> sparsity = ChannelSparsity.from_energy(we, threshold) - Using a recording/sorting property (e.g. 'group'): + Using a recording/sorting property (e.g. "group"): >>> sparsity = ChannelSparsity.from_property(we, by_property="group") @@ -251,7 +251,7 @@ def from_dict(cls, dictionary: dict): def from_best_channels(cls, we, num_channels, peak_sign="neg"): """ Construct sparsity from N best channels with the largest amplitude. - Use the 'num_channels' argument to specify the number of channels. + Use the "num_channels" argument to specify the number of channels. """ from .template_tools import get_template_amplitudes @@ -267,7 +267,7 @@ def from_best_channels(cls, we, num_channels, peak_sign="neg"): def from_radius(cls, we, radius_um, peak_sign="neg"): """ Construct sparsity from a radius around the best channel. - Use the 'radius_um' argument to specify the radius in um + Use the "radius_um" argument to specify the radius in um """ from .template_tools import get_template_extremum_channel @@ -285,7 +285,7 @@ def from_radius(cls, we, radius_um, peak_sign="neg"): def from_snr(cls, we, threshold, peak_sign="neg"): """ Construct sparsity from a thresholds based on template signal-to-noise ratio. - Use the 'threshold' argument to specify the SNR threshold. + Use the "threshold" argument to specify the SNR threshold. """ from .template_tools import get_template_amplitudes @@ -302,7 +302,7 @@ def from_snr(cls, we, threshold, peak_sign="neg"): def from_ptp(cls, we, threshold): """ Construct sparsity from a thresholds based on template peak-to-peak values. - Use the 'threshold' argument to specify the SNR threshold. + Use the "threshold" argument to specify the SNR threshold. """ mask = np.zeros((we.unit_ids.size, we.channel_ids.size), dtype="bool") @@ -317,7 +317,7 @@ def from_ptp(cls, we, threshold): def from_energy(cls, we, threshold): """ Construct sparsity from a threshold based on per channel energy ratio. - Use the 'threshold' argument to specify the SNR threshold. + Use the "threshold" argument to specify the SNR threshold. """ mask = np.zeros((we.unit_ids.size, we.channel_ids.size), dtype="bool") noise = np.sqrt(we.nsamples) * get_noise_levels(we.recording, return_scaled=we.return_scaled) @@ -331,8 +331,8 @@ def from_energy(cls, we, threshold): @classmethod def from_property(cls, we, by_property): """ - Construct sparsity witha property of the recording and sorting(e.g. 'group'). - Use the 'by_property' argument to specify the property name. + Construct sparsity witha property of the recording and sorting(e.g. "group"). + Use the "by_property" argument to specify the property name. """ # check consistency assert by_property in we.recording.get_property_keys(), f"Property {by_property} is not a recording property" diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index b6022e27c0..a6de2de2fa 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -19,8 +19,8 @@ def get_template_amplitudes( peak_sign: "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels mode: "extremum" | "at_index", default: "extremum" - 'extremum': max or min - 'at_index': take value at spike index + "extremum": max or min + "at_index": take value at spike index Returns ------- @@ -75,16 +75,16 @@ def get_template_extremum_channel( peak_sign: "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels mode: "extremum" | "at_index", default: "extremum" - 'extremum': max or min - 'at_index': take value at spike index + "extremum": max or min + "at_index": take value at spike index outputs: "id" | "index", default: "id" - * 'id': channel id - * 'index': channel index + * "id": channel id + * "index": channel index Returns ------- extremum_channels: dict - Dictionary with unit ids as keys and extremum channels (id or index based on 'outputs') + Dictionary with unit ids as keys and extremum channels (id or index based on "outputs") as values """ assert peak_sign in ("both", "neg", "pos") @@ -127,13 +127,13 @@ def get_template_channel_sparsity( The waveform extractor {} outputs: str - * 'id': channel id - * 'index': channel index + * "id": channel id + * "index": channel index Returns ------- sparsity: dict - Dictionary with unit ids as keys and sparse channel ids or indices (id or index based on 'outputs') + Dictionary with unit ids as keys and sparse channel ids or indices (id or index based on "outputs") as values """ from spikeinterface.core.sparsity import compute_sparsity @@ -223,8 +223,8 @@ def get_template_extremum_amplitude( Sign of the template to compute best channels mode: "extremum" | "at_index", default: "at_index" Where the amplitude is computed - 'extremum': max or min - 'at_index': take value at spike index + "extremum": max or min + "at_index": take value at spike index Returns ------- diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 38987a58e5..4326cd15aa 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -3,7 +3,7 @@ but check only for BaseRecording general methods. """ import json -import shutil +import pickle from pathlib import Path import pytest import numpy as np @@ -111,7 +111,7 @@ def test_BaseRecording(): rec2 = BaseExtractor.from_dict(d, base_folder=cache_folder) rec3 = load_extractor(d, base_folder=cache_folder) - # dump/load json + # dump/load json - relative to rec.dump_to_json(cache_folder / "test_BaseRecording_rel.json", relative_to=cache_folder) rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording_rel.json", base_folder=cache_folder) rec3 = load_extractor(cache_folder / "test_BaseRecording_rel.json", base_folder=cache_folder) @@ -128,6 +128,23 @@ def test_BaseRecording(): "/" not in data["kwargs"]["file_paths"][0] ) # Relative to parent folder, so there shouldn't be any '/' in the path. + # dump/load pkl - relative to + rec.dump_to_pickle(cache_folder / "test_BaseRecording_rel.pkl", relative_to=cache_folder) + rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording_rel.pkl", base_folder=cache_folder) + rec3 = load_extractor(cache_folder / "test_BaseRecording_rel.pkl", base_folder=cache_folder) + + # dump/load relative=True + rec.dump_to_pickle(cache_folder / "test_BaseRecording_rel_true.pkl", relative_to=True) + rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording_rel_true.pkl", base_folder=True) + rec3 = load_extractor(cache_folder / "test_BaseRecording_rel_true.pkl", base_folder=True) + check_recordings_equal(rec, rec2, return_scaled=False, check_annotations=True) + check_recordings_equal(rec, rec3, return_scaled=False, check_annotations=True) + with open(cache_folder / "test_BaseRecording_rel_true.pkl", "rb") as pkl_file: + data = pickle.load(pkl_file) + assert ( + "/" not in data["kwargs"]["file_paths"][0] + ) # Relative to parent folder, so there shouldn't be any '/' in the path. + # cache to binary folder = cache_folder / "simple_recording" rec.save(format="binary", folder=folder) diff --git a/src/spikeinterface/core/tests/test_channelslicerecording.py b/src/spikeinterface/core/tests/test_channelslicerecording.py index 08bb22a2c8..565a743084 100644 --- a/src/spikeinterface/core/tests/test_channelslicerecording.py +++ b/src/spikeinterface/core/tests/test_channelslicerecording.py @@ -4,7 +4,7 @@ import pytest import numpy as np -import probeinterface as pi +import probeinterface from spikeinterface.core import ChannelSliceRecording, BinaryRecordingExtractor @@ -58,7 +58,7 @@ def test_ChannelSliceRecording(): assert np.all(traces[:, 1] == 0) # with probe and after save() - probe = pi.generate_linear_probe(num_elec=num_chan) + probe = probeinterface.generate_linear_probe(num_elec=num_chan) probe.set_device_channel_indices(np.arange(num_chan)) rec_p = rec.set_probe(probe) rec_sliced3 = ChannelSliceRecording(rec_p, channel_ids=[0, 2], renamed_channel_ids=[3, 4]) diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 501fd8cc79..6d5a753cad 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -297,6 +297,48 @@ def test_extract_waveforms(): ) assert we4.sparsity is not None + # test with sparsity estimation + folder5 = cache_folder / "test_extract_waveforms_compute_sparsity_tmp_folder" + sparsity_temp_folder = cache_folder / "tmp_sparsity" + if folder5.is_dir(): + shutil.rmtree(folder5) + + we5 = extract_waveforms( + recording, + sorting, + folder5, + max_spikes_per_unit=100, + return_scaled=True, + sparse=True, + sparsity_temp_folder=sparsity_temp_folder, + method="radius", + radius_um=50.0, + n_jobs=2, + chunk_duration="500ms", + ) + assert we5.sparsity is not None + # tmp folder is cleaned up + assert not sparsity_temp_folder.is_dir() + + # should raise an error if sparsity_temp_folder is not empty + with pytest.raises(AssertionError): + if folder5.is_dir(): + shutil.rmtree(folder5) + sparsity_temp_folder.mkdir() + we5 = extract_waveforms( + recording, + sorting, + folder5, + max_spikes_per_unit=100, + return_scaled=True, + sparse=True, + sparsity_temp_folder=sparsity_temp_folder, + method="radius", + radius_um=50.0, + n_jobs=2, + chunk_duration="500ms", + ) + def test_recordingless(): durations = [30, 40] diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index d4ae140b90..a2b58daa24 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -41,9 +41,8 @@ class WaveformExtractor: rec_attributes: None or dict When recording is None then a minimal dict with some attributes is needed. - allow_unfiltered: bool + allow_unfiltered: bool, default: False If true, will accept unfiltered recording. - False by default. Returns ------- we: WaveformExtractor @@ -61,7 +60,7 @@ class WaveformExtractor: >>> # Retrieve >>> waveforms = we.get_waveforms(unit_id) - >>> template = we.get_template(unit_id, mode='median') + >>> template = we.get_template(unit_id, mode="median") >>> # Load from folder (in another session) >>> we = WaveformExtractor.load(folder) @@ -167,7 +166,7 @@ def load_from_folder( pass elif (folder / "recording.pickle").exists(): try: - recording = load_extractor(folder / "recording.pickle") + recording = load_extractor(folder / "recording.pickle", base_folder=folder) except: pass if recording is None: @@ -178,7 +177,7 @@ def load_from_folder( if (folder / "sorting.json").exists(): sorting = load_extractor(folder / "sorting.json", base_folder=folder) elif (folder / "sorting.pickle").exists(): - sorting = load_extractor(folder / "sorting.pickle") + sorting = load_extractor(folder / "sorting.pickle", base_folder=folder) else: raise FileNotFoundError("load_waveforms() impossible to find the sorting object (json or pickle)") @@ -288,15 +287,12 @@ def create( if recording.check_serializablility("json"): recording.dump(folder / "recording.json", relative_to=relative_to) elif recording.check_serializablility("pickle"): - # In this case we loose the relative_to!! - recording.dump(folder / "recording.pickle") + recording.dump(folder / "recording.pickle", relative_to=relative_to) if sorting.check_serializablility("json"): sorting.dump(folder / "sorting.json", relative_to=relative_to) elif sorting.check_serializablility("pickle"): - # In this case we loose the relative_to!! - # TODO later the dump to pickle should dump the dictionary and so relative could be put back - sorting.dump(folder / "sorting.pickle") + sorting.dump(folder / "sorting.pickle", relative_to=relative_to) else: warn( "Sorting object is not serializable to file, which might result in downstream errors for " @@ -650,9 +646,8 @@ def set_recording( rec_attributes: None or dict When recording is None then a minimal dict with some attributes is needed. - allow_unfiltered: bool + allow_unfiltered: bool, default: False If true, will accept unfiltered recording. - False by default. """ if recording is None: # Recordless mode. @@ -875,15 +870,15 @@ def save( ---------- folder : str or Path The output waveform folder - format : str, optional - "binary", "zarr", by default "binary" + format : "binary" | "zarr", default: "binary" + The backend to use for saving the waveforms overwrite : bool - If True and folder exists, it is deleted, by default False - use_relative_path : bool, optional + If True and folder exists, it is deleted, default: False + use_relative_path : bool, default: False If True, the recording and sorting paths are relative to the waveforms folder. This allows portability of the waveform folder provided that the relative paths are the same, - but forces all the data files to be in the same drive, by default False - sparsity : ChannelSparsity, optional + but forces all the data files to be in the same drive + sparsity : ChannelSparsity, default: None If given and WaveformExtractor is not sparse, it makes the returned WaveformExtractor sparse """ folder = Path(folder) @@ -920,12 +915,12 @@ def save( if self.recording.check_serializablility("json"): self.recording.dump(folder / "recording.json", relative_to=relative_to) elif self.recording.check_serializablility("pickle"): - self.recording.dump(folder / "recording.pickle") + self.recording.dump(folder / "recording.pickle", relative_to=relative_to) if self.sorting.check_serializablility("json"): self.sorting.dump(folder / "sorting.json", relative_to=relative_to) elif self.sorting.check_serializablility("pickle"): - self.sorting.dump(folder / "sorting.pickle") + self.sorting.dump(folder / "sorting.pickle", relative_to=relative_to) else: warn( "Sorting object is not serializable to file, which might result in downstream errors for " @@ -1050,17 +1045,17 @@ def get_waveforms( ---------- unit_id: int or str Unit id to retrieve waveforms for - with_index: bool - If True, spike indices of extracted waveforms are returned (default False) - cache: bool - If True, waveforms are cached to the self._waveforms dictionary (default False) - lazy: bool + with_index: bool, default: False + If True, spike indices of extracted waveforms are returned + cache: bool, default: False + If True, waveforms are cached to the self._waveforms dictionary + lazy: bool, default: True If True, waveforms are loaded as memmap objects (when format="binary") or Zarr datasets (when format="zarr"). - If False, waveforms are loaded as np.array objects (default True) - sparsity: ChannelSparsity, optional + If False, waveforms are loaded as np.array objects + sparsity: ChannelSparsity, default: None Sparsity to apply to the waveforms (if WaveformExtractor is not sparse) - force_dense: bool (False) + force_dense: bool, default: False Return dense waveforms even if the waveform extractor is sparse Returns @@ -1068,7 +1063,7 @@ def get_waveforms( wfs: np.array The returned waveform (num_spikes, num_samples, num_channels) indices: np.array - If 'with_index' is True, the spike indices corresponding to the waveforms extracted + If "with_index" is True, the spike indices corresponding to the waveforms extracted """ assert unit_id in self.sorting.unit_ids, "'unit_id' is invalid" assert self.has_waveforms(), "Waveforms have been deleted!" @@ -1164,7 +1159,7 @@ def get_waveforms_segment(self, segment_index: int, unit_id, sparsity): The segment index to retrieve waveforms from unit_id: int or str Unit id to retrieve waveforms for - sparsity: ChannelSparsity, optional + sparsity: ChannelSparsity, default: None Sparsity to apply to the waveforms (if WaveformExtractor is not sparse) Returns @@ -1229,8 +1224,8 @@ def get_all_templates(self, unit_ids: Optional[Iterable] = None, mode="average") ---------- unit_ids: list or None Unit ids to retrieve waveforms for - mode: str - 'average' (default) or 'median' , 'std' + mode: "average" | "median" | "std", default: "average" + The mode to compute the templates Returns ------- @@ -1256,9 +1251,9 @@ def get_template(self, unit_id, mode="average", sparsity=None, force_dense: bool ---------- unit_id: int or str Unit id to retrieve waveforms for - mode: str - 'average' (default), 'median' , 'std'(standard deviation) - sparsity: ChannelSparsity, optional + mode: "average" | "median" | "std", default: "average" + The mode to compute the template + sparsity: ChannelSparsity, default: None Sparsity to apply to the waveforms (if WaveformExtractor is not sparse) force_dense: bool (False) Return a dense template even if the waveform extractor is sparse @@ -1314,9 +1309,9 @@ def get_template_segment(self, unit_id, segment_index, mode="average", sparsity= Unit id to retrieve waveforms for segment_index: int The segment index to retrieve template from - mode: str - 'average' (default), 'median', 'std'(standard deviation) - sparsity: ChannelSparsity, optional + mode: "average" | "median" | "std", default: "average" + The mode to compute the template + sparsity: ChannelSparsity, default: None Sparsity to apply to the waveforms (if WaveformExtractor is not sparse). Returns @@ -1481,7 +1476,9 @@ def extract_waveforms( dtype=None, sparse=True, sparsity=None, + sparsity_temp_folder=None, num_spikes_for_sparsity=100, + unit_batch_size=200, allow_unfiltered=False, use_relative_path=False, seed=None, @@ -1502,46 +1499,52 @@ def extract_waveforms( The recording object sorting: Sorting The sorting object - folder: str or Path or None + folder: str or Path or None, default: None The folder where waveforms are cached - mode: str - "folder" (default) or "memory". The "folder" argument must be specified in case of mode "folder". + mode: "folder" | "memory, default: "folder" + The mode to store waveforms. If "folder", waveforms are stored on disk in the specified folder. + The "folder" argument must be specified in case of mode "folder". If "memory" is used, the waveforms are stored in RAM. Use this option carefully! - precompute_template: None or list - Precompute average/std/median for template. If None not precompute. - ms_before: float + precompute_template: None or list, default: ["average"] + Precompute average/std/median for template. If None, no templates are precomputed + ms_before: float, default: 1.0 Time in ms to cut before spike peak - ms_after: float + ms_after: float, default: 2.0 Time in ms to cut after spike peak - max_spikes_per_unit: int or None - Number of spikes per unit to extract waveforms from (default 500). + max_spikes_per_unit: int or None, default: 500 + Number of spikes per unit to extract waveforms from Use None to extract waveforms for all spikes - overwrite: bool - If True and 'folder' exists, the folder is removed and waveforms are recomputed. + overwrite: bool, default: False + If True and "folder" exists, the folder is removed and waveforms are recomputed Otherwise an error is raised. - return_scaled: bool - If True and recording has gain_to_uV/offset_to_uV properties, waveforms are converted to uV. - dtype: dtype or None - Dtype of the output waveforms. If None, the recording dtype is maintained. + return_scaled: bool, default: True + If True and recording has gain_to_uV/offset_to_uV properties, waveforms are converted to uV + dtype: dtype or None, default: None + Dtype of the output waveforms. If None, the recording dtype is maintained sparse: bool, default: True If True, before extracting all waveforms the `precompute_sparsity()` function is run using a few spikes to get an estimate of dense templates to create a ChannelSparsity object. Then, the waveforms will be sparse at extraction time, which saves a lot of memory. When True, you must some provide kwargs handle `precompute_sparsity()` to control the kind of sparsity you want to apply (by radius, by best channels, ...). - sparsity: ChannelSparsity or None + sparsity: ChannelSparsity or None, default: None The sparsity used to compute waveforms. If this is given, `sparse` is ignored. Default None. - num_spikes_for_sparsity: int (default 100) + sparsity_temp_folder: str or Path or None, default: None + If sparse is True, this is the temporary folder where the dense waveforms are temporarily saved. + If None, dense waveforms are extracted in memory in batches (which can be controlled by the `unit_batch_size` + parameter. With a large number of units (e.g., > 400), it is advisable to use a temporary folder. + num_spikes_for_sparsity: int, default: 100 The number of spikes to use to estimate sparsity (if sparse=True). + unit_batch_size: int, default: 200 + The number of units to process at once when extracting dense waveforms (if sparse=True and sparsity_temp_folder + is None). allow_unfiltered: bool If true, will accept an allow_unfiltered recording. - False by default. - use_relative_path: bool + use_relative_path: bool, default: False If True, the recording and sorting paths are relative to the waveforms folder. This allows portability of the waveform folder provided that the relative paths are the same, but forces all the data files to be in the same drive. - Default is False. - seed: int or None + seed: int or None, default: None Random seed for spike selection sparsity kwargs: @@ -1612,6 +1615,8 @@ def extract_waveforms( ms_before=ms_before, ms_after=ms_after, num_spikes_for_sparsity=num_spikes_for_sparsity, + unit_batch_size=unit_batch_size, + temp_folder=sparsity_temp_folder, allow_unfiltered=allow_unfiltered, **estimate_kwargs, **job_kwargs, @@ -1654,11 +1659,11 @@ def load_waveforms(folder, with_recording: bool = True, sorting: Optional[BaseSo ---------- folder : str or Path The folder / zarr folder where the waveform extractor is stored - with_recording : bool, optional - If True, the recording is loaded, by default True. + with_recording : bool, default: True + If True, the recording is loaded. If False, the WaveformExtractor object in recordingless mode. - sorting : BaseSorting, optional - If passed, the sorting object associated to the waveform extractor, by default None + sorting : BaseSorting, default: None + If passed, the sorting object associated to the waveform extractor Returns ------- @@ -1675,6 +1680,7 @@ def precompute_sparsity( unit_batch_size=200, ms_before=2.0, ms_after=3.0, + temp_folder=None, allow_unfiltered=False, **kwargs, ): @@ -1689,25 +1695,25 @@ def precompute_sparsity( The recording object sorting: Sorting The sorting object - num_spikes_for_sparsity: int - How many spikes per unit. - unit_batch_size: int or None + num_spikes_for_sparsity: int, default: 100 + How many spikes per unit + unit_batch_size: int or None, default: 200 How many units are extracted at once to estimate sparsity. - If None then they are extracted all at one (consum many memory) - ms_before: float + If None then they are extracted all at one (but uses a lot of memory) + ms_before: float, default: 2.0 Time in ms to cut before spike peak - ms_after: float + ms_after: float, default: 3.0 Time in ms to cut after spike peak - allow_unfiltered: bool + temp_folder: str or Path or None, default: None + If provided, dense waveforms are saved to this temporary folder + allow_unfiltered: bool, default: False If true, will accept an allow_unfiltered recording. - False by default. - kwargs for sparsity strategy: {} - Job kwargs: + job kwargs: {} Returns @@ -1724,18 +1730,38 @@ def precompute_sparsity( if unit_batch_size is None: unit_batch_size = len(unit_ids) - mask = np.zeros((len(unit_ids), len(channel_ids)), dtype="bool") - - nloop = int(np.ceil((unit_ids.size / unit_batch_size))) - for i in range(nloop): - sl = slice(i * unit_batch_size, (i + 1) * unit_batch_size) - local_ids = unit_ids[sl] - local_sorting = sorting.select_units(local_ids) - local_we = extract_waveforms( + if temp_folder is None: + mask = np.zeros((len(unit_ids), len(channel_ids)), dtype="bool") + nloop = int(np.ceil((unit_ids.size / unit_batch_size))) + for i in range(nloop): + sl = slice(i * unit_batch_size, (i + 1) * unit_batch_size) + local_ids = unit_ids[sl] + local_sorting = sorting.select_units(local_ids) + local_we = extract_waveforms( + recording, + local_sorting, + folder=None, + mode="memory", + precompute_template=("average",), + ms_before=ms_before, + ms_after=ms_after, + max_spikes_per_unit=num_spikes_for_sparsity, + return_scaled=False, + allow_unfiltered=allow_unfiltered, + sparse=False, + **job_kwargs, + ) + local_sparsity = compute_sparsity(local_we, **sparse_kwargs) + mask[sl, :] = local_sparsity.mask + else: + temp_folder = Path(temp_folder) + assert ( + not temp_folder.is_dir() + ), "Temporary folder for pre-computing sparsity already exists. Provide a non-existing folder" + dense_we = extract_waveforms( recording, - local_sorting, - folder=None, - mode="memory", + sorting, + folder=temp_folder, precompute_template=("average",), ms_before=ms_before, ms_after=ms_after, @@ -1745,8 +1771,9 @@ def precompute_sparsity( sparse=False, **job_kwargs, ) - local_sparsity = compute_sparsity(local_we, **sparse_kwargs) - mask[sl, :] = local_sparsity.mask + sparsity = compute_sparsity(dense_we, **sparse_kwargs) + mask = sparsity.mask + shutil.rmtree(temp_folder) sparsity = ChannelSparsity(mask, unit_ids, channel_ids) return sparsity diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index a2f1296e31..1588bc926c 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -54,18 +54,18 @@ def extract_waveforms_to_buffers( N samples before spike nafter: int N samples after spike - mode: str - Mode to use ('memmap' | 'shared_memory') - return_scaled: bool - Scale traces before exporting to buffer or not. - folder: str or path + mode: "memmap" | "shared_memory", default: "memmap" + The mode to use for the buffer + return_scaled: bool, default: False + Scale traces before exporting to buffer or not + folder: str or path or None, default: None In case of memmap mode, folder to save npy files - dtype: numpy.dtype + dtype: numpy.dtype, default: None dtype for waveforms buffer - sparsity_mask: None or array of bool + sparsity_mask: None or array of bool, default: None If not None shape must be must be (len(unit_ids), len(channel_ids)) - copy: bool - If True (default), the output shared memory object is copied to a numpy standard array. + copy: bool, default: False + If True, the output shared memory object is copied to a numpy standard array. If copy=False then arrays_info is also return. Please keep in mind that arrays_info need to be referenced as long as waveforms_by_units will be used otherwise it will be very hard to debug. Also when copy=False the SharedMemory will need to be unlink manually @@ -147,8 +147,8 @@ def allocate_waveforms_buffers( N samples before spike nafter: int N samples after spike - mode: str - Mode to use ('memmap' | 'shared_memory') + mode: "memmap" | "shared_memory", default: "memmap" + Mode to use folder: str or path In case of memmap mode, folder to save npy files dtype: numpy.dtype @@ -242,8 +242,8 @@ def distribute_waveforms_to_buffers( N samples after spike return_scaled: bool Scale traces before exporting to buffer or not. - mode: str - Mode to use ('memmap' | 'shared_memory') + mode: "memmap" | "shared_memory", default: "memmap" + Mode to use sparsity_mask: None or array of bool If not None shape must be must be (len(unit_ids), len(channel_ids) @@ -419,7 +419,7 @@ def extract_waveforms_to_single_buffer( Important note: for the "shared_memory" mode wf_array_info contains reference to the shared memmory buffer, this variable must be referenced as long as arrays is used. This variable must also unlink() when the array is de-referenced. - To avoid this complicated behavior, by default (copy=True) the shared memmory buffer is copied into a standard + To avoid this complicated behavior, default: (copy=True) the shared memmory buffer is copied into a standard numpy array. @@ -436,18 +436,18 @@ def extract_waveforms_to_single_buffer( N samples before spike nafter: int N samples after spike - mode: str - Mode to use ('memmap' | 'shared_memory') - return_scaled: bool - Scale traces before exporting to buffer or not. - file_path: str or path - In case of memmap mode, file to save npy file. - dtype: numpy.dtype + mode: "memmap" | "shared_memory", default: "memmap" + The mode to use for the buffer + return_scaled: bool, default: False + Scale traces before exporting to buffer or not + file_path: str or path or None, default: None + In case of memmap mode, file to save npy file + dtype: numpy.dtype, default: None dtype for waveforms buffer - sparsity_mask: None or array of bool + sparsity_mask: None or array of bool, default: None If not None shape must be must be (len(unit_ids), len(channel_ids)) - copy: bool - If True (default), the output shared memory object is copied to a numpy standard array and no reference + copy: bool, default: False + If True, the output shared memory object is copied to a numpy standard array and no reference to the internal shared memory object is kept. If copy=False then the shared memory object is also returned. Please keep in mind that the shared memory object need to be referenced as long as all_waveforms will be used otherwise it might produce segmentation diff --git a/src/spikeinterface/core/zarrrecordingextractor.py b/src/spikeinterface/core/zarrrecordingextractor.py index 4dc94a24dd..6c15044ad9 100644 --- a/src/spikeinterface/core/zarrrecordingextractor.py +++ b/src/spikeinterface/core/zarrrecordingextractor.py @@ -169,7 +169,7 @@ def get_default_zarr_compressor(clevel=5): Parameters ---------- - clevel : int, optional + clevel : int, default: 5 Compression level (higher -> more compressed). Minimum 1, maximum 9. By default 5 diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 5e7047a5c1..6db8d856cb 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -58,49 +58,48 @@ def get_potential_auto_merge( ---------- waveform_extractor: WaveformExtractor The waveform extractor - minimum_spikes: int + minimum_spikes: int, default: 1000 Minimum number of spikes for each unit to consider a potential merge. - Enough spikes are needed to estimate the correlogram, by default 1000 - maximum_distance_um: float - Minimum distance between units for considering a merge, by default 150 - peak_sign: "neg"/"pos"/"both" - Peak sign used to estimate the maximum channel of a template, by default "neg" - bin_ms: float - Bin size in ms used for computing the correlogram, by default 0.25 - window_ms: float - Window size in ms used for computing the correlogram, by default 100 - corr_diff_thresh: float + Enough spikes are needed to estimate the correlogram + maximum_distance_um: float, default: 150 + Minimum distance between units for considering a merge + peak_sign: "neg" | "pos" | "both", default: "neg" + Peak sign used to estimate the maximum channel of a template + bin_ms: float, default: 0.25 + Bin size in ms used for computing the correlogram + window_ms: float, default: 100 + Window size in ms used for computing the correlogram + corr_diff_thresh: float, default: 0.16 The threshold on the "correlogram distance metric" for considering a merge. - It needs to be between 0 and 1, by default 0.16 - template_diff_thresh: float + It needs to be between 0 and 1 + template_diff_thresh: float, default: 0.25 The threshold on the "template distance metric" for considering a merge. - It needs to be between 0 and 1, by default 0.25 - censored_period_ms: float - Used to compute the refractory period violations aka "contamination", by default 0 - refractory_period_ms: float - Used to compute the refractory period violations aka "contamination", by default 1 - sigma_smooth_ms: float - Parameters to smooth the correlogram estimation, by default 0.6 - contamination_threshold: float - Threshold for not taking in account a unit when it is too contaminated, by default 0.2 - adaptative_window_threshold:: float - Parameter to detect the window size in correlogram estimation, by default 0.5 - censor_correlograms_ms: float - The period to censor on the auto and cross-correlograms, by default 0.15 ms - num_channels: int - Number of channel to use for template similarity computation, by default 5 - num_shift: int - Number of shifts in samles to be explored for template similarity computation, by default 5 - firing_contamination_balance: float - Parameter to control the balance between firing rate and contamination in computing unit "quality score", - by default 1.5 - extra_outputs: bool - If True, an additional dictionary (`outs`) with processed data is returned, by default False - steps: None or list of str + It needs to be between 0 and 1 + censored_period_ms: float, default: 0.3 + Used to compute the refractory period violations aka "contamination" + refractory_period_ms: float, default: 1 + Used to compute the refractory period violations aka "contamination" + sigma_smooth_ms: float, default: 0.6 + Parameters to smooth the correlogram estimation + contamination_threshold: float, default: 0.2 + Threshold for not taking in account a unit when it is too contaminated + adaptative_window_threshold:: float, default: 0.5 + Parameter to detect the window size in correlogram estimation + censor_correlograms_ms: float, default: 0.15 + The period to censor on the auto and cross-correlograms + num_channels: int, default: 5 + Number of channel to use for template similarity computation + num_shift: int, default: 5 + Number of shifts in samles to be explored for template similarity computation + firing_contamination_balance: float, default: 1.5 + Parameter to control the balance between firing rate and contamination in computing unit "quality score" + extra_outputs: bool, default: False + If True, an additional dictionary (`outs`) with processed data is returned + steps: None or list of str, default: None which steps to run (gives flexibility to running just some steps) If None all steps are done. - Pontential steps: 'min_spikes', 'remove_contaminated', 'unit_positions', 'correlogram', 'template_similarity', - 'check_increase_score'. Please check steps explanations above! + Pontential steps: "min_spikes", "remove_contaminated", "unit_positions", "correlogram", "template_similarity", + "check_increase_score". Please check steps explanations above! Returns ------- @@ -312,7 +311,7 @@ def smooth_correlogram(correlograms, bins, sigma_smooth_ms=0.6): import scipy.signal # OLD implementation : smooth correlogram by low pass filter - # b, a = scipy.signal.butter(N=2, Wn = correlogram_low_pass / (1e3 / bin_ms /2), btype='low') + # b, a = scipy.signal.butter(N=2, Wn = correlogram_low_pass / (1e3 / bin_ms /2), btype="low") # correlograms_smoothed = scipy.signal.filtfilt(b, a, correlograms, axis=2) # new implementation smooth by convolution with a Gaussian kernel @@ -378,10 +377,10 @@ def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair The sorting object templates : np.array The templates array (num_units, num_samples, num_channels) - num_channels: int, optional - Number of channel to use for template similarity computation, by default 5 - num_shift: int, optional - Number of shifts in samles to be explored for template similarity computation, by default 5 + num_channels: int, default: 5 + Number of channel to use for template similarity computation + num_shift: int, default: 5 + Number of shifts in samles to be explored for template similarity computation pair_mask: None or boolean array A bool matrix of size (num_units, num_units) to select which pair to compute. diff --git a/src/spikeinterface/curation/curation_tools.py b/src/spikeinterface/curation/curation_tools.py index ddf7d4dc9d..4badaebbbb 100644 --- a/src/spikeinterface/curation/curation_tools.py +++ b/src/spikeinterface/curation/curation_tools.py @@ -109,7 +109,7 @@ def find_duplicated_spikes( The spike train on which to look for duplicated spikes. censored_period: int The censored period for duplicates (in sample time). - method: "keep_first" |"keep_last" | "keep_first_iterative' | 'keep_last_iterative" |random" + method: "keep_first" |"keep_last" | "keep_first_iterative" | "keep_last_iterative" |random", default: "random" Method used to remove the duplicated spikes. seed: int | None The seed to use if method="random". diff --git a/src/spikeinterface/curation/curationsorting.py b/src/spikeinterface/curation/curationsorting.py index bdb33e9eb1..423baa0818 100644 --- a/src/spikeinterface/curation/curationsorting.py +++ b/src/spikeinterface/curation/curationsorting.py @@ -14,10 +14,10 @@ class CurationSorting: ---------- parent_sorting: Recording The recording object - properties_policy: str - Policy used to propagate properties after split and merge operation. If 'keep' the properties will be - passed to the new units (if the original units have the same value). If 'remove' the new units will have - an empty value for all the properties. Default: 'keep' + properties_policy: "keep" | "remove", default: "keep" + Policy used to propagate properties after split and merge operation. If "keep" the properties will be + passed to the new units (if the original units have the same value). If "remove" the new units will have + an empty value for all the properties make_graph: bool True to keep a Networkx graph instance with the curation history Returns diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index 5295cc76d8..ae033d5531 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -17,11 +17,10 @@ class MergeUnitsSorting(BaseSorting): but it can also have more (merge multiple units at once). new_unit_ids: None or list A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge` - properties_policy: str ('keep', 'remove') - Policy used to propagate properties. If 'keep' the properties will be passed to the new units - (if the units_to_merge have the same value). If 'remove' the new units will have an empty + properties_policy: "keep" | "remove", default: "keep" + Policy used to propagate properties. If "keep" the properties will be passed to the new units + (if the units_to_merge have the same value). If "remove" the new units will have an empty value for all the properties of the new unit. - Default: 'keep' delta_time_ms: float or None Number of ms to consider for duplicated spikes. None won't check for duplications diff --git a/src/spikeinterface/curation/remove_duplicated_spikes.py b/src/spikeinterface/curation/remove_duplicated_spikes.py index 04af69b37a..e29e88377e 100644 --- a/src/spikeinterface/curation/remove_duplicated_spikes.py +++ b/src/spikeinterface/curation/remove_duplicated_spikes.py @@ -17,7 +17,7 @@ class RemoveDuplicatedSpikesSorting(BaseSorting): The parent sorting. censored_period_ms: float The censored period to consider 2 spikes to be duplicated (in ms). - method: str in ("keep_first", "keep_last", "keep_first_iterative', 'keep_last_iterative", random") + method: "keep_first" | "keep_last" | "keep_first_iterative" | "keep_last_iterative" | "random", default: "keep_first" Method used to remove the duplicated spikes. If method = "random", will randomly choose to remove the first or last spike. If method = "keep_first", for each ISI violation, will remove the second spike. diff --git a/src/spikeinterface/curation/remove_redundant.py b/src/spikeinterface/curation/remove_redundant.py index e13f83550a..88868c8730 100644 --- a/src/spikeinterface/curation/remove_redundant.py +++ b/src/spikeinterface/curation/remove_redundant.py @@ -26,9 +26,9 @@ def remove_redundant_units( When a redundant pair is found, there are several strategies to choose which unit is the best: - * 'minimum_shift' - * 'highest_amplitude' - * 'max_spikes' + * "minimum_shift" + * "highest_amplitude" + * "max_spikes" Parameters @@ -37,26 +37,26 @@ def remove_redundant_units( If WaveformExtractor, the spike trains can be optionally realigned using the peak shift in the template to improve the matching procedure. If BaseSorting, the spike trains are not aligned. - align : bool, optional - If True, spike trains are aligned (if a WaveformExtractor is used), by default False - delta_time : float, optional - The time in ms to consider matching spikes, by default 0.4 - agreement_threshold : float, optional - Threshold on the agreement scores to flag possible redundant/duplicate units, by default 0.2 - duplicate_threshold : float, optional + align : bool, default: False + If True, spike trains are aligned (if a WaveformExtractor is used) + delta_time : float, default: 0.4 + The time in ms to consider matching spikes + agreement_threshold : float, default: 0.2 + Threshold on the agreement scores to flag possible redundant/duplicate units + duplicate_threshold : float, default: 0.8 Final threshold on the portion of coincident events over the number of spikes above which the - unit is removed, by default 0.8 - remove_strategy: 'minimum_shift' | 'highest_amplitude' | 'max_spikes', default: 'minimum_shift' + unit is removed + remove_strategy: "minimum_shift" | "highest_amplitude" | "max_spikes", default: "minimum_shift" Which strategy to remove one of the two duplicated units: - * 'minimum_shift': keep the unit with best peak alignment (minimum shift) - If shifts are equal then the 'highest_amplitude' is used - * 'highest_amplitude': keep the unit with the best amplitude on unshifted max. - * 'max_spikes': keep the unit with more spikes + * "minimum_shift": keep the unit with best peak alignment (minimum shift) + If shifts are equal then the "highest_amplitude" is used + * "highest_amplitude": keep the unit with the best amplitude on unshifted max. + * "max_spikes": keep the unit with more spikes - peak_sign: 'neg' |'pos' | 'both', default: 'neg' - Used when remove_strategy='highest_amplitude' - extra_outputs: bool + peak_sign: "neg" | "pos" | "both", default: "neg" + Used when remove_strategy="highest_amplitude" + extra_outputs: bool, default: False If True, will return the redundant pairs. Returns @@ -147,13 +147,13 @@ def find_redundant_units(sorting, delta_time: float = 0.4, agreement_threshold=0 ---------- sorting : BaseSorting The input sorting object - delta_time : float, optional - The time in ms to consider matching spikes, by default 0.4 - agreement_threshold : float, optional - Threshold on the agreement scores to flag possible redundant/duplicate units, by default 0.2 - duplicate_threshold : float, optional + delta_time : float, default: 0.4 + The time in ms to consider matching spikes + agreement_threshold : float, default: 0.2 + Threshold on the agreement scores to flag possible redundant/duplicate units + duplicate_threshold : float, default: 0.8 Final threshold on the portion of coincident events over the number of spikes above which the - unit is flagged as duplicate/redundant, by default 0.8 + unit is flagged as duplicate/redundant Returns ------- diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index 626ea79eb9..e6427b32a2 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -19,16 +19,16 @@ def apply_sortingview_curation( The sorting object to be curated uri_or_json : str or Path The URI curation link from SortingView or the path to the curation json file - exclude_labels : list, optional + exclude_labels : list, default: None Optional list of labels to exclude (e.g. ["reject", "noise"]). - Mutually exclusive with include_labels, by default None - include_labels : list, optional + Mutually exclusive with include_labels + include_labels : list, default: None Optional list of labels to include (e.g. ["accept"]). Mutually exclusive with exclude_labels, by default None - skip_merge : bool, optional - If True, merges are not applied (only labels), by default False - verbose : bool, optional - If True, output is verbose, by default False + skip_merge : bool, default: False + If True, merges are not applied (only labels) + verbose : bool, default: False + If True, output is verbose Returns ------- diff --git a/src/spikeinterface/curation/splitunitsorting.py b/src/spikeinterface/curation/splitunitsorting.py index 23863a85e5..75b50c373f 100644 --- a/src/spikeinterface/curation/splitunitsorting.py +++ b/src/spikeinterface/curation/splitunitsorting.py @@ -21,9 +21,9 @@ class SplitUnitSorting(BaseSorting): be the same length as the spike train (for each segment) new_unit_ids: int Unit ids of the new units to be created. - properties_policy: 'keep' | 'remove', default: 'keep' - Policy used to propagate properties. If 'keep' the properties will be passed to the new units - (if the units_to_merge have the same value). If 'remove' the new units will have an empty + properties_policy: "keep" | "remove", default: "keep" + Policy used to propagate properties. If "keep" the properties will be passed to the new units + (if the units_to_merge have the same value). If "remove" the new units will have an empty value for all the properties of the new unit. Returns ------- diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 4910c4348f..57a5ab0166 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -30,16 +30,16 @@ def export_report( If WaveformExtractor is provide then the compute is faster otherwise output_folder: str The output folder where the report files are saved - remove_if_exists: bool + remove_if_exists: bool, default: False If True and the output folder exists, it is removed - format: str - 'png' (default) or 'pdf' or any format handled by matplotlib - peak_sign: 'neg' or 'pos' + format: str, default: "png" + The output figure format (any format handled by matplotlib) + peak_sign: "neg" or "pos", default: "neg" used to compute amplitudes and metrics - show_figures: bool - If True, figures are shown. If False (default), figures are closed after saving. - force_computation: bool default False - Force or not some heavy computaion before exporting. + show_figures: bool, default: False + If True, figures are shown. If False, figures are closed after saving + force_computation: bool, default: False + Force or not some heavy computaion before exporting {} """ import pandas as pd diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 0529c99d12..59771331bc 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -47,27 +47,27 @@ def export_to_phy( If WaveformExtractor is provide then the compute is faster otherwise output_folder: str | Path The output folder where the phy template-gui files are saved - compute_pc_features: bool - If True (default), pc features are computed - compute_amplitudes: bool - If True (default), waveforms amplitudes are computed - sparsity: ChannelSparsity or None - The sparsity object. - copy_binary: bool - If True, the recording is copied and saved in the phy 'output_folder' - remove_if_exists: bool - If True and 'output_folder' exists, it is removed and overwritten - peak_sign: 'neg', 'pos', 'both' + compute_pc_features: bool, default: True + If True, pc features are computed + compute_amplitudes: bool, default: True + If True, waveforms amplitudes are computed + sparsity: ChannelSparsity or None, default: None + The sparsity object + copy_binary: bool, default: True + If True, the recording is copied and saved in the phy "output_folder" + remove_if_exists: bool, default: False + If True and "output_folder" exists, it is removed and overwritten + peak_sign: "neg" | "pos" | "both", default: "neg" Used by compute_spike_amplitudes - template_mode: str - Parameter 'mode' to be given to WaveformExtractor.get_template() - dtype: dtype or None + template_mode: str, default: "median" + Parameter "mode" to be given to WaveformExtractor.get_template() + dtype: dtype or None, default: None Dtype to save binary data - verbose: bool + verbose: bool, default: True If True, output is verbose use_relative_path : bool, default: False - If True and `copy_binary=True` saves the binary file `dat_path` in the `params.py` relative to `output_folder` (ie `dat_path=r'recording.dat'`). If `copy_binary=False`, then uses a path relative to the `output_folder` - If False, uses an absolute path in the `params.py` (ie `dat_path=r'path/to/the/recording.dat'`) + If True and `copy_binary=True` saves the binary file `dat_path` in the `params.py` relative to `output_folder` (ie `dat_path=r"recording.dat"`). If `copy_binary=False`, then uses a path relative to the `output_folder` + If False, uses an absolute path in the `params.py` (ie `dat_path=r"path/to/the/recording.dat"`) {} """ diff --git a/src/spikeinterface/extractors/bids.py b/src/spikeinterface/extractors/bids.py index 8b70722652..b1888f4a27 100644 --- a/src/spikeinterface/extractors/bids.py +++ b/src/spikeinterface/extractors/bids.py @@ -3,7 +3,7 @@ import numpy as np import neo -from probeinterface import read_BIDS_probe +import probeinterface from .nwbextractors import read_nwb from .neoextractors import read_nix @@ -60,7 +60,7 @@ def read_bids(folder_path): def _read_probe_group(folder, bids_name, recording_channel_ids): - probegroup = read_BIDS_probe(folder) + probegroup = probeinterface.read_BIDS_probe(folder) # make maps between : channel_id and contact_id using _channels.tsv import pandas as pd diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index bd56208ebe..37ed931d1a 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -1,6 +1,6 @@ from pathlib import Path -import probeinterface as pi +import probeinterface from spikeinterface.core import BaseRecording, BaseRecordingSegment from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts @@ -89,7 +89,7 @@ def __init__(self, folder_path, load_sync_channel=False, stream_name="ap"): self.set_channel_offsets(offsets) if not load_sync_channel: - probe = pi.read_spikeglx(meta_file) + probe = probeinterface.read_spikeglx(meta_file) if probe.shank_ids is not None: self.set_probe(probe, in_place=True, group_mode="by_shank") diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index 0980e89f1c..3436313b4d 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -24,9 +24,9 @@ class CellExplorerSortingExtractor(BaseSorting): ---------- file_path: str | Path Path to `.mat` file containing spikes. Usually named `session_id.spikes.cellinfo.mat` - sampling_frequency: float | None, optional + sampling_frequency: float | None, default: None The sampling frequency of the data. If None, it will be extracted from the files. - session_info_file_path: str | Path | None, optional + session_info_file_path: str | Path | None, default: None Path to the `sessionInfo.mat` file. If None, it will be inferred from the file_path. """ diff --git a/src/spikeinterface/extractors/combinatoextractors.py b/src/spikeinterface/extractors/combinatoextractors.py index fa2bdde450..737bdfe7a4 100644 --- a/src/spikeinterface/extractors/combinatoextractors.py +++ b/src/spikeinterface/extractors/combinatoextractors.py @@ -22,9 +22,9 @@ class CombinatoSortingExtractor(BaseSorting): Path to the Combinato folder. sampling_frequency : int, default: 30000 The sampling frequency. - user : str - The username that ran the sorting. Defaults to 'simple'. - det_sign : {'both', 'pos', 'neg'} + user : str, default: "simple" + The username that ran the sorting + det_sign : "both", "pos", "neg", default: "both" Which sign was used for detection. keep_good_only : bool, default: True Whether to only keep good units. diff --git a/src/spikeinterface/extractors/hdsortextractors.py b/src/spikeinterface/extractors/hdsortextractors.py index 178596d052..074a3fbd40 100644 --- a/src/spikeinterface/extractors/hdsortextractors.py +++ b/src/spikeinterface/extractors/hdsortextractors.py @@ -191,13 +191,13 @@ def write_sorting(sorting, save_path, locations=None, noise_std_by_channel=None, if noise_std_by_channel is None: noise_std_by_channel = np.ones((1, n_channels)) - dict_to_save = {'Units': units, - 'MultiElectrode': multi_electrode, - 'noiseStd': noise_std_by_channel, + dict_to_save = {"Units": units, + "MultiElectrode": multi_electrode, + "noiseStd": noise_std_by_channel, "samplingRate": sorting._sampling_frequency} # Save Units and MultiElectrode to .mat file: - MATSortingExtractor.write_dict_to_mat(save_path, dict_to_save, version='7.3') + MATSortingExtractor.write_dict_to_mat(save_path, dict_to_save, version="7.3") """ diff --git a/src/spikeinterface/extractors/iblstreamingrecording.py b/src/spikeinterface/extractors/iblstreamingrecording.py index fcd03f8bcf..35dccbef1e 100644 --- a/src/spikeinterface/extractors/iblstreamingrecording.py +++ b/src/spikeinterface/extractors/iblstreamingrecording.py @@ -4,7 +4,7 @@ from pathlib import Path import numpy as np -import probeinterface as pi +import probeinterface from spikeinterface.core import BaseRecording, BaseRecordingSegment from spikeinterface.core.core_tools import define_function_from_class @@ -18,24 +18,24 @@ class IblStreamingRecordingExtractor(BaseRecording): ---------- session : str The session ID to extract recordings for. - In ONE, this is sometimes referred to as the 'eid'. + In ONE, this is sometimes referred to as the "eid". When doing a session lookup such as >>> from one.api import ONE >>> one = ONE(base_url="https://openalyx.internationalbrainlab.org", password="international", silent=True) - >>> sessions = one.alyx.rest('sessions', 'list', tag='2022_Q2_IBL_et_al_RepeatedSite') + >>> sessions = one.alyx.rest("sessions", "list", tag="2022_Q2_IBL_et_al_RepeatedSite") - each returned value in `sessions` refers to it as the 'id'. + each returned value in `sessions` refers to it as the "id". stream_name : str The name of the stream to load for the session. These can be retrieved from calling `StreamingIblExtractor.get_stream_names(session="")`. load_sync_channels : bool, default: false Load or not the last channel (sync). If not then the probe is loaded. - cache_folder : str, optional + cache_folder : str or None, default: None The location to temporarily store chunks of data during streaming. The default uses the folder designated by ONE.alyx._par.CACHE_DIR / "cache", which is typically the designated - 'Downloads' folder on your operating system. As long as `remove_cached` is set to True, the only files that will + "Downloads" folder on your operating system. As long as `remove_cached` is set to True, the only files that will persist in this folder are the metadata header files and the chunk of data being actively streamed and used in RAM. remove_cached : bool, default: True Whether or not to remove streamed data from the cache immediately after it is read. @@ -61,14 +61,14 @@ def get_stream_names(cls, session: str, cache_folder: Optional[Union[Path, str]] ---------- session : str The session ID to extract recordings for. - In ONE, this is sometimes referred to as the 'eid'. + In ONE, this is sometimes referred to as the "eid". When doing a session lookup such as >>> from one.api import ONE >>> one = ONE(base_url="https://openalyx.internationalbrainlab.org", password="international", silent=True) - >>> sessions = one.alyx.rest('sessions', 'list', tag='2022_Q2_IBL_et_al_RepeatedSite') + >>> sessions = one.alyx.rest("sessions", "list", tag="2022_Q2_IBL_et_al_RepeatedSite") - each returned value in `sessions` refers to it as the 'id'. + each returned value in `sessions` refers to it as the "id". Returns ------- @@ -165,7 +165,7 @@ def __init__( # set probe if not load_sync_channel: - probe = pi.read_spikeglx(meta_file) + probe = probeinterface.read_spikeglx(meta_file) if probe.shank_ids is not None: self.set_probe(probe, in_place=True, group_mode="by_shank") diff --git a/src/spikeinterface/extractors/klustaextractors.py b/src/spikeinterface/extractors/klustaextractors.py index f6a86ae9ae..83718cffb2 100644 --- a/src/spikeinterface/extractors/klustaextractors.py +++ b/src/spikeinterface/extractors/klustaextractors.py @@ -31,7 +31,7 @@ class KlustaSortingExtractor(BaseSorting): ---------- file_or_folder_path : str or Path Path to the ALF folder. - exclude_cluster_groups: list or str, optional + exclude_cluster_groups: list or str, default: None Cluster groups to exclude (e.g. "noise" or ["noise", "mua"]). Returns diff --git a/src/spikeinterface/extractors/mclustextractors.py b/src/spikeinterface/extractors/mclustextractors.py index 9ca802c58d..dfe5bcda26 100644 --- a/src/spikeinterface/extractors/mclustextractors.py +++ b/src/spikeinterface/extractors/mclustextractors.py @@ -15,9 +15,9 @@ class MClustSortingExtractor(BaseSorting): Path to folder with t files. sampling_frequency : sampling frequency sampling frequency in Hz. - sampling_frequency_raw: float or None + sampling_frequency_raw: float or None, default: None Required to read files with raw formats. In that case, the samples are saved in the same - unit as the input data. Default None + unit as the input data Examples: - If raw time is in tens of ms sampling_frequency_raw=10000 - If raw time is in samples sampling_frequency_raw=sampling_frequency diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index 1eb0182318..229e3ef0d0 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -21,12 +21,12 @@ class MdaRecordingExtractor(BaseRecording): ---------- folder_path : str or Path Path to the MDA folder. - raw_fname: str - File name of raw file. Defaults to 'raw.mda'. - params_fname: str - File name of params file. Defaults to 'params.json'. - geom_fname: str - File name of geom file. Defaults to 'geom.csv'. + raw_fname: str, default: "raw.mda" + File name of raw file + params_fname: str, default: "params.json" + File name of params file + geom_fname: str, default: "geom.csv" + File name of geom file Returns ------- @@ -87,13 +87,13 @@ def write_recording( params: dictionary Dictionary with optional parameters to save metadata. Sampling frequency is appended to this dictionary. - raw_fname: str - File name of raw file. Defaults to 'raw.mda'. - params_fname: str - File name of params file. Defaults to 'params.json'. - geom_fname: str - File name of geom file. Defaults to 'geom.csv'. - dtype: dtype + raw_fname: str, default: "raw.mda" + File name of raw file + params_fname: str, default: "params.json" + File name of params file + geom_fname: str, default: "geom.csv" + File name of geom file + dtype: dtype or None, default: None Data type to be used. If None dtype is same as recording traces. **job_kwargs: Use by job_tools modules to set: diff --git a/src/spikeinterface/extractors/neoextractors/alphaomega.py b/src/spikeinterface/extractors/neoextractors/alphaomega.py index a58b5ab5ec..8d9eee0924 100644 --- a/src/spikeinterface/extractors/neoextractors/alphaomega.py +++ b/src/spikeinterface/extractors/neoextractors/alphaomega.py @@ -15,11 +15,11 @@ class AlphaOmegaRecordingExtractor(NeoBaseRecordingExtractor): ---------- folder_path: str or Path-like The folder path to the AlphaOmega recordings. - lsx_files: list of strings or None, optional + lsx_files: list of strings or None, default: None A list of listings files that refers to mpx files to load. - stream_id: {'RAW', 'LFP', 'SPK', 'ACC', 'AI', 'UD'}, optional + stream_id: {"RAW", "LFP", "SPK", "ACC", "AI", "UD"}, default: "RAW" If there are several streams, specify the stream id you want to load. - stream_name: str, optional + stream_name: str, default: None If there are several streams, specify the stream name you want to load. all_annotations: bool, default: False Load exhaustively all annotations from neo. diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 3e30cf77ae..b4f1e3f341 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -1,6 +1,6 @@ from pathlib import Path -import probeinterface as pi +import probeinterface from spikeinterface.core.core_tools import define_function_from_class @@ -17,15 +17,15 @@ class BiocamRecordingExtractor(NeoBaseRecordingExtractor): ---------- file_path: str The file path to load the recordings from. - mea_pitch: float, optional + mea_pitch: float, default: None The inter-electrode distance (pitch) between electrodes. - electrode_width: float, optional + electrode_width: float, default: None Width of the electrodes in um. - stream_id: str, optional + stream_id: str, default: None If there are several streams, specify the stream id you want to load. - stream_name: str, optional + stream_name: str, default: None If there are several streams, specify the stream name you want to load. - all_annotations: bool (default False) + all_annotations: bool, default: False Load exhaustively all annotations from neo. """ @@ -54,7 +54,7 @@ def __init__( probe_kwargs["mea_pitch"] = mea_pitch if electrode_width is not None: probe_kwargs["electrode_width"] = electrode_width - probe = pi.read_3brain(file_path, **probe_kwargs) + probe = probeinterface.read_3brain(file_path, **probe_kwargs) self.set_probe(probe, in_place=True) self.set_property("row", self.get_property("contact_vector")["row"]) self.set_property("col", self.get_property("contact_vector")["col"]) diff --git a/src/spikeinterface/extractors/neoextractors/blackrock.py b/src/spikeinterface/extractors/neoextractors/blackrock.py index 8300e6bc5e..474bdd21a0 100644 --- a/src/spikeinterface/extractors/neoextractors/blackrock.py +++ b/src/spikeinterface/extractors/neoextractors/blackrock.py @@ -19,9 +19,9 @@ class BlackrockRecordingExtractor(NeoBaseRecordingExtractor): ---------- file_path: str The file path to load the recordings from. - stream_id: str, optional + stream_id: str, default: None If there are several streams, specify the stream id you want to load. - stream_name: str, optional + stream_name: str, default: None If there are several streams, specify the stream name you want to load. all_annotations: bool, default: False Load exhaustively all annotations from neo. @@ -74,14 +74,14 @@ class BlackrockSortingExtractor(NeoBaseSortingExtractor): Parameters ---------- file_path: str - The file path to load the recordings from. - sampling_frequency: float, None by default. + The file path to load the recordings from + sampling_frequency: float, default: None The sampling frequency for the sorting extractor. When the signal data is available (.ncs) those files will be used to extract the frequency automatically. Otherwise, the sampling frequency needs to be specified for - this extractor to be initialized. - stream_id: str, optional + this extractor to be initialized + stream_id: str, default: None Used to extract information about the sampling frequency and t_start from the analog signal if provided. - stream_name: str, optional + stream_name: str, default: None Used to extract information about the sampling frequency and t_start from the analog signal if provided. """ diff --git a/src/spikeinterface/extractors/neoextractors/ced.py b/src/spikeinterface/extractors/neoextractors/ced.py index 2451ca8fe1..e7bc1bffb4 100644 --- a/src/spikeinterface/extractors/neoextractors/ced.py +++ b/src/spikeinterface/extractors/neoextractors/ced.py @@ -17,11 +17,11 @@ class CedRecordingExtractor(NeoBaseRecordingExtractor): ---------- file_path: str The file path to the smr or smrx file. - stream_id: str, optional + stream_id: str, default: None If there are several streams, specify the stream id you want to load. - stream_name: str, optional + stream_name: str, default: None If there are several streams, specify the stream name you want to load. - block_index: int, optional + block_index: int, default: None If there are several blocks, specify the block index you want to load. all_annotations: bool, default: False Load exhaustively all annotations from neo. diff --git a/src/spikeinterface/extractors/neoextractors/edf.py b/src/spikeinterface/extractors/neoextractors/edf.py index 5d8c56ee87..5aa51d9725 100644 --- a/src/spikeinterface/extractors/neoextractors/edf.py +++ b/src/spikeinterface/extractors/neoextractors/edf.py @@ -15,10 +15,10 @@ class EDFRecordingExtractor(NeoBaseRecordingExtractor): ---------- file_path: str The file path to load the recordings from. - stream_id: str, optional + stream_id: str, default: None If there are several streams, specify the stream id you want to load. For this neo reader streams are defined by their sampling frequency. - stream_name: str, optional + stream_name: str, default: None If there are several streams, specify the stream name you want to load. all_annotations: bool, default: False Load exhaustively all annotations from neo. diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index 2a61e7385f..3584844180 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -15,9 +15,9 @@ class IntanRecordingExtractor(NeoBaseRecordingExtractor): ---------- file_path: str The file path to load the recordings from. - stream_id: str, optional + stream_id: str, default: None If there are several streams, specify the stream id you want to load. - stream_name: str, optional + stream_name: str, default: None If there are several streams, specify the stream name you want to load. all_annotations: bool, default: False Load exhaustively all annotations from neo. diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index ac85dbdf30..ca03aa7f85 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -1,7 +1,7 @@ import numpy as np from pathlib import Path -import probeinterface as pi +import probeinterface from spikeinterface import BaseEvent, BaseEventSegment from spikeinterface.core.core_tools import define_function_from_class @@ -20,15 +20,15 @@ class MaxwellRecordingExtractor(NeoBaseRecordingExtractor): ---------- file_path: str The file path to the maxwell h5 file. - stream_id: str, optional + stream_id: str, default: None If there are several streams, specify the stream id you want to load. For MaxTwo when there are several wells at the same time you need to specify stream_id='well000' or 'well0001', etc. - stream_name: str, optional + stream_name: str, default: None If there are several streams, specify the stream name you want to load. all_annotations: bool, default: False Load exhaustively all annotations from neo. - rec_name: str, optional + rec_name: str, default: None When the file contains several recordings you need to specify the one you want to extract. (rec_name='rec0000'). install_maxwell_plugin: bool, default: False @@ -68,7 +68,7 @@ def __init__( well_name = self.stream_id # rec_name auto set by neo rec_name = self.neo_reader.rec_name - probe = pi.read_maxwell(file_path, well_name=well_name, rec_name=rec_name) + probe = probeinterface.read_maxwell(file_path, well_name=well_name, rec_name=rec_name) self.set_probe(probe, in_place=True) self.set_property("electrode", self.get_property("contact_vector")["electrode"]) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), rec_name=rec_name)) diff --git a/src/spikeinterface/extractors/neoextractors/mcsraw.py b/src/spikeinterface/extractors/neoextractors/mcsraw.py index 4b6af54bcd..24eea2d058 100644 --- a/src/spikeinterface/extractors/neoextractors/mcsraw.py +++ b/src/spikeinterface/extractors/neoextractors/mcsraw.py @@ -18,11 +18,11 @@ class MCSRawRecordingExtractor(NeoBaseRecordingExtractor): ---------- file_path: str The file path to load the recordings from. - stream_id: str, optional + stream_id: str, default: None If there are several streams, specify the stream id you want to load. - stream_name: str, optional + stream_name: str, default: None If there are several streams, specify the stream name you want to load. - block_index: int, optional + block_index: int, default: None If there are several blocks, specify the block index you want to load. all_annotations: bool, default: False Load exhaustively all annotations from neo. diff --git a/src/spikeinterface/extractors/neoextractors/mearec.py b/src/spikeinterface/extractors/neoextractors/mearec.py index 7dda9175f5..c0b820f65b 100644 --- a/src/spikeinterface/extractors/neoextractors/mearec.py +++ b/src/spikeinterface/extractors/neoextractors/mearec.py @@ -3,7 +3,7 @@ import numpy as np -import probeinterface as pi +import probeinterface from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor @@ -48,7 +48,7 @@ def __init__(self, file_path: Union[str, Path], all_annotations: bool = False): self.extra_requirements.append("mearec") - probe = pi.read_mearec(file_path) + probe = probeinterface.read_mearec(file_path) probe.annotations["mearec_name"] = str(probe.annotations["mearec_name"]) self.set_probe(probe, in_place=True) self.annotate(is_filtered=True) diff --git a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py index 5c94d99b1e..78a52ae3e6 100644 --- a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py +++ b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py @@ -167,15 +167,15 @@ def __init__( Parameters ---------- - stream_id : Optional[str], default=None + stream_id : Optional[str], default: None The ID of the stream to extract from the data. - stream_name : Optional[str], default=None + stream_name : Optional[str], default: None The name of the stream to extract from the data. - block_index : Optional[int], default=None + block_index : Optional[int], default: None The index of the block to extract from the data. - all_annotations : bool, default=False + all_annotations : bool, default: False If True, include all annotations in the extracted data. - use_names_as_ids : Optional[bool], default=None + use_names_as_ids : Optional[bool], default: None If True, use channel names as IDs. Otherwise, use default IDs. neo_kwargs : Dict[str, Any] Additional keyword arguments to pass to the NeoBaseExtractor for initialization. @@ -402,10 +402,9 @@ def _infer_sampling_frequency_from_analog_signal(self, stream_id: Optional[str] Parameters ---------- - stream_id : str, optional + stream_id : str, default: None The ID of the stream from which to infer the sampling frequency. If not provided, - the function will look for a common sampling frequency across all streams. - (default is None) + the function will look for a common sampling frequency across all streams Returns ------- @@ -491,7 +490,7 @@ def _infer_t_start_from_signal_stream(self, segment_index: int, stream_id: Optio ---------- segment_index : int The index of the segment in which to look for the stream. - stream_id : str, optional + stream_id : str, default: None The ID of the stream from which to infer t_start. If not provided, the function will look for streams with a matching sampling frequency. diff --git a/src/spikeinterface/extractors/neoextractors/neuralynx.py b/src/spikeinterface/extractors/neoextractors/neuralynx.py index 672602b66c..47452b8003 100644 --- a/src/spikeinterface/extractors/neoextractors/neuralynx.py +++ b/src/spikeinterface/extractors/neoextractors/neuralynx.py @@ -16,9 +16,9 @@ class NeuralynxRecordingExtractor(NeoBaseRecordingExtractor): ---------- folder_path: str The file path to load the recordings from. - stream_id: str, optional + stream_id: str, default: None If there are several streams, specify the stream id you want to load. - stream_name: str, optional + stream_name: str, default: None If there are several streams, specify the stream name you want to load. all_annotations: bool, default: False Load exhaustively all annotations from neo. @@ -54,9 +54,9 @@ class NeuralynxSortingExtractor(NeoBaseSortingExtractor): sampling_frequency: float The sampling frequency for the spiking channels. When the signal data is available (.ncs) those files will be used to extract the frequency. Otherwise, the sampling frequency needs to be specified for this extractor. - stream_id: str, optional + stream_id: str, default: None Used to extract information about the sampling frequency and t_start from the analog signal if provided. - stream_name: str, optional + stream_name: str, default: None Used to extract information about the sampling frequency and t_start from the analog signal if provided. """ diff --git a/src/spikeinterface/extractors/neoextractors/neuroexplorer.py b/src/spikeinterface/extractors/neoextractors/neuroexplorer.py index 2c8603cb9c..94c6953a3d 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroexplorer.py +++ b/src/spikeinterface/extractors/neoextractors/neuroexplorer.py @@ -36,10 +36,10 @@ class NeuroExplorerRecordingExtractor(NeoBaseRecordingExtractor): ---------- file_path: str The file path to load the recordings from. - stream_id: str, optional + stream_id: str, default: None If there are several streams, specify the stream id you want to load. For this neo reader streams are defined by their sampling frequency. - stream_name: str, optional + stream_name: str, default: None If there are several streams, specify the stream name you want to load. all_annotations: bool, default: False Load exhaustively all annotations from neo. diff --git a/src/spikeinterface/extractors/neoextractors/neuroscope.py b/src/spikeinterface/extractors/neoextractors/neuroscope.py index c652ce4fb9..2df95d4af5 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroscope.py +++ b/src/spikeinterface/extractors/neoextractors/neuroscope.py @@ -25,11 +25,11 @@ class NeuroScopeRecordingExtractor(NeoBaseRecordingExtractor): ---------- file_path: str The file path to the binary container usually a .dat, .lfp, .eeg extension. - xml_file_path: str, optional + xml_file_path: str, default: None The path to the xml file. If None, the xml file is assumed to have the same name as the binary file. - stream_id: str, optional + stream_id: str, default: None If there are several streams, specify the stream id you want to load. - stream_name: str, optional + stream_name: str, default: None If there are several streams, specify the stream name you want to load. all_annotations: bool, default: False Load exhaustively all annotations from neo. @@ -67,11 +67,11 @@ class NeuroScopeSortingExtractor(BaseSorting): """ Extracts spiking information from an arbitrary number of .res.%i and .clu.%i files in the general folder path. - The .res is a text file with a sorted list of spiketimes from all units displayed in sample (integer '%i') units. + The .res is a text file with a sorted list of spiketimes from all units displayed in sample (integer "%i") units. The .clu file is a file with one more row than the .res with the first row corresponding to the total number of unique ids in the file (and may exclude 0 & 1 from this count) with the rest of the rows indicating which unit id the corresponding entry in the .res file refers to. - The group id is loaded as unit property 'group'. + The group id is loaded as unit property "group". In the original Neuroscope format: Unit ID 0 is the cluster of unsorted spikes (noise). @@ -92,12 +92,12 @@ class NeuroScopeSortingExtractor(BaseSorting): clufile_path : PathType Optional. Path to a particular .clu text file. If given, only the single .clu file (and the respective .res file) are loaded - keep_mua_units : bool - Optional. Whether or not to return sorted spikes from multi-unit activity. Defaults to True. + keep_mua_units : bool, default: True + Optional. Whether or not to return sorted spikes from multi-unit activity exclude_shanks : list Optional. List of indices to ignore. The set of all possible indices is chosen by default, extracted as the final integer of all the .res.%i and .clu.%i pairs. - xml_file_path : PathType, optional + xml_file_path : PathType, default: None Path to the .xml file referenced by this sorting. """ @@ -303,15 +303,16 @@ def read_neuroscope( file_path: str The xml file. stream_id: str or None - keep_mua_units: bool - Optional. Whether or not to return sorted spikes from multi-unit activity. Defaults to True. + The stream id to load. If None, the first stream is loaded + keep_mua_units: bool, default: False + Optional. Whether or not to return sorted spikes from multi-unit activity exclude_shanks: list Optional. List of indices to ignore. The set of all possible indices is chosen by default, extracted as the final integer of all the .res. % i and .clu. % i pairs. - load_recording: bool - If True, the recording is loaded (default True) - load_sorting: bool - If True, the sorting is loaded (default False) + load_recording: bool, default: True + If True, the recording is loaded + load_sorting: bool, default: False + If True, the sorting is loaded """ outputs = () # TODO add checks for recording and sorting existence diff --git a/src/spikeinterface/extractors/neoextractors/nix.py b/src/spikeinterface/extractors/neoextractors/nix.py index 2762e5645b..298b8c6019 100644 --- a/src/spikeinterface/extractors/neoextractors/nix.py +++ b/src/spikeinterface/extractors/neoextractors/nix.py @@ -15,11 +15,11 @@ class NixRecordingExtractor(NeoBaseRecordingExtractor): ---------- file_path: str The file path to load the recordings from. - stream_id: str, optional + stream_id: str, default: None If there are several streams, specify the stream id you want to load. - stream_name: str, optional + stream_name: str, default: None If there are several streams, specify the stream name you want to load. - block_index: int, optional + block_index: int, default: None If there are several blocks, specify the block index you want to load. all_annotations: bool, default: False Load exhaustively all annotations from neo. diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index bb3ae3435a..6a37ab8d06 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -1,5 +1,4 @@ """ - There are two extractors for data saved by the Open Ephys GUI * OpenEphysLegacyRecordingExtractor: reads the original "Open Ephys" data format @@ -7,7 +6,6 @@ See https://open-ephys.github.io/gui-docs/User-Manual/Recording-data/index.html for more info. - """ from pathlib import Path @@ -15,7 +13,7 @@ import numpy as np import warnings -import probeinterface as pi +import probeinterface from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor, NeoBaseEventExtractor @@ -23,10 +21,10 @@ def drop_invalid_neo_arguments_for_version_0_12_0(neo_kwargs): - # Temporary function until neo version 0.13.0 is released from packaging.version import Version from importlib.metadata import version as lib_version + # Temporary function until neo version 0.13.0 is released neo_version = lib_version("neo") # The possibility of ignoring timestamps errors is not present in neo <= 0.12.0 if Version(neo_version) <= Version("0.12.0"): @@ -49,17 +47,17 @@ class OpenEphysLegacyRecordingExtractor(NeoBaseRecordingExtractor): Parameters ---------- folder_path: str - The folder path to load the recordings from. - stream_id: str, optional - If there are several streams, specify the stream id you want to load. - stream_name: str, optional - If there are several streams, specify the stream name you want to load. - block_index: int, optional - If there are several blocks (experiments), specify the block index you want to load. - all_annotations: bool (default False) - Load exhaustively all annotation from neo. - ignore_timestamps_errors: bool (default False) - Ignore the discontinuous timestamps errors in neo. + The folder path to load the recordings from + stream_id: str, default: None + If there are several streams, specify the stream id you want to load + stream_name: str, default: None + If there are several streams, specify the stream name you want to load + block_index: int, default: None + If there are several blocks (experiments), specify the block index you want to load + all_annotations: bool, default: False + Load exhaustively all annotation from neo + ignore_timestamps_errors: bool, default: False + Ignore the discontinuous timestamps errors in neo """ mode = "folder" @@ -107,26 +105,26 @@ class OpenEphysBinaryRecordingExtractor(NeoBaseRecordingExtractor): Parameters ---------- folder_path: str - The folder path to the root folder (containing the record node folders). - load_sync_channel : bool - If False (default) and a SYNC channel is present (e.g. Neuropixels), this is not loaded. + The folder path to the root folder (containing the record node folders) + load_sync_channel : bool, default: False + If False (default) and a SYNC channel is present (e.g. Neuropixels), this is not loaded If True, the SYNC channel is loaded and can be accessed in the analog signals. - load_sync_timestamps : bool + load_sync_timestamps : bool, default: False If True, the synchronized_timestamps are loaded and set as times to the recording. If False (default), only the t_start and sampling rate are set, and timestamps are assumed - to be uniform and linearly increasing. - experiment_names: str, list, or None + to be uniform and linearly increasing + experiment_names: str, list, or None, default: None If multiple experiments are available, this argument allows users to select one or more experiments. If None, all experiements are loaded as blocks. - E.g. 'experiment_names="experiment2"', 'experiment_names=["experiment1", "experiment2"]' - stream_id: str, optional - If there are several streams, specify the stream id you want to load. - stream_name: str, optional - If there are several streams, specify the stream name you want to load. - block_index: int, optional - If there are several blocks (experiments), specify the block index you want to load. - all_annotations: bool (default False) - Load exhaustively all annotation from neo. + E.g. `experiment_names="experiment2"`, `experiment_names=["experiment1", "experiment2"]` + stream_id: str, default: None + If there are several streams, specify the stream id you want to load + stream_name: str, default: None + If there are several streams, specify the stream name you want to load + block_index: int, default: None + If there are several blocks (experiments), specify the block index you want to load + all_annotations: bool, default: False + Load exhaustively all annotation from neo """ @@ -178,7 +176,9 @@ def __init__( settings_file = self.neo_reader.folder_structure[record_node]["experiments"][exp_id]["settings_file"] if Path(settings_file).is_file(): - probe = pi.read_openephys(settings_file=settings_file, stream_name=stream_name, raise_error=False) + probe = probeinterface.read_openephys( + settings_file=settings_file, stream_name=stream_name, raise_error=False + ) else: probe = None @@ -187,9 +187,16 @@ def __init__( self.set_probe(probe, in_place=True, group_mode="by_shank") else: self.set_probe(probe, in_place=True) - probe_name = probe.annotations["probe_name"] + + # this handles a breaking change in probeinterface after v0.2.18 + # in the new version, the Neuropixels model name is stored in the "model_name" annotation, + # rather than in the "probe_name" annotation + model_name = probe.annotations.get("model_name", None) + if model_name is None: + model_name = probe.annotations["probe_name"] + # load num_channels_per_adc depending on probe type - if "2.0" in probe_name: + if "2.0" in model_name: num_channels_per_adc = 16 num_cycles_in_adc = 16 total_channels = 384 @@ -203,7 +210,7 @@ def __init__( sample_shifts = get_neuropixels_sample_shifts(total_channels, num_channels_per_adc, num_cycles_in_adc) if self.get_num_channels() != total_channels: # need slice because not all channel are saved - chans = pi.get_saved_channel_indices_from_openephys_settings(settings_file, oe_stream) + chans = probeinterface.get_saved_channel_indices_from_openephys_settings(settings_file, oe_stream) # lets clip to 384 because this contains also the synchro channel chans = chans[chans < total_channels] sample_shifts = sample_shifts[chans] @@ -281,20 +288,38 @@ def map_to_neo_kwargs(cls, folder_path): def read_openephys(folder_path, **kwargs): """ - Read 'legacy' or 'binary' Open Ephys formats + Read "legacy" or "binary" Open Ephys formats Parameters ---------- folder_path: str or Path Path to openephys folder - stream_id: str, optional - If there are several streams, specify the stream id you want to load. - stream_name: str, optional - If there are several streams, specify the stream name you want to load. - block_index: int, optional - If there are several blocks (experiments), specify the block index you want to load. - all_annotations: bool (default False) - Load exhaustively all annotation from neo. + stream_id: str, default: None + If there are several streams, specify the stream id you want to load + stream_name: str, default: None + If there are several streams, specify the stream name you want to load + block_index: int, default: None + If there are several blocks (experiments), specify the block index you want to load + all_annotations: bool, default: False + Load exhaustively all annotation from neo + load_sync_channel : bool, default: False + If False (default) and a SYNC channel is present (e.g. Neuropixels), this is not loaded. + If True, the SYNC channel is loaded and can be accessed in the analog signals. + For open ephsy binary format only + load_sync_timestamps : bool, default: False + If True, the synchronized_timestamps are loaded and set as times to the recording. + If False (default), only the t_start and sampling rate are set, and timestamps are assumed + to be uniform and linearly increasing. + For open ephsy binary format only + experiment_names: str, list, or None, default: None + If multiple experiments are available, this argument allows users to select one + or more experiments. If None, all experiements are loaded as blocks. + E.g. `experiment_names="experiment2"`, `experiment_names=["experiment1", "experiment2"]` + For open ephsy binary format only + ignore_timestamps_errors: bool, default: False + Ignore the discontinuous timestamps errors in neo + For open ephsy legacy format only + Returns ------- @@ -313,13 +338,13 @@ def read_openephys(folder_path, **kwargs): def read_openephys_event(folder_path, block_index=None): """ - Read Open Ephys events from 'binary' format. + Read Open Ephys events from "binary" format. Parameters ---------- folder_path: str or Path Path to openephys folder - block_index: int, optional + block_index: int, default: None If there are several blocks (experiments), specify the block index you want to load. Returns diff --git a/src/spikeinterface/extractors/neoextractors/plexon.py b/src/spikeinterface/extractors/neoextractors/plexon.py index c3ff59fe82..b62bd473b6 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon.py +++ b/src/spikeinterface/extractors/neoextractors/plexon.py @@ -15,9 +15,9 @@ class PlexonRecordingExtractor(NeoBaseRecordingExtractor): ---------- file_path: str The file path to load the recordings from. - stream_id: str, optional + stream_id: str, default: None If there are several streams, specify the stream id you want to load. - stream_name: str, optional + stream_name: str, default: None If there are several streams, specify the stream name you want to load. all_annotations: bool, default: False Load exhaustively all annotations from neo. diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 8dbfc67e90..d176e6546d 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -13,9 +13,9 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): ---------- file_path: str The file path to load the recordings from. - stream_id: str, optional + stream_id: str, default: None If there are several streams, specify the stream id you want to load. - stream_name: str, optional + stream_name: str, default: None If there are several streams, specify the stream name you want to load. all_annotations: bool, default: False Load exhaustively all annotations from neo. diff --git a/src/spikeinterface/extractors/neoextractors/spike2.py b/src/spikeinterface/extractors/neoextractors/spike2.py index af172855ed..a600c61c11 100644 --- a/src/spikeinterface/extractors/neoextractors/spike2.py +++ b/src/spikeinterface/extractors/neoextractors/spike2.py @@ -16,9 +16,9 @@ class Spike2RecordingExtractor(NeoBaseRecordingExtractor): ---------- file_path: str The file path to load the recordings from. - stream_id: str, optional + stream_id: str, default: None If there are several streams, specify the stream id you want to load. - stream_name: str, optional + stream_name: str, default: None If there are several streams, specify the stream name you want to load. all_annotations: bool, default: False Load exhaustively all annotations from neo. diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index 49d55ca3eb..7c10270365 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -15,9 +15,9 @@ class SpikeGadgetsRecordingExtractor(NeoBaseRecordingExtractor): ---------- file_path: str The file path to load the recordings from. - stream_id: str, optional + stream_id: str or None, default: None If there are several streams, specify the stream id you want to load. - stream_name: str, optional + stream_name: str or None, default: None If there are several streams, specify the stream name you want to load. all_annotations: bool, default: False Load exhaustively all annotations from neo. diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index 8c3b33505d..6a6901b62e 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -4,7 +4,7 @@ from pathlib import Path import neo -import probeinterface as pi +import probeinterface from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts @@ -22,21 +22,21 @@ class SpikeGLXRecordingExtractor(NeoBaseRecordingExtractor): Based on :py:class:`neo.rawio.SpikeGLXRawIO` - Contrary to older verion this reader is folder based. - So if the folder contain several streams ('imec0.ap' 'nidq' 'imec0.lf') - then it has to be specified with 'stream_id'. + Contrary to older verions, this reader is folder-based. + If the folder contains several streams (e.g., "imec0.ap", "nidq" ,"imec0.lf"), + then the stream has to be specified with "stream_id" or "stream_name". Parameters ---------- folder_path: str The folder path to load the recordings from. - load_sync_channel: bool default False + load_sync_channel: bool default: False Whether or not to load the last channel in the stream, which is typically used for synchronization. If True, then the probe is not loaded. - stream_id: str, optional + stream_id: str or None, default: None If there are several streams, specify the stream id you want to load. - For example, 'imec0.ap' 'nidq' or 'imec0.lf'. - stream_name: str, optional + For example, "imec0.ap", "nidq", or "imec0.lf". + stream_name: str or None, default: None If there are several streams, specify the stream name you want to load. all_annotations: bool, default: False Load exhaustively all annotations from neo. @@ -60,7 +60,7 @@ def __init__(self, folder_path, load_sync_channel=False, stream_id=None, stream_ # Load probe geometry if available if "lf" in self.stream_id: meta_filename = meta_filename.replace(".lf", ".ap") - probe = pi.read_spikeglx(meta_filename) + probe = probeinterface.read_spikeglx(meta_filename) if probe.shank_ids is not None: self.set_probe(probe, in_place=True, group_mode="by_shank") @@ -84,7 +84,7 @@ def __init__(self, folder_path, load_sync_channel=False, stream_id=None, stream_ sample_shifts = get_neuropixels_sample_shifts(total_channels, num_channels_per_adc, num_cycles_in_adc) if self.get_num_channels() != total_channels: # need slice because not all channel are saved - chans = pi.get_saved_channel_indices_from_spikeglx_meta(meta_filename) + chans = probeinterface.get_saved_channel_indices_from_spikeglx_meta(meta_filename) # lets clip to 384 because this contains also the synchro channel chans = chans[chans < total_channels] sample_shifts = sample_shifts[chans] diff --git a/src/spikeinterface/extractors/neoextractors/tdt.py b/src/spikeinterface/extractors/neoextractors/tdt.py index 60cd39c010..11c189bf34 100644 --- a/src/spikeinterface/extractors/neoextractors/tdt.py +++ b/src/spikeinterface/extractors/neoextractors/tdt.py @@ -15,9 +15,9 @@ class TdtRecordingExtractor(NeoBaseRecordingExtractor): ---------- folder_path: str The folder path to the tdt folder. - stream_id: str, optional + stream_id: str or None, default: None If there are several streams, specify the stream id you want to load. - stream_name: str, optional + stream_name: str or None, default: None If there are several streams, specify the stream name you want to load. all_annotations: bool, default: False Load exhaustively all annotations from neo. diff --git a/src/spikeinterface/extractors/neuropixels_utils.py b/src/spikeinterface/extractors/neuropixels_utils.py index 0db2394dc0..6bef869eb8 100644 --- a/src/spikeinterface/extractors/neuropixels_utils.py +++ b/src/spikeinterface/extractors/neuropixels_utils.py @@ -111,7 +111,7 @@ def synchronize_neuropixel_streams(recording_ref, recording_other): Method used : 1. detect pulse times on both streams. - 2. make a linear regression from 'other' to 'ref'. + 2. make a linear regression from "other" to "ref". The slope is nclose to 1 and corresponds to the sample rate correction The intercept is close to 0 and corresponds to the delta time start diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index bca4c75d99..f7b445cdb9 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -25,7 +25,7 @@ def retrieve_electrical_series(nwbfile: NWBFile, electrical_series_name: Optiona ---------- nwbfile : NWBFile The NWBFile object from which to extract the ElectricalSeries. - electrical_series_name : str, optional + electrical_series_name : str, default: None The name of the ElectricalSeries to extract. If not specified, it will return the first found ElectricalSeries if there's only one; otherwise, it raises an error. @@ -80,10 +80,10 @@ def read_nwbfile( ---------- file_path : Path, str The path to the NWB file. - stream_mode : "fsspec" or "ros3", optional - The streaming mode to use. Default assumes the file is on the local disk. - stream_cache_path : str, optional - The path to the cache storage. Default is None. + stream_mode : "fsspec" or "ros3" or None, default: None + The streaming mode to use. If None it assumes the file is on the local disk. + stream_cache_path : str or None, default: None + The path to the cache storage Returns ------- @@ -144,17 +144,17 @@ class NwbRecordingExtractor(BaseRecording): ---------- file_path: str or Path Path to NWB file or s3 url. - electrical_series_name: str, optional + electrical_series_name: str or None, default: None The name of the ElectricalSeries. Used if multiple ElectricalSeries are present. load_time_vector: bool, default: False If True, the time vector is loaded to the recording object. samples_for_rate_estimation: int, default: 100000 The number of timestamp samples to use to estimate the rate. - Used if 'rate' is not specified in the ElectricalSeries. - stream_mode: str, optional + Used if "rate" is not specified in the ElectricalSeries. + stream_mode: str or None, default: None Specify the stream mode: "fsspec" or "ros3". - stream_cache_path: str or Path, optional - Local path for caching. Default: cwd/cache. + stream_cache_path: str or Path or None, default: None + Local path for caching. If None it uses cwd Returns ------- @@ -424,17 +424,17 @@ class NwbSortingExtractor(BaseSorting): ---------- file_path: str or Path Path to NWB file. - electrical_series_name: str, optional + electrical_series_name: str or None, default: None The name of the ElectricalSeries (if multiple ElectricalSeries are present). - sampling_frequency: float, optional + sampling_frequency: float or None, default: None The sampling frequency in Hz (required if no ElectricalSeries is available). samples_for_rate_estimation: int, default: 100000 The number of timestamp samples to use to estimate the rate. - Used if 'rate' is not specified in the ElectricalSeries. - stream_mode: str, optional + Used if "rate" is not specified in the ElectricalSeries. + stream_mode: str or None, default: None Specify the stream mode: "fsspec" or "ros3". - stream_cache_path: str or Path, optional - Local path for caching. Default: cwd/cache. + stream_cache_path: str or Path or None, default: None + Local path for caching. If None it uses cwd Returns ------- @@ -590,14 +590,14 @@ def read_nwb(file_path, load_recording=True, load_sorting=False, electrical_seri If True, the recording object is loaded. load_sorting : bool, default: False If True, the recording object is loaded. - electrical_series_name: str, optional + electrical_series_name: str or None, default: None The name of the ElectricalSeries (if multiple ElectricalSeries are present) Returns ------- extractors: extractor or tuple Single RecordingExtractor/SortingExtractor or tuple with both - (depending on 'load_recording'/'load_sorting') arguments. + (depending on "load_recording"/"load_sorting") arguments. """ outputs = () if load_recording: diff --git a/src/spikeinterface/extractors/shybridextractors.py b/src/spikeinterface/extractors/shybridextractors.py index 130c0ce47e..ccb97e31b3 100644 --- a/src/spikeinterface/extractors/shybridextractors.py +++ b/src/spikeinterface/extractors/shybridextractors.py @@ -2,7 +2,7 @@ import numpy as np -from probeinterface import read_prb, write_prb +import probeinterface from spikeinterface.core import BinaryRecordingExtractor, BaseRecordingSegment, BaseSorting, BaseSortingSegment from spikeinterface.core.core_tools import write_binary_recording, define_function_from_class @@ -69,7 +69,7 @@ def __init__(self, file_path): ) # load probe file - probegroup = read_prb(params["probe"]) + probegroup = probeinterface.read_prb(params["probe"]) self.set_probegroup(probegroup, in_place=True) self._kwargs = {"file_path": str(Path(file_path).absolute())} self.extra_requirements.extend(["hybridizer", "pyyaml"]) @@ -81,14 +81,14 @@ def write_recording(recording, save_path, initial_sorting_fn, dtype="float32", * Parameters ---------- recording: RecordingExtractor - The recording extractor to be converted and saved. + The recording extractor to be converted and saved save_path: str - Full path to desired target folder. + Full path to desired target folder initial_sorting_fn: str Full path to the initial sorting csv file (can also be generated - using write_sorting static method from the SHYBRIDSortingExtractor). - dtype: dtype - Type of the saved data. Default float32. + using write_sorting static method from the SHYBRIDSortingExtractor) + dtype: dtype, default: float32 + Type of the saved data **write_binary_kwargs: keyword arguments for write_to_binary_dat_format() function """ try: @@ -119,7 +119,7 @@ def write_recording(recording, save_path, initial_sorting_fn, dtype="float32", * # write probe file probe_fn = (save_path / probe_name).absolute() probegroup = recording.get_probegroup() - write_prb(probe_fn, probegroup, total_nb_channels=recording.get_num_channels()) + probeinterface.write_prb(probe_fn, probegroup, total_nb_channels=recording.get_num_channels()) # create parameters file parameters = dict( diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index 257c1d566a..64c6499767 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -1,10 +1,10 @@ import unittest import platform import subprocess +import os from packaging import version import pytest -import numpy as np from spikeinterface.core.testing import check_recordings_equal from spikeinterface import get_global_dataset_folder @@ -16,6 +16,7 @@ EventCommonTestSuite, ) +ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) local_folder = get_global_dataset_folder() / "ephy_testing_data" @@ -277,6 +278,7 @@ class CedRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ] +@pytest.mark.skipif(ON_GITHUB, reason="Maxwell plugin not installed on GitHub") class MaxwellRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = MaxwellRecordingExtractor downloads = ["maxwell"] diff --git a/src/spikeinterface/extractors/toy_example.py b/src/spikeinterface/extractors/toy_example.py index 2a97dfdb17..d281862789 100644 --- a/src/spikeinterface/extractors/toy_example.py +++ b/src/spikeinterface/extractors/toy_example.py @@ -41,25 +41,25 @@ def toy_example( Parameters ---------- - duration: float (or list if multi segment) - Duration in seconds (default 10). - num_channels: int - Number of channels (default 4). - num_units: int - Number of units (default 10). - sampling_frequency: float - Sampling frequency (default 30000). - num_segments: int - Number of segments (default 2). - spike_times: ndarray (or list of multi segment) - Spike time in the recording. - spike_labels: ndarray (or list of multi segment) + duration: float or list[float], default: 10 + Duration in seconds. If a list is provided, it will be the duration of each segment. + num_channels: int, default: 4 + Number of channels + num_units: int, default: 10 + Number of units + sampling_frequency: float, default: 30000 + Sampling frequency + num_segments: int, default: 2 + Number of segments. + spike_times: np.array or list[nparray] or None, default: None + Spike time in the recording + spike_labels: np.array or list[nparray] or None, default: None Cluster label for each spike time (needs to specified both together). # score_detection: int (between 0 and 1) - # Generate the sorting based on a subset of spikes compare with the trace generation. - firing_rate: float - The firing rate for the units (in Hz). - seed: int + # Generate the sorting based on a subset of spikes compare with the trace generation + firing_rate: float, default: 3.0 + The firing rate for the units (in Hz) + seed: int or None, default: None Seed for random initialization. Returns diff --git a/src/spikeinterface/extractors/tridesclousextractors.py b/src/spikeinterface/extractors/tridesclousextractors.py index 8b0ce37e7a..6bf248c62a 100644 --- a/src/spikeinterface/extractors/tridesclousextractors.py +++ b/src/spikeinterface/extractors/tridesclousextractors.py @@ -11,7 +11,7 @@ class TridesclousSortingExtractor(BaseSorting): ---------- folder_path : str or Path Path to the Tridesclous folder. - chan_grp : list, optional + chan_grp : list or None, default: None The channel group(s) to load. Returns diff --git a/src/spikeinterface/extractors/waveclussnippetstextractors.py b/src/spikeinterface/extractors/waveclussnippetstextractors.py index 2e4c28b12e..3bcda1ea70 100644 --- a/src/spikeinterface/extractors/waveclussnippetstextractors.py +++ b/src/spikeinterface/extractors/waveclussnippetstextractors.py @@ -97,12 +97,10 @@ def get_snippets( Parameters ---------- - indexes: (Union[int, None], optional) - start sample index, or zero if None. Defaults to None. - end_frame: (Union[int, None], optional) - end_sample, or number of samples if None. Defaults to None. - channel_indices: (Union[List, None], optional) - Indices of channels to return, or all channels if None. Defaults to None. + indices: list[int] + Indices of the snippets to return + channel_indices: Union[list, None], default: None + Indices of channels to return, or all channels if None Returns ------- @@ -122,10 +120,10 @@ def frames_to_indices(self, start_frame: Union[int, None] = None, end_frame: Uni Parameters ---------- - start_frame: (Union[int, None], optional) - start sample index, or zero if None. Defaults to None. - end_frame: (Union[int, None], optional) - end_sample, or number of samples if None. Defaults to None. + start_frame: Union[int, None], default: None + start sample index, or zero if Non + end_frame: Union[int, None], default: None + end_sample, or number of samples if None Returns ------- diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index 33e0ff6c03..3aebd13797 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -1,12 +1,3 @@ -# This is kept in 0.97.0 and then will be removed -from .template_tools import ( - get_template_amplitudes, - get_template_extremum_channel, - get_template_extremum_channel_peak_shift, - get_template_extremum_amplitude, - get_template_channel_sparsity, -) - from .template_metrics import ( TemplateMetricsCalculator, compute_template_metrics, diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 7e6c95a875..2aaf4d20b9 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -153,12 +153,13 @@ def get_data(self, outputs="concatenated"): Get computed spike amplitudes. Parameters ---------- - outputs : str, optional - 'concatenated' or 'by_unit', by default 'concatenated' + outputs : "concatenated" | "by_unit", default: "concatenated" + The output format + Returns ------- spike_amplitudes : np.array or dict - The spike amplitudes as an array (outputs='concatenated') or + The spike amplitudes as an array (outputs="concatenated") or as a dict with units as key and spike amplitudes as values. """ we = self.waveform_extractor @@ -206,16 +207,16 @@ def compute_amplitude_scalings( ---------- waveform_extractor: WaveformExtractor The waveform extractor object - sparsity: ChannelSparsity, default: None + sparsity: ChannelSparsity or None, default: None If waveforms are not sparse, sparsity is required if the number of channels is greater than `max_dense_channels`. If the waveform extractor is sparse, its sparsity is automatically used. max_dense_channels: int, default: 16 Maximum number of channels to allow running without sparsity. To compute amplitude scaling using dense waveforms, set this to None, sparsity to None, and pass dense waveforms as input. - ms_before : float, default: None + ms_before : float or None, default: None The cut out to apply before the spike peak to extract local waveforms. If None, the WaveformExtractor ms_before is used. - ms_after : float, default: None + ms_after : float or None, default: None The cut out to apply after the spike peak to extract local waveforms. If None, the WaveformExtractor ms_after is used. handle_collisions: bool, default: True @@ -226,18 +227,16 @@ def compute_amplitude_scalings( The maximum time difference in ms before and after a spike to gather colliding spikes. load_if_exists : bool, default: False Whether to load precomputed spike amplitudes, if they already exist. - outputs: str, default: 'concatenated' - How the output should be returned: - - 'concatenated' - - 'by_unit' + outputs: "concatenated" | "by_unit", default: "concatenated" + How the output should be returned {} Returns ------- amplitude_scalings: np.array or list of dict The amplitude scalings. - - If 'concatenated' all amplitudes for all spikes and all units are concatenated - - If 'by_unit', amplitudes are returned as a list (for segments) of dictionaries (for units) + - If "concatenated" all amplitudes for all spikes and all units are concatenated + - If "by_unit", amplitudes are returned as a list (for segments) of dictionaries (for units) """ if load_if_exists and waveform_extractor.is_extension(AmplitudeScalingsCalculator.extension_name): sac = waveform_extractor.load_extension(AmplitudeScalingsCalculator.extension_name) @@ -591,9 +590,9 @@ def fit_collision( # ---------- # we : WaveformExtractor # The WaveformExtractor object. -# sparsity : ChannelSparsity, default=None +# sparsity : ChannelSparsity, default: None # The ChannelSparsity. If None, only main channels are plotted. -# num_collisions : int, default=None +# num_collisions : int, default: None # Number of collisions to plot. If None, all collisions are plotted. # """ # assert we.is_extension("amplitude_scalings"), "Could not find amplitude scalings extension!" diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 6e693635eb..369354fe04 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -146,15 +146,15 @@ def compute_correlograms( Parameters ---------- waveform_or_sorting_extractor : WaveformExtractor or BaseSorting - If WaveformExtractor, the correlograms are saved as WaveformExtensions. + If WaveformExtractor, the correlograms are saved as WaveformExtensions load_if_exists : bool, default: False - Whether to load precomputed crosscorrelograms, if they already exist. - window_ms : float, optional - The window in ms, by default 100.0. - bin_ms : float, optional - The bin size in ms, by default 5.0. - method : str, optional - "auto" | "numpy" | "numba". If _auto" and numba is installed, numba is used, by default "auto" + Whether to load precomputed crosscorrelograms, if they already exist + window_ms : float, default: 100.0 + The window in ms + bin_ms : float, default: 5 + The bin size in ms + method : "auto" | "numpy" | "numba", default: "auto" + If "auto" and numba is installed, numba is used, otherwise numpy is used Returns ------- diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index e98e64f753..1185e179b1 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -77,15 +77,15 @@ def compute_isi_histograms( Parameters ---------- waveform_or_sorting_extractor : WaveformExtractor or BaseSorting - If WaveformExtractor, the ISI histograms are saved as WaveformExtensions. + If WaveformExtractor, the ISI histograms are saved as WaveformExtensions load_if_exists : bool, default: False - Whether to load precomputed crosscorrelograms, if they already exist. - window_ms : float, optional - The window in ms, by default 50.0. - bin_ms : float, optional - The bin size in ms, by default 1.0. - method : str, optional - "auto" | "numpy" | "numba". If "auto" and numba is installed, numba is used, by default "auto" + Whether to load precomputed crosscorrelograms, if they already exist + window_ms : float, default: 50 + The window in ms + bin_ms : float, default: 1 + The bin size in ms + method : "auto" | "numpy" | "numba", default: "auto" + . If "auto" and numba is installed, numba is used, otherwise numpy is used Returns ------- diff --git a/src/spikeinterface/postprocessing/noise_level.py b/src/spikeinterface/postprocessing/noise_level.py index 8b5c04dab1..db93731977 100644 --- a/src/spikeinterface/postprocessing/noise_level.py +++ b/src/spikeinterface/postprocessing/noise_level.py @@ -56,13 +56,11 @@ def compute_noise_levels(waveform_extractor, load_if_exists=False, **params): Parameters ---------- waveform_extractor: WaveformExtractor - A waveform extractor object. - num_chunks_per_segment: int (deulf 20) - Number of chunks to estimate the noise - chunk_size: int (default 10000) - Size of chunks in sample - seed: int (default None) - Eventualy a seed for reproducibility. + A waveform extractor object + load_if_exists: bool, default: False + If True, the noise levels are loaded if they already exist + **params: dict with additional parameters + Returns ------- diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 8383dcbb43..cf32e79b25 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -80,7 +80,7 @@ def get_projections(self, unit_id, sparse=False): ---------- unit_id : int or str The unit id to return PCA projections for - sparse: bool, default False + sparse: bool, default: False If True, and sparsity is not None, only projections on sparse channels are returned. Returns @@ -122,18 +122,18 @@ def get_all_projections(self, channel_ids=None, unit_ids=None, outputs="id"): Parameters ---------- - channel_ids : list, optional + channel_ids : list, default: None List of channel ids on which projections are computed - unit_ids : list, optional + unit_ids : list, default: None List of unit ids to return projections for outputs: str - * 'id': 'all_labels' contain unit ids - * 'index': 'all_labels' contain unit indices + * "id": "all_labels" contain unit ids + * "index": "all_labels" contain unit indices Returns ------- all_labels: np.array - Array with labels (ids or indices based on 'outputs') of returned PCA projections + Array with labels (ids or indices based on "outputs") of returned PCA projections all_projections: np.array The PCA projections (num_all_waveforms, num_components, num_channels) """ @@ -169,7 +169,7 @@ def project_new(self, new_waveforms, unit_id=None, sparse=False): new_waveforms: np.array Array with new waveforms to project with shape (num_waveforms, num_samples, num_channels) unit_id: int or str - In case PCA is sparse and mode is by_channel_local, the unit_id of 'new_waveforms' + In case PCA is sparse and mode is by_channel_local, the unit_id of "new_waveforms" sparse: bool, default: False If True, and sparsity is not None, only projections on sparse channels are returned. @@ -186,7 +186,7 @@ def project_new(self, new_waveforms, unit_id=None, sparse=False): wfs0 = self.waveform_extractor.get_waveforms(unit_id=self.waveform_extractor.sorting.unit_ids[0]) assert ( wfs0.shape[1] == new_waveforms.shape[1] - ), "Mismatch in number of samples between waveforms used to fit the pca model and 'new_waveforms" + ), "Mismatch in number of samples between waveforms used to fit the pca model and 'new_waveforms'" num_channels = len(self.waveform_extractor.channel_ids) # check waveform shapes @@ -200,7 +200,7 @@ def project_new(self, new_waveforms, unit_id=None, sparse=False): else: assert ( wfs0.shape[2] == new_waveforms.shape[2] - ), "Mismatch in number of channels between waveforms used to fit the pca model and 'new_waveforms" + ), "Mismatch in number of channels between waveforms used to fit the pca model and 'new_waveforms'" channel_inds = np.arange(num_channels, dtype=int) # get channel ids and pca models @@ -706,27 +706,28 @@ def compute_principal_components( The waveform extractor load_if_exists: bool If True and pc scores are already in the waveform extractor folders, pc scores are loaded and not recomputed. - n_components: int - Number of components fo PCA - default 5 - mode: str, default: 'by_channel_local' - - 'by_channel_local': a local PCA is fitted for each channel (projection by channel) - - 'by_channel_global': a global PCA is fitted for all channels (projection by channel) - - 'concatenated': channels are concatenated and a global PCA is fitted - sparsity: ChannelSparsity or None + n_components: int, default: 5 + Number of components fo PCA + mode: "by_channel_local" | "by_channel_global" | "concatenated", default: "by_channel_local" + The PCA mode: + - "by_channel_local": a local PCA is fitted for each channel (projection by channel) + - "by_channel_global": a global PCA is fitted for all channels (projection by channel) + - "concatenated": channels are concatenated and a global PCA is fitted + sparsity: ChannelSparsity or None, default: None The sparsity to apply to waveforms. - If waveform_extractor is already sparse, the default sparsity will be used - default None - whiten: bool - If True, waveforms are pre-whitened - default True - dtype: dtype - Dtype of the pc scores - default float32 - n_jobs: int - Number of jobs used to fit the PCA model (if mode is 'by_channel_local') - default 1 - progress_bar: bool - If True, a progress bar is shown - default False - tmp_folder: str + If waveform_extractor is already sparse, the default sparsity will be used + whiten: bool, default: True + If True, waveforms are pre-whitened + dtype: dtype, default: "float32" + Dtype of the pc scores + tmp_folder: str or Path or None, default: None The temporary folder to use for parallel computation. If you run several `compute_principal_components` - functions in parallel with mode 'by_channel_local', you need to specify a different `tmp_folder` for each call, - to avoid overwriting to the same folder - default None + functions in parallel with mode "by_channel_local", you need to specify a different `tmp_folder` for each call, + to avoid overwriting to the same folder + n_jobs: int, default: 1 + Number of jobs used to fit the PCA model (if mode is "by_channel_local") + progress_bar: bool, default: False + If True, a progress bar is shown Returns ------- diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index ccd2121174..50dac50ad3 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -93,13 +93,13 @@ def get_data(self, outputs="concatenated"): Parameters ---------- - outputs : str, optional - 'concatenated' or 'by_unit', by default 'concatenated' + outputs : "concatenated" | "by_unit", default: "concatenated" + The output format Returns ------- spike_amplitudes : np.array or dict - The spike amplitudes as an array (outputs='concatenated') or + The spike amplitudes as an array (outputs="concatenated") or as a dict with units as key and spike amplitudes as values. """ we = self.waveform_extractor @@ -148,25 +148,20 @@ def compute_spike_amplitudes( The waveform extractor object load_if_exists : bool, default: False Whether to load precomputed spike amplitudes, if they already exist. - peak_sign: str - The sign to compute maximum channel: - - 'neg' - - 'pos' - - 'both' + peak_sign: "neg" | "pos" | "both", default: "neg + The sign to compute maximum channel return_scaled: bool If True and recording has gain_to_uV/offset_to_uV properties, amplitudes are converted to uV. - outputs: str - How the output should be returned: - - 'concatenated' - - 'by_unit' + outputs: "concatenated" | "by_unit", default: "concatenated" + How the output should be returned {} Returns ------- amplitudes: np.array or list of dict The spike amplitudes. - - If 'concatenated' all amplitudes for all spikes and all units are concatenated - - If 'by_unit', amplitudes are returned as a list (for segments) of dictionaries (for units) + - If "concatenated" all amplitudes for all spikes and all units are concatenated + - If "by_unit", amplitudes are returned as a list (for segments) of dictionaries (for units) """ if load_if_exists and waveform_extractor.is_extension(SpikeAmplitudesCalculator.extension_name): sac = waveform_extractor.load_extension(SpikeAmplitudesCalculator.extension_name) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 28eed131cd..72d44bf348 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -82,13 +82,13 @@ def get_data(self, outputs="concatenated"): Parameters ---------- - outputs : str, optional - 'concatenated' or 'by_unit', by default 'concatenated' + outputs : "concatenated" | "by_unit", default: "concatenated" + The output format Returns ------- spike_locations : np.array or dict - The spike locations as a structured array (outputs='concatenated') or + The spike locations as a structured array (outputs="concatenated") or as a dict with units as key and spike locations as values. """ we = self.waveform_extractor @@ -140,38 +140,38 @@ def compute_spike_locations( Parameters ---------- waveform_extractor : WaveformExtractor - A waveform extractor object. + A waveform extractor object load_if_exists : bool, default: False - Whether to load precomputed spike locations, if they already exist. - ms_before : float - The left window, before a peak, in milliseconds. - ms_after : float - The right window, after a peak, in milliseconds. + Whether to load precomputed spike locations, if they already exist + ms_before : float, default: 0.5 + The left window, before a peak, in milliseconds + ms_after : float, default: 0.5 + The right window, after a peak, in milliseconds spike_retriver_kwargs: dict - A dictionary to control the behavior for getting the maximum channel for each spike. + A dictionary to control the behavior for getting the maximum channel for each spike This dictionary contains: - * channel_from_template: bool, default True - For each spike is the maximum channel computed from template or re estimated at every spikes. + * channel_from_template: bool, default: True + For each spike is the maximum channel computed from template or re estimated at every spikes channel_from_template = True is old behavior but less acurate channel_from_template = False is slower but more accurate - * radius_um: float, default 50 - In case channel_from_template=False, this is the radius to get the true peak. - * peak_sign="neg" + * radius_um: float, default: 50 + In case channel_from_template=False, this is the radius to get the true peak + * peak_sign, default: "neg" In case channel_from_template=False, this is the peak sign. - method : str - 'center_of_mass' / 'monopolar_triangulation' / 'grid_convolution' - method_kwargs : dict + method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" + The localization method to use + method_kwargs : dict, default: dict() Other kwargs depending on the method. - outputs : str - 'concatenated' (default) / 'by_unit' + outputs : "concatenated" | "by_unit", default: "concatenated" + The output format {} Returns ------- spike_locations: np.array or list of dict The spike locations. - - If 'concatenated' all locations for all spikes and all units are concatenated - - If 'by_unit', locations are returned as a list (for segments) of dictionaries (for units) + - If "concatenated" all locations for all spikes and all units are concatenated + - If "by_unit", locations are returned as a list (for segments) of dictionaries (for units) """ if load_if_exists and waveform_extractor.is_extension(SpikeLocationsCalculator.extension_name): slc = waveform_extractor.load_extension(SpikeLocationsCalculator.extension_name) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index a359e2a814..858af3ee08 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -241,12 +241,12 @@ def compute_template_metrics( Parameters ---------- - waveform_extractor : WaveformExtractor, optional + waveform_extractor : WaveformExtractor The waveform extractor used to compute template metrics load_if_exists : bool, default: False Whether to load precomputed template metrics, if they already exist. - metric_names : list, optional - List of metrics to compute (see si.postprocessing.get_template_metric_names()), by default None + metric_names : list or None, default: None + List of metrics to compute (see si.postprocessing.get_template_metric_names()) peak_sign : {"neg", "pos"}, default: "neg" Whether to use the positive ("pos") or negative ("neg") peaks to estimate extremum channels. upsampling_factor : int, default: 10 @@ -278,8 +278,8 @@ def compute_template_metrics( ------- template_metrics : pd.DataFrame Dataframe with the computed template metrics. - If 'sparsity' is None, the index is the unit_id. - If 'sparsity' is given, the index is a multi-index (unit_id, channel_id) + If "sparsity" is None, the index is the unit_id. + If "sparsity" is given, the index is a multi-index (unit_id, channel_id) Notes ----- diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index e026604b68..5febdf83f7 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -90,9 +90,9 @@ def compute_template_similarity( A waveform extractor object load_if_exists : bool, default: False Whether to load precomputed similarity, if is already exists. - method: str - Method name ('cosine_similarity') - waveform_extractor_other: WaveformExtractor, optional + method: str, default: "cosine_similarity" + The method to compute the similarity + waveform_extractor_other: WaveformExtractor, default: None A second waveform extractor object Returns @@ -143,7 +143,7 @@ def check_equal_template_with_distribution_overlap( template0 , template1=None or numpy array The average of each cluster. If None, then computed. - num_shift: int default 2 + num_shift: int default: 2 number of shift on each side to perform. quantile_limit: float in [0 1] The quantile overlap limit. diff --git a/src/spikeinterface/postprocessing/template_tools.py b/src/spikeinterface/postprocessing/template_tools.py deleted file mode 100644 index 0d992a0046..0000000000 --- a/src/spikeinterface/postprocessing/template_tools.py +++ /dev/null @@ -1,39 +0,0 @@ -# This is kept in 0.97.0 and then will be removed - -import warnings - -import spikeinterface.core.template_tools as tt - - -def _warn(): - warnings.warn( - "The spikeinterface.postprocessing.template_tools is submodule is deprecated." - "Use spikeinterface.core.template_tools instead", - DeprecationWarning, - stacklevel=2, - ) - - -def get_template_amplitudes(*args, **kwargs): - _warn() - return tt.get_template_amplitudes(*args, **kwargs) - - -def get_template_extremum_channel(*args, **kwargs): - _warn() - return tt.get_template_extremum_channel(*args, **kwargs) - - -def get_template_channel_sparsity(*args, **kwargs): - _warn() - return tt.get_template_channel_sparsity(*args, **kwargs) - - -def get_template_extremum_channel_peak_shift(*args, **kwargs): - _warn() - return tt.get_template_extremum_channel_peak_shift(*args, **kwargs) - - -def get_template_extremum_amplitude(*args, **kwargs): - _warn() - return tt.get_template_extremum_amplitude(*args, **kwargs) diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index 48ceb34a4e..f665bac8d6 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -69,13 +69,13 @@ def get_data(self, outputs="numpy"): Parameters ---------- - outputs : str, optional - 'numpy' or 'by_unit', by default 'numpy' + outputs : "numpy" | "by_unit", default: "numpy" + The output format Returns ------- unit_locations : np.array or dict - The unit locations as a Nd array (outputs='numpy') or + The unit locations as a Nd array (outputs="numpy") or as a dict with units as key and locations as values. """ if outputs == "numpy": @@ -104,15 +104,15 @@ def compute_unit_locations( Parameters ---------- waveform_extractor: WaveformExtractor - A waveform extractor object. + A waveform extractor object load_if_exists : bool, default: False - Whether to load precomputed unit locations, if they already exist. - method: str - 'center_of_mass' / 'monopolar_triangulation' / 'grid_convolution' - outputs: str - 'numpy' (default) / 'by_unit' + Whether to load precomputed unit locations, if they already exist + method: "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" + The method to use for localization + outputs: "numpy" | "by_unit", default: "numpy" + The output format method_kwargs: - Other kwargs depending on the method. + Other kwargs depending on the method Returns ------- @@ -247,21 +247,21 @@ def compute_monopolar_triangulation( ---------- waveform_extractor:WaveformExtractor A waveform extractor object - method: str ('least_square', 'minimize_with_log_penality') - 2 variants of the method - radius_um: float + method: "least_square" | "minimize_with_log_penality", default: "least_square" + The optimizer to use + radius_um: float, default: 75 For channel sparsity - max_distance_um: float + max_distance_um: float, default: 1000 to make bounddary in x, y, z and also for alpha - return_alpha: bool default False + return_alpha: bool, default: False Return or not the alpha value - enforce_decrease : bool (default False) + enforce_decrease : bool, default: False Enforce spatial decreasingness for PTP vectors - feature: string in ['ptp', 'energy', 'peak_voltage'] + feature: "ptp" | "energy" | "peak_voltage", default: "ptp" The available features to consider for estimating the position via - monopolar triangulation are peak-to-peak amplitudes ('ptp', default), - energy ('energy', as L2 norm) or voltages at the center of the waveform - ('peak_voltage') + monopolar triangulation are peak-to-peak amplitudes ("ptp", default), + energy ("energy", as L2 norm) or voltages at the center of the waveform + ("peak_voltage") Returns ------- @@ -323,12 +323,12 @@ def compute_center_of_mass(waveform_extractor, peak_sign="neg", radius_um=75, fe ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the template to compute best channels radius_um: float Radius to consider in order to estimate the COM - feature: str ['ptp', 'mean', 'energy', 'peak_voltage'] - Feature to consider for computation. Default is 'ptp' + feature: "ptp" | "mean" | "energy" | "peak_voltage", default: "ptp" + Feature to consider for computation Returns ------- @@ -387,24 +387,24 @@ def compute_grid_convolution( ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') - radius_um: float + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the template to compute best channels + radius_um: float, default: 40.0 Radius to consider for the fake templates - upsampling_um: float + upsampling_um: float, default: 5 Upsampling resolution for the grid of templates - sigma_um: np.array + sigma_um: np.array, default: np.linspace(5.0, 25.0, 5) Spatial decays of the fake templates - sigma_ms: float + sigma_ms: float, default: 0.25 The temporal decay of the fake templates - margin_um: float + margin_um: float, default: 50 The margin for the grid of fake templates - prototype: np.array + prototype: np.array or None, default: None Fake waveforms for the templates. If None, generated as Gaussian - percentile: float (default 10) + percentile: float, default: 10 The percentage in [0, 100] of the best scalar products kept to estimate the position - sparsity_threshold: float (default 0.01) + sparsity_threshold: float, default: 0.01 The sparsity threshold (in 0-1) below which weights should be considered as 0. Returns ------- diff --git a/src/spikeinterface/preprocessing/average_across_direction.py b/src/spikeinterface/preprocessing/average_across_direction.py index 3f056dfada..b237b12a77 100644 --- a/src/spikeinterface/preprocessing/average_across_direction.py +++ b/src/spikeinterface/preprocessing/average_across_direction.py @@ -24,12 +24,12 @@ def __init__( ---------- parent_recording : BaseRecording recording to zero-pad - direction : str + direction : "x" | "y" | "z", default: "y" Channels living at unique positions along this direction will be averaged. - dtype : optional numpy dtype - If unset, parent dtype is preserved, but the average will - lose accuracy, so float32 by default. + dtype : numpy dtype or None, default: float32 + If None, parent dtype is preserved, but the average will + lose accuracy """ parent_channel_locations = parent_recording.get_channel_locations() dim = ["x", "y", "z"].index(direction) diff --git a/src/spikeinterface/preprocessing/clip.py b/src/spikeinterface/preprocessing/clip.py index cc18d51d2e..784f799279 100644 --- a/src/spikeinterface/preprocessing/clip.py +++ b/src/spikeinterface/preprocessing/clip.py @@ -15,10 +15,10 @@ class ClipRecording(BasePreprocessor): ---------- recording: RecordingExtractor The recording extractor to be transformed - a_min: float or `None` (default `None`) + a_min: float or None, default: None Minimum value. If `None`, clipping is not performed on lower interval edge. - a_max: float or `None` (default `None`) + a_max: float or None, default: None Maximum value. If `None`, clipping is not performed on upper interval edge. @@ -59,22 +59,22 @@ class BlankSaturationRecording(BasePreprocessor): The recording extractor to be transformed Minimum value. If `None`, clipping is not performed on lower interval edge. - abs_threshold: float or None + abs_threshold: float or None, default: None The absolute value for considering that the signal is saturating - quantile_threshold: float or None + quantile_threshold: float or None, default: None Tha value in [0, 1] used if abs_threshold is None to automatically set the abs_threshold given the data. Must be provided if abs_threshold is None - direction: string in ['upper', 'lower', 'both'] - Only values higher than the detection threshold are set to fill_value ('higher'), - or only values lower than the detection threshold ('lower'), or both ('both') - fill_value: float or None + direction: "upper" | "lower" | "both", default: "upper" + Only values higher than the detection threshold are set to fill_value ("higher"), + or only values lower than the detection threshold ("lower"), or both ("both") + fill_value: float or None, default: None The value to write instead of the saturating signal. If None, then the value is automatically computed as the median signal value - num_chunks_per_segment: int (default 50) + num_chunks_per_segment: int, default: 50 The number of chunks per segments to consider to estimate the threshold/fill_values - chunk_size: int (default 500) + chunk_size: int, default: 500 The chunk size to estimate the threshold/fill_values - seed: int (default 0) + seed: int, default: 0 The seed to select the random chunks Returns diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 6d6ce256de..219854f340 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -16,22 +16,22 @@ class CommonReferenceRecording(BasePreprocessor): ---------- recording: RecordingExtractor The recording extractor to be re-referenced - reference: str 'global', 'single' or 'local' - If 'global' then CMR/CAR is used either by groups or all channel way. - If 'single', the selected channel(s) is remove from all channels. operator is no used in that case. - If 'local', an average CMR/CAR is implemented with only k channels selected the nearest outside of a radius around each channel - operator: str 'median' or 'average' - If 'median', common median reference (CMR) is implemented (the median of + reference: "global" | "single" | "local", default: "global" + If "global" then CMR/CAR is used either by groups or all channel way. + If "single", the selected channel(s) is remove from all channels. operator is no used in that case. + If "local", an average CMR/CAR is implemented with only k channels selected the nearest outside of a radius around each channel + operator: "median" | "average", default: "median" + If "median", common median reference (CMR) is implemented (the median of the selected channels is removed for each timestamp). - If 'average', common average reference (CAR) is implemented (the mean of the selected channels is removed + If "average", common average reference (CAR) is implemented (the mean of the selected channels is removed for each timestamp). groups: list List of lists containing the channel ids for splitting the reference. The CMR, CAR, or referencing with respect to single channels are applied group-wise. However, this is not applied for the local CAR. It is useful when dealing with different channel groups, e.g. multiple tetrodes. ref_channel_ids: list or int - If no 'groups' are specified, all channels are referenced to 'ref_channel_ids'. If 'groups' is provided, then a - list of channels to be applied to each group is expected. If 'single' reference, a list of one channel or an + If no "groups" are specified, all channels are referenced to "ref_channel_ids". If "groups" is provided, then a + list of channels to be applied to each group is expected. If "single" reference, a list of one channel or an int is expected. local_radius: tuple(int, int) Use in the local CAR implementation as the selecting annulus (exclude radius, include radius) diff --git a/src/spikeinterface/preprocessing/correct_lsb.py b/src/spikeinterface/preprocessing/correct_lsb.py index fe2e5f00cb..bd1fc39230 100644 --- a/src/spikeinterface/preprocessing/correct_lsb.py +++ b/src/spikeinterface/preprocessing/correct_lsb.py @@ -14,14 +14,14 @@ def correct_lsb(recording, num_chunks_per_segment=20, chunk_size=10000, seed=Non ---------- recording : RecordingExtractor The recording extractor to be LSB-corrected. - num_chunks_per_segment: int - Number of chunks per segment for random chunk, by default 20 - chunk_size : int - Size of a chunk in number for random chunk, by default 10000 - seed : int - Random seed for random chunk, by default None - verbose : bool - If True, estimate LSB value is printed, by default False + num_chunks_per_segment: int, default: 20 + Number of chunks per segment for random chunk + chunk_size : int, default: 10000 + Size of a chunk in number for random chunk + seed : int or None, default: None + Random seed for random chunk + verbose : bool, default: False + If True, estimate LSB value is printed Returns ------- diff --git a/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py index 0525cdfc7a..43b35dfef9 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py @@ -2,7 +2,7 @@ import numpy as np from pathlib import Path -import probeinterface as pi +import probeinterface from spikeinterface import download_dataset, generate_recording, append_recordings, concatenate_recordings from spikeinterface.extractors import read_mearec, read_spikeglx, read_openephys from spikeinterface.preprocessing import depth_order, zscore @@ -29,7 +29,7 @@ def recording_and_shape(): num_cols = 2 num_rows = 64 - probe = pi.generate_multi_columns_probe(num_columns=num_cols, num_contact_per_column=num_rows) + probe = probeinterface.generate_multi_columns_probe(num_columns=num_cols, num_contact_per_column=num_rows) probe.set_device_channel_indices(np.arange(num_cols * num_rows)) recording = generate_recording(num_channels=num_cols * num_rows, durations=[10.0], sampling_frequency=30000) recording.set_probe(probe, in_place=True) diff --git a/src/spikeinterface/preprocessing/deepinterpolation/train.py b/src/spikeinterface/preprocessing/deepinterpolation/train.py index c3b63dfc7e..9146d4099f 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/train.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/train.py @@ -84,36 +84,36 @@ def train_deepinterpolation( Number of frames after the frame to be predicted pre_post_omission : int Number of frames to be omitted before and after the frame to be predicted - existing_model_path : str | Path - Path to an existing model to be used for transfer learning, default is None - verbose : bool - Whether to print the progress of the training, default is True - steps_per_epoch : int + existing_model_path : str | Path | None, default: None + Path to an existing model to be used for transfer learning + verbose : bool, default: True + Whether to print the progress of the training + steps_per_epoch : int, default: 10 Number of steps per epoch - period_save : int + period_save : int, default: 100 Period of saving the model - apply_learning_decay : int + apply_learning_decay : int, default: 0 Whether to use a learning scheduler during training - nb_times_through_data : int + nb_times_through_data : int, default: 1 Number of times the data is repeated during training - learning_rate : float + learning_rate : float, default: 0.0001 Learning rate - loss : str + loss : str, default: "mean_squared_error" Loss function to be used - nb_workers : int + nb_workers : int, default: -1 Number of workers to be used for the training - caching_validation : bool - Whether to cache the validation data, default is False - run_uid : str + caching_validation : bool, default: False + Whether to cache the validation data + run_uid : str, default: "si" Unique identifier for the training - network_name : str - Name of the network to be used, default is None - use_gpu : bool - Whether to use GPU, default is True - disable_tf_logger : bool - Whether to disable the tensorflow logger, default is True - memory_gpu : int - Amount of memory to be used by the GPU, default is None + network_name : str, default: "unet_single_ephys_1024" + Name of the network to be used + use_gpu : bool, default: True + Whether to use GPU + disable_tf_logger : bool, default: True + Whether to disable the tensorflow logger + memory_gpu : int, default: None + Amount of memory to be used by the GPU Returns ------- diff --git a/src/spikeinterface/preprocessing/depth_order.py b/src/spikeinterface/preprocessing/depth_order.py index 55e34ba5dd..7c4259d9af 100644 --- a/src/spikeinterface/preprocessing/depth_order.py +++ b/src/spikeinterface/preprocessing/depth_order.py @@ -14,10 +14,10 @@ class DepthOrderRecording(ChannelSliceRecording): The recording to re-order. channel_ids : list/array or None If given, a subset of channels to order locations for - dimensions : str, tuple, list - If str, it needs to be 'x', 'y', 'z'. + dimensions : str or tuple, list, default: ("x", "y") + If str, it needs to be "x", "y", "z". If tuple or list, it sorts the locations in two dimensions using lexsort. - This approach is recommended since there is less ambiguity, by default ('x', 'y') + This approach is recommended since there is less ambiguity flip: bool, default: False If flip is False then the order is bottom first (starting from tip of the probe). If flip is True then the order is upper first. diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index e6e2836a35..a162cfe636 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -50,49 +50,43 @@ def detect_bad_channels( ---------- recording : BaseRecording The recording for which bad channels are detected - method : str - The method to be used: - - * coeherence+psd (default, developed by IBL) - * mad - * std - std_mad_threshold (mstd) : float - (method std, mad) + method : "coeherence+psd" | "std" | "mad" | "neighborhood_r2", default: "coeherence+psd" + The method to be used for bad channel detection + std_mad_threshold : float, default: 5 The standard deviation/mad multiplier threshold - psd_hf_threshold (coeherence+psd) : float + psd_hf_threshold (coeherence+psd) : float, default: 0.02 An absolute threshold (uV^2/Hz) used as a cutoff for noise channels. Channels with average power at >80% Nyquist larger than this threshold - will be labeled as noise, by default 0.02 - dead_channel_threshold (coeherence+psd) : float, optional - Threshold for channel coherence below which channels are labeled as dead, by default -0.5 - noisy_channel_threshold (coeherence+psd) : float - Threshold for channel coherence above which channels are labeled as noisy (together with psd condition), - by default 1 - outside_channel_threshold (coeherence+psd) : float + will be labeled as noise + dead_channel_threshold (coeherence+psd) : float, default: -0.5 + Threshold for channel coherence below which channels are labeled as dead + noisy_channel_threshold (coeherence+psd) : float, default: 1 + Threshold for channel coherence above which channels are labeled as noisy (together with psd condition) + outside_channel_threshold (coeherence+psd) : float, default: -0.75 Threshold for channel coherence above which channels at the edge of the recording are marked as outside - of the brain, by default -0.75 - n_neighbors (coeherence+psd) : int - Number of channel neighbors to compute median filter (needs to be odd), by default 11 - nyquist_threshold (coeherence+psd) : float + of the brain + n_neighbors (coeherence+psd) : int, default: 11 + Number of channel neighbors to compute median filter (needs to be odd) + nyquist_threshold (coeherence+psd) : float, default: 0.8 Frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared - with psd_hf_threshold, by default 0.8 - direction (coeherence+psd): str - 'x', 'y', 'z', the depth dimension, by default 'y' - highpass_filter_cutoff : float - If the recording is not filtered, the cutoff frequency of the highpass filter, by default 300 - chunk_duration_s : float - Duration of each chunk, by default 0.5 - num_random_chunks : int - Number of random chunks, by default 100 + with psd_hf_threshold + direction (coeherence+psd): "x" | "y" | "z", default: "y" + The depth dimension + highpass_filter_cutoff : float, default: 300 + If the recording is not filtered, the cutoff frequency of the highpass filter + chunk_duration_s : float, default: 0.5 + Duration of each chunk + num_random_chunks : int, default: 100 + Number of random chunks Having many chunks is important for reproducibility. - welch_window_ms : float - Window size for the scipy.signal.welch that will be converted to nperseg, by default 10ms - neighborhood_r2_threshold : float, default 0.95 + welch_window_ms : float, default: 10 + Window size for the scipy.signal.welch that will be converted to nperseg + neighborhood_r2_threshold : float, default: 0.95 R^2 threshold for the neighborhood_r2 method. - neighborhood_r2_radius_um : float, default 30 + neighborhood_r2_radius_um : float, default: 30 Spatial radius below which two channels are considered neighbors in the neighborhood_r2 method. - seed : int or None - The random seed to extract chunks, by default None + seed : int or None, default: None + The random seed to extract chunks Returns ------- @@ -294,19 +288,19 @@ def detect_bad_channels_ibl( psd_hf_threshold : float Threshold for high frequency PSD. If mean PSD above `nyquist_threshold` * fn is greater than this value, channels are flagged as noisy (together with channel coherence condition). - dead_channel_thr : float, optional - Threshold for channel coherence below which channels are labeled as dead, by default -0.5 - noisy_channel_thr : float - Threshold for channel coherence above which channels are labeled as noisy (together with psd condition), - by default -0.5 - outside_channel_thr : float + dead_channel_thr : float, default: -0.5 + Threshold for channel coherence below which channels are labeled as dead + noisy_channel_thr : float, default: 1 + Threshold for channel coherence above which channels are labeled as noisy (together with psd condition) + outside_channel_thr : float, default: -0.75 Threshold for channel coherence above which channels - n_neighbors : int, optional - Number of neighbors to compute median fitler, by default 11 - nyquist_threshold : float, optional - Threshold on Nyquist frequency to calculate HF noise band, by default 0.8 - welch_window_ms: float - Window size for the scipy.signal.welch that will be converted to nperseg, by default 10ms + n_neighbors : int, default: 11 + Number of neighbors to compute median fitler + nyquist_threshold : float, default: 0.8 + Threshold on Nyquist frequency to calculate HF noise band + welch_window_ms: float, default: 0.3 + Window size for the scipy.signal.welch that will be converted to nperseg + Returns ------- 1d array diff --git a/src/spikeinterface/preprocessing/directional_derivative.py b/src/spikeinterface/preprocessing/directional_derivative.py index d74b2a71ef..48bcf77d7f 100644 --- a/src/spikeinterface/preprocessing/directional_derivative.py +++ b/src/spikeinterface/preprocessing/directional_derivative.py @@ -18,8 +18,8 @@ def __init__( ): """Take derivative of any `order` along `direction` - np.gradient is applied independently along each colum (direction='y') - or row (direction='x'). Accounts for channel spacings and boundary + np.gradient is applied independently along each colum (direction="y") + or row (direction="x"). Accounts for channel spacings and boundary issues using np.gradient -- see that function's documentation for more information about `edge_order`. @@ -30,15 +30,15 @@ def __init__( ---------- recording : BaseRecording recording to zero-pad - direction : str - Gradients will be taken along this dimension. - order : int - np.gradient will be applied this many times. - edge_order : int + direction : "x" | "y" | "z", default: "y" + Gradients will be taken along this dimension + order : int, default: 1 + np.gradient will be applied this many times + edge_order : int, default: 1 Order of gradient accuracy at edges; see np.gradient for details. - dtype : optional numpy dtype - If unset, parent dtype is preserved, but the derivative can - overflow or lose accuracy, so "float32" by default. + dtype : numpy dtype or None, default: "float32" + If None, parent dtype is preserved, but the derivative can + overflow or lose accuracy """ parent_channel_locations = recording.get_channel_locations() dim = ["x", "y", "z"].index(direction) diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index b31088edf7..1d6947be79 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -10,10 +10,10 @@ * filter_order: order The order of the filter - * filter_mode: 'sos or 'ba' - 'sos' is bi quadratic and more stable than ab so thery are prefered. + * filter_mode: "sos or "ba" + "sos" is bi quadratic and more stable than ab so thery are prefered. * ftype: str - Filter type for iirdesign ('butter' / 'cheby1' / ... all possible of scipy.signal.iirdesign) + Filter type for iirdesign ("butter" / "cheby1" / ... all possible of scipy.signal.iirdesign) """ @@ -30,20 +30,20 @@ class FilterRecording(BasePreprocessor): ---------- recording: Recording The recording extractor to be re-referenced - band: float or list - If float, cutoff frequency in Hz for 'highpass' filter type - If list. band (low, high) in Hz for 'bandpass' filter type - btype: str - Type of the filter ('bandpass', 'highpass') - margin_ms: float + band: float or list, default: [300.0, 6000.0] + If float, cutoff frequency in Hz for "highpass" filter type + If list. band (low, high) in Hz for "bandpass" filter type + btype: "bandpass" | "highpass", default: "bandpass" + Type of the filter + margin_ms: float, default: 5.0 Margin in ms on border to avoid border effect - filter_mode: str 'sos' or 'ba' + filter_mode: "sos" | "ba", default: "sos" Filter form of the filter coefficients: - - second-order sections (default): 'sos' - - numerator/denominator: 'ba' - coef: ndarray or None + - second-order sections ("sos") + - numerator/denominator: ("ba") + coef: array or None, default: None Filter coefficients in the filter_mode form. - dtype: dtype or None + dtype: dtype or None, default: None The dtype of the returned traces. If None, the dtype of the parent recording is used {} diff --git a/src/spikeinterface/preprocessing/filter_opencl.py b/src/spikeinterface/preprocessing/filter_opencl.py index d3a08297c6..e16e2cfd08 100644 --- a/src/spikeinterface/preprocessing/filter_opencl.py +++ b/src/spikeinterface/preprocessing/filter_opencl.py @@ -18,7 +18,7 @@ class FilterOpenCLRecording(BasePreprocessor): """ Simple implementation of FilterRecording in OpenCL. - Only filter_mode='sos' is supported. + Only filter_mode="sos" is supported. Author : Samuel Garcia This kernel is ported from "tridesclous" @@ -29,9 +29,9 @@ class FilterOpenCLRecording(BasePreprocessor): The recording extractor to be re-referenced N: order - filter_mode: 'sos' only + filter_mode: "sos" only - ftypestr: 'butter' / 'cheby1' / ... all possible of scipy.signal.iirdesign + ftypestr: "butter" / "cheby1" / ... all possible of scipy.signal.iirdesign margin: margin in second on border to avoid border effect diff --git a/src/spikeinterface/preprocessing/highpass_spatial_filter.py b/src/spikeinterface/preprocessing/highpass_spatial_filter.py index 4df4a409bc..37ebe6326a 100644 --- a/src/spikeinterface/preprocessing/highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/highpass_spatial_filter.py @@ -26,25 +26,25 @@ class HighpassSpatialFilterRecording(BasePreprocessor): ---------- recording : BaseRecording The parent recording - n_channel_pad : int + n_channel_pad : int, default: 60 Number of channels to pad prior to filtering. Channels are padded with mirroring. - If None, no padding is applied, by default 60 - n_channel_taper : int + If None, no padding is applied + n_channel_taper : int, default: 0 Number of channels to perform cosine tapering on prior to filtering. If None and n_channel_pad is set, n_channel_taper will be set to the number of padded channels. - Otherwise, the passed value will be used, by default None - direction : str - The direction in which the spatial filter is applied, by default "y" - apply_agc : bool - It True, Automatic Gain Control is applied, by default True - agc_window_length_s : float - Window in seconds to compute Hanning window for AGC, by default 0.01 - highpass_butter_order : int - Order of spatial butterworth filter, by default 3 - highpass_butter_wn : float - Critical frequency (with respect to Nyquist) of spatial butterworth filter, by default 0.01 + Otherwise, the passed value will be used + direction : "x" | "y" | "z", default: "y" + The direction in which the spatial filter is applied + apply_agc : bool, default: True + It True, Automatic Gain Control is applied + agc_window_length_s : float, default: 0.1 + Window in seconds to compute Hanning window for AGC + highpass_butter_order : int, default: 3 + Order of spatial butterworth filter + highpass_butter_wn : float, default: 0.01 + Critical frequency (with respect to Nyquist) of spatial butterworth filter Returns ------- diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index 95ecd0fe52..10c5a55265 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -23,15 +23,15 @@ class InterpolateBadChannelsRecording(BasePreprocessor): The parent recording bad_channel_ids : list or 1d np.array Channel ids of the bad channels to interpolate. - sigma_um : float + sigma_um : float or None, default: None Distance between sequential channels in um. If None, will use - the most common distance between y-axis channels, by default None - p : float + the most common distance between y-axis channels + p : float, default: 1.3 Exponent of the Gaussian kernel. Determines rate of decay - for distance weightings, by default 1.3 - weights : np.array + for distance weightings + weights : np.array or None, default: None The weights to give to bad_channel_ids at interpolation. - If None, weights are automatically computed, by default None + If None, weights are automatically computed Returns ------- diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 6ab1a9afce..8672b48340 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -200,14 +200,14 @@ def correct_motion( ---------- recording: RecordingExtractor The recording extractor to be transformed - preset: str - The preset name. Default "nonrigid_accurate". - folder: Path str or None - If not None then intermediate motion info are saved into a folder. Default None - output_motion_info: bool + preset: str, default: "nonrigid_accurate" + The preset name + folder: Path str or None, default: None + If not None then intermediate motion info are saved into a folder + output_motion_info: bool, default: False If True, then the function returns a `motion_info` dictionary that contains variables to check intermediate steps (motion_histogram, non_rigid_windows, pairwise_displacement) - This dictionary is the same when reloaded from the folder. Default False + This dictionary is the same when reloaded from the folder detect_kwargs: dict Optional parameters to overwrite the ones in the preset for "detect" step. select_kwargs: dict diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index bd53866b6a..03afada380 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -35,18 +35,18 @@ class NormalizeByQuantileRecording(BasePreprocessor): ---------- recording: RecordingExtractor The recording extractor to be transformed - scalar: float + scale: float, default: 1.0 Scale for the output distribution - median: float + median: float, default: 0.0 Median for the output distribution - q1: float (default 0.01) + q1: float, default: 0.01 Lower quantile used for measuring the scale - q1: float (default 0.99) + q1: float, default: 0.99 Upper quantile used for measuring the - seed: int - Random seed for reproducibility - dtype: str or np.dtype - The dtype of the output traces. Default "float32" + mode: "by_channel" | "pool_channel", default: "by_channel" + If "by_channel" each channel is rescaled independently. + dtype: str or np.dtype, default: "float32" + The dtype of the output traces **random_chunk_kwargs: Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function Returns @@ -123,8 +123,8 @@ class ScaleRecording(BasePreprocessor): Scalar for the traces of the recording extractor or array with scalars for each channel offset: float or array Offset for the traces of the recording extractor or array with offsets for each channel - dtype: str or np.dtype - The dtype of the output traces. Default "float32" + dtype: str or np.dtype, default: "float32" + The dtype of the output traces Returns ------- @@ -179,10 +179,10 @@ class CenterRecording(BasePreprocessor): ---------- recording: RecordingExtractor The recording extractor to be centered - mode: str - 'median' (default) | 'mean' - dtype: str or np.dtype - The dtype of the output traces. Default "float32" + mode: "median" | "mean", default: "median" + The method used to center the traces + dtype: str or np.dtype, default: "float32" + The dtype of the output traces **random_chunk_kwargs: Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function Returns @@ -227,8 +227,8 @@ class ZScoreRecording(BasePreprocessor): ---------- recording: RecordingExtractor The recording extractor to be centered - mode: str - "median+mad" (default) or "mean+std" + mode: "median+mad" | "mean+std", default: "median+mad" + The mode to compute the zscore dtype: None or dtype If None the the parent dtype is kept. For integer dtype a int_scale must be also given. diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index bdba55038d..570ce48a5d 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -23,12 +23,13 @@ class PhaseShiftRecording(BasePreprocessor): ---------- recording: Recording The recording. It need to have "inter_sample_shift" in properties. - margin_ms: float (default 40) - margin in ms for computation + margin_ms: float, default: 40.0 + Margin in ms for computation. 40ms ensure a very small error when doing chunk processing - inter_sample_shift: None or numpy array - If "inter_sample_shift" is not in recording.properties - we can externaly provide one. + inter_sample_shift: None or numpy array, default: None + If "inter_sample_shift" is not in recording properties, + we can externally provide one. + Returns ------- filter_recording: PhaseShiftRecording diff --git a/src/spikeinterface/preprocessing/preprocessing_tools.py b/src/spikeinterface/preprocessing/preprocessing_tools.py index 17b05df5ad..039d054b64 100644 --- a/src/spikeinterface/preprocessing/preprocessing_tools.py +++ b/src/spikeinterface/preprocessing/preprocessing_tools.py @@ -27,20 +27,20 @@ def get_spatial_interpolation_kernel( The recording extractor to be transformed target_location: array shape (n, 2) Scale for the output distribution - method: 'kriging' or 'idw' or 'nearest' + method: "kriging" | "idw" | "nearest", default: "kriging" Choice of the method - 'kriging' : the same one used in kilosort - 'idw' : inverse distance weithed - 'nearest' : use nereast channel - sigma_um : float or list (default 20.) - Used in the 'kriging' formula. When list, it needs to have 2 elements (for the x and y directions). - p: int (default 1) - Used in the 'kriging' formula - sparse_thresh: None or float (default None) - If not None for 'kriging' force small value to be zeros to get a sparse matrix. - num_closest: int (default 3) - Used for 'idw' - force_extrapolate: bool (false by default) + "kriging" : the same one used in kilosort + "idw" : inverse distance weithed + "nearest" : use neareast channel + sigma_um : float or list, default: 20.0 + Used in the "kriging" formula. When list, it needs to have 2 elements (for the x and y directions). + p: int, default: 1 + Used in the "kriging" formula + sparse_thresh: None or float, default: None + If not None for "kriging" force small value to be zeros to get a sparse matrix. + num_closest: int, default: 3 + Used for "idw" + force_extrapolate: bool, default: False How to handle when target location are outside source location. When False : no extrapolation all target location outside are set to zero. When True : extrapolation done with the formula of the method. @@ -123,10 +123,9 @@ def get_kriging_kernel_distance(locations_1, locations_2, sigma_um, p, distance_ Scale paremter on the Gaussian kernel, typically distance between contacts in micrometers. In case sigma_um is list then this mimics the Kilosort2.5 behavior, which uses two separate sigmas for each dimension. - In the later case the metric is always a 'cityblock' + In the later case the metric is always a "cityblock" p : float - Weight parameter on the exponential function. Default - in IBL kriging interpolation is 1.3. + Weight parameter on the exponential function Results ---------- diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 1eafa48a0b..534a3fb5a4 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -9,10 +9,10 @@ class RemoveArtifactsRecording(BasePreprocessor): """ Removes stimulation artifacts from recording extractor traces. By default, - artifact periods are zeroed-out (mode = 'zeros'). This is only recommended + artifact periods are zeroed-out (mode = "zeros"). This is only recommended for traces that are centered around zero (e.g. through a prior highpass filter); if this is not the case, linear and cubic interpolation modes are - also available, controlled by the 'mode' input argument. + also available, controlled by the "mode" input argument. Note that several artifacts can be removed at once (potentially with distinct duration each), if labels are specified @@ -22,66 +22,66 @@ class RemoveArtifactsRecording(BasePreprocessor): The recording extractor to remove artifacts from list_triggers: list of lists/arrays One list per segment of int with the stimulation trigger frames - ms_before: float or None + ms_before: float or None, default: 0.5 Time interval in ms to remove before the trigger events. If None, then also ms_after must be None and a single sample is removed - ms_after: float or None + ms_after: float or None, default: 3.0 Time interval in ms to remove after the trigger events. If None, then also ms_before must be None and a single sample is removed list_labels: list of lists/arrays or None One list per segment of labels with the stimulation labels for the given artefacs. labels should be strings, for JSON serialization. - Required for 'median' and 'average' modes. - mode: str + Required for "median" and "average" modes. + mode: "zeros", "linear", "cubic", "average", "median", default: "zeros" Determines what artifacts are replaced by. Can be one of the following: - - 'zeros' (default): Artifacts are replaced by zeros. + - "zeros": Artifacts are replaced by zeros. - - 'median': The median over all artifacts is computed and subtracted for + - "median": The median over all artifacts is computed and subtracted for each occurence of an artifact - - 'average': The mean over all artifacts is computed and subtracted for each + - "average": The mean over all artifacts is computed and subtracted for each occurence of an artifact - - 'linear': Replacement are obtained through Linear interpolation between + - "linear": Replacement are obtained through Linear interpolation between the trace before and after the artifact. If the trace starts or ends with an artifact period, the gap is filled with the closest available value before or after the artifact. - - 'cubic': Cubic spline interpolation between the trace before and after + - "cubic": Cubic spline interpolation between the trace before and after the artifact, referenced to evenly spaced fit points before and after the artifact. This is an option thatcan be helpful if there are significant LFP effects around the time of the artifact, but visual inspection of fit behaviour with your chosen settings is recommended. - The spacing of fit points is controlled by 'fit_sample_spacing', with + The spacing of fit points is controlled by "fit_sample_spacing", with greater spacing between points leading to a fit that is less sensitive to high frequency fluctuations but at the cost of a less smooth continuation of the trace. If the trace starts or ends with an artifact, the gap is filled with the closest available value before or after the artifact. - fit_sample_spacing: float + fit_sample_spacing: float, default: 1.0 Determines the spacing (in ms) of reference points for the cubic spline - fit if mode = 'cubic'. Default = 1ms. Note: The actual fit samples are + fit if mode = "cubic". Note: The actual fit samples are the median of the 5 data points around the time of each sample point to avoid excessive influence from hyper-local fluctuations. - artifacts: dict - If provided (when mode is 'median' or 'average') then it must be a dict with + artifacts: dict or None, default: None + If provided (when mode is "median" or "average") then it must be a dict with keys that are the labels of the artifacts, and values the artifacts themselves, on all channels (and thus bypassing ms_before and ms_after) - sparsity: dict - If provided (when mode is 'median' or 'average') then it must be a dict with + sparsity: dict or None, default: None + If provided (when mode is "median" or "average") then it must be a dict with keys that are the labels of the artifacts, and values that are boolean mask of the channels where the artifacts should be considered (for subtraction/scaling) - scale_amplitude: False - If true, then for mode 'median' or 'average' the amplitude of the template + scale_amplitude: False, default: False + If true, then for mode "median" or "average" the amplitude of the template will be scaled in amplitude at each time occurence to minimize residuals - time_jitter: float (default 0) - If non 0, then for mode 'median' or 'average', a time jitter in ms + time_jitter: float, default: 0 + If non 0, then for mode "median" or "average", a time jitter in ms can be allowed to minimize the residuals - waveforms_kwargs: dict or None + waveforms_kwargs: dict or None, default: None The arguments passed to the WaveformExtractor object when extracting the - artifacts, for mode 'median' or 'average'. - By default, the global job kwargs are used, in addition to {'allow_unfiltered' : True, 'mode':'memory'}. + artifacts, for mode "median" or "average". + By default, the global job kwargs are used, in addition to {"allow_unfiltered" : True, "mode":"memory"}. To estimate sparse artifact Returns diff --git a/src/spikeinterface/preprocessing/resample.py b/src/spikeinterface/preprocessing/resample.py index a00893846a..cffbdb2419 100644 --- a/src/spikeinterface/preprocessing/resample.py +++ b/src/spikeinterface/preprocessing/resample.py @@ -26,12 +26,12 @@ class ResampleRecording(BasePreprocessor): The recording extractor to be re-referenced resample_rate : int The resampling frequency - margin : float (default 100) + margin : float, default: 100.0 Margin in ms for computations, will be used to decrease edge effects. - dtype : dtype or None + dtype : dtype or None, default: None The dtype of the returned traces. If None, the dtype of the parent recording is used. - skip_checks : bool - If True, checks on sampling frequencies and cutoff filter frequencies are skipped, by default False + skip_checks : bool, default: False + If True, checks on sampling frequencies and cutoff filter frequencies are skipped Returns ------- diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index c2ffcc6843..4299d199ed 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -9,7 +9,7 @@ class SilencedPeriodsRecording(BasePreprocessor): """ Silence user-defined periods from recording extractor traces. By default, - periods are zeroed-out (mode = 'zeros'). You can also fill the periods with noise. + periods are zeroed-out (mode = "zeros"). You can also fill the periods with noise. Note that both methods assume that traces that are centered around zero. If this is not the case, make sure you apply a filter or center function prior to silencing periods. @@ -21,12 +21,12 @@ class SilencedPeriodsRecording(BasePreprocessor): list_periods: list of lists/arrays One list per segment of tuples (start_frame, end_frame) to silence - mode: str + mode: "zeros" | "noise, default: "zeros" Determines what periods are replaced by. Can be one of the following: - - 'zeros' (default): Artifacts are replaced by zeros. + - "zeros": Artifacts are replaced by zeros. - - 'noise': The periods are filled with a gaussion noise that has the + - "noise": The periods are filled with a gaussion noise that has the same variance that the one in the recordings, on a per channel basis **random_chunk_kwargs: Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function diff --git a/src/spikeinterface/preprocessing/tests/test_resample.py b/src/spikeinterface/preprocessing/tests/test_resample.py index 32c1b938bf..d17617487f 100644 --- a/src/spikeinterface/preprocessing/tests/test_resample.py +++ b/src/spikeinterface/preprocessing/tests/test_resample.py @@ -38,14 +38,14 @@ def create_sinusoidal_traces(sampling_frequency=3e4, duration=30, freqs_n=10, ma Parameters ---------- - sampling_frequency : float, optional - Sampling rate of the signal, by default 3e4 - duration : int, optional - Duration of the signal in seconds, by default 30 - freqs_n : int, optional - Total frequencies to span on the signal, by default 10 - max_freq : int, optional - Maximum frequency of sinusoids, by default 10000 + sampling_frequency : float, default: 30000 + Sampling rate of the signal + duration : int, default: 30 + Duration of the signal in seconds + freqs_n : int, default: 10 + Total frequencies to span on the signal + max_freq : int, default: 10000 + Maximum frequency of sinusoids Returns ------- diff --git a/src/spikeinterface/preprocessing/tests/test_zero_padding.py b/src/spikeinterface/preprocessing/tests/test_zero_padding.py index 75d64b0088..954f5ed7e8 100644 --- a/src/spikeinterface/preprocessing/tests/test_zero_padding.py +++ b/src/spikeinterface/preprocessing/tests/test_zero_padding.py @@ -6,7 +6,7 @@ from spikeinterface.core import generate_recording from spikeinterface.core.numpyextractors import NumpyRecording -from spikeinterface.preprocessing import zero_channel_pad +from spikeinterface.preprocessing import zero_channel_pad, bandpass_filter, phase_shift from spikeinterface.preprocessing.zero_channel_pad import TracePaddedRecording if hasattr(pytest, "global_test_folder"): @@ -39,7 +39,7 @@ def test_zero_padding_channel(): @pytest.fixture def recording(): num_channels = 4 - num_samples = 10 + num_samples = 10000 rng = np.random.default_rng(seed=0) traces = rng.random(size=(num_samples, num_channels)) traces_list = [traces] @@ -258,5 +258,74 @@ def test_trace_padded_recording_retrieve_traces_with_partial_padding(recording, assert np.allclose(padded_traces_end, expected_zeros) +@pytest.mark.parametrize("padding_start, padding_end", [(5, 5), (0, 5), (5, 0), (0, 0)]) +def test_trace_padded_recording_retrieve_only_start_padding(recording, padding_start, padding_end): + num_samples = recording.get_num_samples() + num_channels = recording.get_num_channels() + + padded_recording = TracePaddedRecording( + parent_recording=recording, + padding_start=padding_start, + padding_end=padding_end, + ) + + # Retrieve the padding at the start and test it + padded_traces_start = padded_recording.get_traces(start_frame=0, end_frame=padding_start) + expected_traces = np.zeros((padding_start, num_channels)) + assert np.allclose(padded_traces_start, expected_traces) + + +@pytest.mark.parametrize("padding_start, padding_end", [(5, 5), (0, 5), (5, 0), (0, 0)]) +def test_trace_padded_recording_retrieve_only_end_padding(recording, padding_start, padding_end): + num_samples = recording.get_num_samples() + num_channels = recording.get_num_channels() + + padded_recording = TracePaddedRecording( + parent_recording=recording, + padding_start=padding_start, + padding_end=padding_end, + ) + + # Retrieve the padding at the end and test it + start_frame = padding_start + num_samples + end_frame = padding_start + num_samples + padding_end + padded_traces_end = padded_recording.get_traces(start_frame=start_frame, end_frame=end_frame) + expected_traces = np.zeros((padding_end, num_channels)) + assert np.allclose(padded_traces_end, expected_traces) + + +@pytest.mark.parametrize("preprocessing", ["bandpass_filter", "phase_shift"]) +@pytest.mark.parametrize("padding_start, padding_end", [(5, 5), (0, 5), (5, 0), (0, 0)]) +def test_trace_padded_recording_retrieve_only_end_padding_with_preprocessing( + recording, padding_start, padding_end, preprocessing +): + """This is a tmeporary test to check that this works when the recording is called out of bonds. It should be removed + when more general test are added in that direction""" + + num_samples = recording.get_num_samples() + num_channels = recording.get_num_channels() + + if preprocessing == "bandpass_filter": + recording = bandpass_filter(recording, freq_min=300, freq_max=6000) + else: + sample_shift_size = 0.4 + inter_sample_shift = np.arange(recording.get_num_channels()) * sample_shift_size + recording.set_property("inter_sample_shift", inter_sample_shift) + recording = phase_shift(recording) + + padded_recording = TracePaddedRecording( + parent_recording=recording, + padding_start=padding_start, + padding_end=padding_end, + ) + + # Retrieve the padding at the end and test it + start_frame = padding_start + num_samples + end_frame = padding_start + num_samples + padding_end + padded_traces_end = padded_recording.get_traces(start_frame=start_frame, end_frame=end_frame) + expected_traces = np.zeros((padding_end, num_channels)) + assert np.allclose(padded_traces_end, expected_traces) + + if __name__ == "__main__": test_zero_padding_channel() diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index ac80f58182..3bea9b91bb 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -18,11 +18,11 @@ class WhitenRecording(BasePreprocessor): dtype: None or dtype, default: None If None the the parent dtype is kept. For integer dtype a int_scale must be also given. - mode: 'global' / 'local', default: 'global' - 'global' use the entire covariance matrix to compute the W matrix - 'local' use local covariance (by radius) to compute the W matrix + mode: "global" | "local", default: "global" + "global" use the entire covariance matrix to compute the W matrix + "local" use local covariance (by radius) to compute the W matrix radius_um: None or float, default: None - Used for mode = 'local' to get the neighborhood + Used for mode = "local" to get the neighborhood apply_mean: bool, default: False Substract or not the mean matrix M before the dot product with W. int_scale : None or float, default: None @@ -33,7 +33,7 @@ class WhitenRecording(BasePreprocessor): Small epsilon to regularize SVD. If None, eps is defaulted to 1e-8. If the data is float type and scaled down to very small values, then the eps is automatically set to a small fraction (1e-3) of the median of the squared data. - W : 2d np.array, default: None + W : 2d np.array or None, default: None Pre-computed whitening matrix M : 1d np.array or None, default: None Pre-computed means. @@ -138,16 +138,16 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r mode : str The mode to compute the whitening matrix. - * 'global': compute SVD using all channels - * 'local': compute SVD on local neighborhood (controlled by `radius_um`) + * "global": compute SVD using all channels + * "local": compute SVD on local neighborhood (controlled by `radius_um`) random_chunk_kwargs : dict Keyword arguments for get_random_data_chunks() apply_mean : bool If True, the mean is removed prior to computing the covariance - radius_um : float, default: None - Used for mode = 'local' to get the neighborhood - eps : float, default: None + radius_um : float or None, default: None + Used for mode = "local" to get the neighborhood + eps : float or None, default: None Small epsilon to regularize SVD. If None, the default is set to 1e-8, but if the data is float type and scaled down to very small values, eps is automatically set to a small fraction (1e-3) of the median of the squared data. diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index 124b2b080e..c1ed31f508 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -17,12 +17,14 @@ class TracePaddedRecording(BasePreprocessor): ---------- parent_recording_segment : BaseRecording The parent recording segment from which the traces are to be retrieved. - padding_start : int - The amount of padding to add to the left of the traces. Default is 0. It has to be non-negative - padding_end : int - The amount of padding to add to the right of the traces. Default is 0. It has to be non-negative - fill_value: float - The value to pad with. Default is 0. + padding_start : int, default: 0 + The amount of padding to add to the left of the traces. It has to be non-negative. + Note that this counts the number of samples, not the number of seconds. + padding_end : int, default: 0 + The amount of padding to add to the right of the traces. It has to be non-negative + Note that this counts the number of samples, not the number of seconds + fill_value: float, default: 0 + The value to pad with """ def __init__( @@ -88,13 +90,18 @@ def get_traces(self, start_frame, end_frame, channel_indices): raise ValueError(f"Unsupported channel_indices type: {type(channel_indices)} raise an issue on github ") # This avoids an extra memory allocation if we are within the confines of the old traces - if start_frame > self.padding_start and end_frame < self.num_samples_in_original_segment + self.padding_start: + end_of_original_traces = self.num_samples_in_original_segment + self.padding_start + if start_frame > self.padding_start and end_frame < end_of_original_traces: return self.get_original_traces_shifted(start_frame, end_frame, channel_indices) - # Else, we start with the full padded traces and allocate the original traces in the middle + # We start with the full padded traces and fill in the original traces if necessary output_traces = np.full(shape=(trace_size, num_channels), fill_value=self.fill_value, dtype=self.dtype) - # After the padding, the original traces are placed in the middle until the end of the original traces + # If start frame is larger than the end of the original traces, we return the padded traces as they are + if start_frame >= end_of_original_traces and end_frame > end_of_original_traces: + return output_traces + + # We add the original traces if end_frame is larger than the start of the original traces if end_frame >= self.padding_start: original_traces = self.get_original_traces_shifted( start_frame=start_frame, @@ -119,6 +126,7 @@ def get_original_traces_shifted(self, start_frame, end_frame, channel_indices): """ original_start_frame = max(start_frame - self.padding_start, 0) original_end_frame = min(end_frame - self.padding_start, self.num_samples_in_original_segment) + original_traces = self.parent_recording_segment.get_traces( start_frame=original_start_frame, end_frame=original_end_frame, @@ -145,9 +153,8 @@ def __init__(self, parent_recording: BaseRecording, num_channels: int, channel_m recording to zero-pad num_channels : int Total number of channels in the zero-channel-padded recording - channel_mapping : Union[list, None], optional - Mapping from the channel index in the original recording to the zero-channel-padded recording, - by default None. + channel_mapping : Union[list, None], default: None + Mapping from the channel index in the original recording to the zero-channel-padded recording. If None, sorts the channel indices in ascending y channel location and puts them at the beginning of the zero-channel-padded recording. """ diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index d3f875959e..5c734b9100 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -184,9 +184,9 @@ def compute_snrs( ---------- waveform_extractor : WaveformExtractor The waveform extractor object. - peak_sign : {'neg', 'pos', 'both'} + peak_sign : "neg" | "pos" | "both", default: "neg" The sign of the template to compute best channels. - peak_mode: {'extremum', 'at_index'} + peak_mode: "extremum" | "at_index", default: "extremum" How to compute the amplitude. Extremum takes the maxima/minima At_index takes the value at t=waveform_extractor.nbefore @@ -251,7 +251,7 @@ def compute_isi_violations(waveform_extractor, isi_threshold_ms=1.5, min_isi_ms= The waveform extractor object isi_threshold_ms : float, default: 1.5 Threshold for classifying adjacent spikes as an ISI violation, in ms. - This is the biophysical refractory period (default=1.5). + This is the biophysical refractory period min_isi_ms : float, default: 0 Minimum possible inter-spike interval, in ms. This is the artificial refractory period enforced @@ -422,19 +422,19 @@ def compute_sliding_rp_violations( ---------- waveform_extractor : WaveformExtractor The waveform extractor object. - min_spikes : int, default 0 + min_spikes : int, default: 0 Contamination is set to np.nan if the unit has less than this many spikes across all segments. - bin_size_ms : float - The size of binning for the autocorrelogram in ms, by default 0.25 - window_size_s : float - Window in seconds to compute correlogram, by default 1 - exclude_ref_period_below_ms : float - Refractory periods below this value are excluded, by default 0.5 - max_ref_period_ms : float - Maximum refractory period to test in ms, by default 10 ms - contamination_values : 1d array or None - The contamination values to test, by default np.arange(0.5, 35, 0.5) % + bin_size_ms : float, default: 0.25 + The size of binning for the autocorrelogram in ms + window_size_s : float, default: 1 + Window in seconds to compute correlogram + exclude_ref_period_below_ms : float, default: 0.5 + Refractory periods below this value are excluded + max_ref_period_ms : float, default: 10 + Maximum refractory period to test in ms + contamination_values : 1d array or None, default: None + The contamination values to test, If None, it is set to np.arange(0.5, 35, 0.5) unit_ids : list or None List of unit ids to compute the sliding RP violations. If None, all units are used. @@ -659,8 +659,8 @@ def compute_amplitude_cv_metrics( min_num_bins : int, default: 10 The minimum number of bins to compute the median and range. If the number of bins is less than this then the median and range are set to NaN. - amplitude_extension : str, default: 'spike_amplitudes' - The name of the extension to load the amplitudes from. 'spike_amplitudes' or 'amplitude_scalings'. + amplitude_extension : str, default: "spike_amplitudes" + The name of the extension to load the amplitudes from. "spike_amplitudes" or "amplitude_scalings". unit_ids : list or None List of unit ids to compute the amplitude spread. If None, all units are used. @@ -760,7 +760,7 @@ def compute_amplitude_cutoffs( ---------- waveform_extractor : WaveformExtractor The waveform extractor object. - peak_sign : {'neg', 'pos', 'both'} + peak_sign : "neg" | "pos" | "both", default: "neg" The sign of the peaks. num_histogram_bins : int, default: 100 The number of bins to use to compute the amplitude histogram. @@ -856,7 +856,7 @@ def compute_amplitude_medians(waveform_extractor, peak_sign="neg", unit_ids=None ---------- waveform_extractor : WaveformExtractor The waveform extractor object. - peak_sign : {'neg', 'pos', 'both'} + peak_sign : "neg" | "pos" | "both", default: "neg" The sign of the peaks. unit_ids : list or None List of unit ids to compute the amplitude medians. If None, all units are used. @@ -929,29 +929,29 @@ def compute_drift_metrics( * drift_std: standard deviation of the drift signal * drift_mad: median absolute deviation of the drift signal - Requires 'spike_locations' extension. If this is not present, metrics are set to NaN. + Requires "spike_locations" extension. If this is not present, metrics are set to NaN. Parameters ---------- waveform_extractor : WaveformExtractor The waveform extractor object. - interval_s : int, optional - Interval length is seconds for computing spike depth, by default 60 - min_spikes_per_interval : int, optional - Minimum number of spikes for computing depth in an interval, by default 100 - direction : str, optional - The direction along which drift metrics are estimated, by default 'y' - min_fraction_valid_intervals : float, optional + interval_s : int, default: 60 + Interval length is seconds for computing spike depth + min_spikes_per_interval : int, default: 100 + Minimum number of spikes for computing depth in an interval + direction : "x" | "y" | "z", default: "y" + The direction along which drift metrics are estimated + min_fraction_valid_intervals : float, default: 0.5 The fraction of valid (not NaN) position estimates to estimate drifts. E.g., if 0.5 at least 50% of estimated positions in the intervals need to be valid, - otherwise drift metrics are set to None, by default 0.5 - min_num_bins : int, optional + otherwise drift metrics are set to None + min_num_bins : int, default: 2 Minimum number of bins required to return a valid metric value. In case there are less bins, the metric values are set to NaN. - return_positions : bool, optional - If True, median positions are returned (for debugging), by default False - unit_ids : list or None - List of unit ids to compute the drift metrics. If None, all units are used. + return_positions : bool, default: False + If True, median positions are returned (for debugging) + unit_ids : list or None, default: None + List of unit ids to compute the drift metrics. If None, all units are used Returns ------- @@ -1094,7 +1094,7 @@ def presence_ratio(spike_train, total_length, bin_edges=None, num_bin_edges=None num_bin_edges : int, default: 101 The number of bins edges to use to compute the presence ratio. (mutually exclusive with bin_edges). - bin_n_spikes_thres: int, default 0 + bin_n_spikes_thres: int, default: 0 Minimum number of spikes within a bin to consider the unit active Returns @@ -1128,7 +1128,7 @@ def isi_violations(spike_trains, total_duration_s, isi_threshold_s=0.0015, min_i The total duration of the recording (in seconds) isi_threshold_s : float, default: 0.0015 Threshold for classifying adjacent spikes as an ISI violation, in seconds. - This is the biophysical refractory period (default=1.5). + This is the biophysical refractory period min_isi_s : float, default: 0 Minimum possible inter-spike interval, in seconds. This is the artificial refractory period enforced @@ -1179,7 +1179,7 @@ def amplitude_cutoff(amplitudes, num_histogram_bins=500, histogram_smoothing_val ---------- amplitudes : ndarray_like The amplitudes (in uV) of the spikes for one unit. - peak_sign : {'neg', 'pos', 'both'} + peak_sign : "neg" | "pos" | "both", default: "neg" The sign of the template to compute best channels. num_histogram_bins : int, default: 500 The number of bins to use to compute the amplitude histogram. @@ -1249,16 +1249,16 @@ def slidingRP_violations( The acquisition sampling rate bin_size_ms : float The size (in ms) of binning for the autocorrelogram. - window_size_s : float - Window in seconds to compute correlogram, by default 2 - exclude_ref_period_below_ms : float - Refractory periods below this value are excluded, by default 0.5 - max_ref_period_ms : float - Maximum refractory period to test in ms, by default 10 ms - contamination_values : 1d array or None - The contamination values to test, by default np.arange(0.5, 35, 0.5) / 100 - return_conf_matrix : bool - If True, the confidence matrix (n_contaminations, n_ref_periods) is returned, by default False + window_size_s : float, default: 1 + Window in seconds to compute correlogram + exclude_ref_period_below_ms : float, default: 0.5 + Refractory periods below this value are excluded + max_ref_period_ms : float, default: 10 + Maximum refractory period to test in ms + contamination_values : 1d array or None, default: None + The contamination values to test, if None it is set to np.arange(0.5, 35, 0.5) / 100 + return_conf_matrix : bool, default: False + If True, the confidence matrix (n_contaminations, n_ref_periods) is returned Code adapted from: https://github.com/SteinmetzLab/slidingRefractory/blob/master/python/slidingRP/metrics.py#L166 diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index ed06f7d738..6b13b1acbf 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -1,6 +1,5 @@ """Cluster quality metrics computed from principal components.""" -from cmath import nan from copy import deepcopy import numpy as np @@ -16,7 +15,6 @@ except: pass -import spikeinterface as si from ..core import get_random_data_chunks, compute_sparsity, WaveformExtractor from ..core.job_tools import tqdm_joblib from ..core.template_tools import get_template_extremum_channel @@ -72,12 +70,12 @@ def calculate_pc_metrics( ---------- pca : WaveformPrincipalComponent Waveform object with principal components computed. - metric_names : list of str, optional + metric_names : list of str, default: None The list of PC metrics to compute. If not provided, defaults to all PC metrics. - sparsity: ChannelSparsity or None + sparsity: ChannelSparsity or None, default: None The sparsity object. This is used also to identify neighbor - units and speed up computations. If None (default) all channels and all units are used + units and speed up computations. If None all channels and all units are used for each unit. qm_params : dict or None Dictionary with parameters for each PC metric function. @@ -393,11 +391,11 @@ def nearest_neighbors_isolation( Recomputed if None. max_spikes : int, default: 1000 Max number of spikes to use per unit. - min_spikes : int, optional, default: 10 + min_spikes : int, default: 10 Min number of spikes a unit must have to go through with metric computation. Units with spikes < min_spikes gets numpy.NaN as the quality metric, and are ignored when selecting other units' neighbors. - min_fr : float, optional, default: 0.0 + min_fr : float, default: 0.0 Min firing rate a unit must have to go through with metric computation. Units with firing rate < min_fr gets numpy.NaN as the quality metric, and are ignored when selecting other units' neighbors. @@ -407,7 +405,7 @@ def nearest_neighbors_isolation( The number of PC components to use to project the snippets to. radius_um : float, default: 100 The radius, in um, that channels need to be within the peak channel to be included. - peak_sign: str, default: 'neg' + peak_sign: "neg" | "pos" | "both", default: "neg" The peak_sign used to compute sparsity and neighbor units. Used if waveform_extractor is not sparse already. min_spatial_overlap : float, default: 100 @@ -599,10 +597,10 @@ def nearest_neighbors_noise_overlap( Recomputed if None. max_spikes : int, default: 1000 The max number of spikes to use per cluster. - min_spikes : int, optional, default: 10 + min_spikes : int, default: 10 Min number of spikes a unit must have to go through with metric computation. Units with spikes < min_spikes gets numpy.NaN as the quality metric. - min_fr : float, optional, default: 0.0 + min_fr : float, default: 0.0 Min firing rate a unit must have to go through with metric computation. Units with firing rate < min_fr gets numpy.NaN as the quality metric. n_neighbors : int, default: 5 @@ -611,7 +609,7 @@ def nearest_neighbors_noise_overlap( The number of PC components to use to project the snippets to. radius_um : float, default: 100 The radius, in um, that channels need to be within the peak channel to be included. - peak_sign: str, default: 'neg' + peak_sign: "neg" | "pos" | "both", default: "neg" The peak_sign used to compute sparsity and neighbor units. Used if waveform_extractor is not sparse already. seed : int, default: 0 diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index a2d0cc41b0..53309db282 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -198,10 +198,10 @@ def compute_quality_metrics( qm_params : dict or None Dictionary with parameters for quality metrics calculation. Default parameters can be obtained with: `si.qualitymetrics.get_default_qm_params()` - sparsity : dict or None + sparsity : dict or None, default: None If given, the sparse channel_ids for each unit in PCA metrics computation. This is used also to identify neighbor units and speed up computations. - If None (default) all channels and all units are used for each unit. + If None all channels and all units are used for each unit. skip_pc_metrics : bool If True, PC metrics computation is skipped. n_jobs : int diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 73bbee611b..eb8317e4df 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -210,7 +210,7 @@ def test_peak_sign(self): # invert recording rec_inv = scale(rec, gain=-1.0) - we_inv = extract_waveforms(rec_inv, sort, self.cache_folder / "toy_waveforms_inv") + we_inv = extract_waveforms(rec_inv, sort, self.cache_folder / "toy_waveforms_inv", seed=0) # compute amplitudes _ = compute_spike_amplitudes(we, peak_sign="neg") diff --git a/src/spikeinterface/qualitymetrics/utils.py b/src/spikeinterface/qualitymetrics/utils.py index 741308270b..4f2195b1a9 100644 --- a/src/spikeinterface/qualitymetrics/utils.py +++ b/src/spikeinterface/qualitymetrics/utils.py @@ -1,5 +1,5 @@ import numpy as np -from scipy.stats import norm, multivariate_normal +from scipy.stats import multivariate_normal def create_ground_truth_pc_distributions(center_locations, total_points): diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 139f15bf12..894918cbc4 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -140,8 +140,9 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo if recording.check_serializablility("json"): recording.dump(rec_file, relative_to=output_folder) elif recording.check_serializablility("pickle"): - recording.dump(output_folder / "spikeinterface_recording.pickle") + recording.dump(output_folder / "spikeinterface_recording.pickle", relative_to=output_folder) else: + # TODO: deprecate and finally remove this after 0.100 d = {"warning": "The recording is not serializable to json"} rec_file.write_text(json.dumps(d, indent=4), encoding="utf8") @@ -203,7 +204,7 @@ def load_recording_from_folder(cls, output_folder, with_warnings=False): else: recording = load_extractor(json_file, base_folder=output_folder) elif pickle_file.exists(): - recording = load_extractor(pickle_file) + recording = load_extractor(pickle_file, base_folder=output_folder) return recording diff --git a/src/spikeinterface/sorters/external/kilosort2.py b/src/spikeinterface/sorters/external/kilosort2.py index 00ab3fbde5..ee9314b488 100644 --- a/src/spikeinterface/sorters/external/kilosort2.py +++ b/src/spikeinterface/sorters/external/kilosort2.py @@ -70,7 +70,7 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): "AUCsplit": "Threshold on the area under the curve (AUC) criterion for performing a split in the final step", "wave_length": "size of the waveform extracted around each detected peak, (Default 61, maximum 81)", "keep_good_only": "If True only 'good' units are returned", - "skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing", + "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_rez_to_mat": "Save the full rez internal struc to mat file", "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that " diff --git a/src/spikeinterface/sorters/external/kilosort2_5.py b/src/spikeinterface/sorters/external/kilosort2_5.py index dd9130b9ae..bf70e7b41d 100644 --- a/src/spikeinterface/sorters/external/kilosort2_5.py +++ b/src/spikeinterface/sorters/external/kilosort2_5.py @@ -80,7 +80,7 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "AUCsplit": "Threshold on the area under the curve (AUC) criterion for performing a split in the final step", "keep_good_only": "If True only 'good' units are returned", "wave_length": "size of the waveform extracted around each detected peak, (Default 61, maximum 81)", - "skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing", + "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_rez_to_mat": "Save the full rez internal struc to mat file", "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that " diff --git a/src/spikeinterface/sorters/external/kilosort3.py b/src/spikeinterface/sorters/external/kilosort3.py index 77267620fa..bce8f181b9 100644 --- a/src/spikeinterface/sorters/external/kilosort3.py +++ b/src/spikeinterface/sorters/external/kilosort3.py @@ -77,7 +77,7 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): "AUCsplit": "Threshold on the area under the curve (AUC) criterion for performing a split in the final step", "wave_length": "size of the waveform extracted around each detected peak, (Default 61, maximum 81)", "keep_good_only": "If True only 'good' units are returned", - "skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing", + "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_rez_to_mat": "Save the full rez internal struc to mat file", "delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that " diff --git a/src/spikeinterface/sorters/launcher.py b/src/spikeinterface/sorters/launcher.py index e7fdedcfe7..15098c8430 100644 --- a/src/spikeinterface/sorters/launcher.py +++ b/src/spikeinterface/sorters/launcher.py @@ -212,13 +212,13 @@ def run_sorter_by_property( **sorter_params, ): """ - Generic function to run a sorter on a recording after splitting by a 'grouping_property' (e.g. 'group'). + Generic function to run a sorter on a recording after splitting by a "grouping_property" (e.g. "group"). Internally, the function works as follows: - * the recording is split based on the provided 'grouping_property' (using the 'split_by' function) - * the 'run_sorters' function is run on the split recordings - * sorting outputs are aggregated using the 'aggregate_units' function - * the 'grouping_property' is added as a property to the SortingExtractor + * the recording is split based on the provided "grouping_property" (using the "split_by" function) + * the "run_sorters" function is run on the split recordings + * sorting outputs are aggregated using the "aggregate_units" function + * the "grouping_property" is added as a property to the SortingExtractor Parameters ---------- @@ -230,23 +230,23 @@ def run_sorter_by_property( Property to split by before sorting working_folder: str The working directory. - mode_if_folder_exists: None + mode_if_folder_exists: bool or None, default: None Must be None. This is deprecated. If not None then a warning is raise. Will be removed in next release. - engine: {'loop', 'joblib', 'dask'} + engine: "loop" | "joblib" | "dask", default: "loop" Which engine to use to run sorter. engine_kwargs: dict This contains kwargs specific to the launcher engine: - * 'loop' : no kwargs - * 'joblib' : {'n_jobs' : } number of processes - * 'dask' : {'client':} the dask client for submitting task - verbose: bool - default True - docker_image: None or str - If str run the sorter inside a container (docker) using the docker package. + * "loop" : no kwargs + * "joblib" : {"n_jobs" : } number of processes + * "dask" : {"client":} the dask client for submitting task + verbose: bool, default: False + Controls sorter verboseness + docker_image: None or str, default: None + If str run the sorter inside a container (docker) using the docker package **sorter_params: keyword args - Spike sorter specific arguments (they can be retrieved with 'get_default_params(sorter_name_or_class)' + Spike sorter specific arguments (they can be retrieved with `get_default_sorter_params(sorter_name_or_class)`) Returns ------- @@ -255,7 +255,7 @@ def run_sorter_by_property( Examples -------- - This example shows how to run spike sorting split by group using the 'joblib' backend with 4 jobs for parallel + This example shows how to run spike sorting split by group using the "joblib" backend with 4 jobs for parallel processing. >>> sorting = si.run_sorter_by_property("tridesclous", recording, grouping_property="group", @@ -334,18 +334,18 @@ def run_sorters( The working directory. sorter_params: dict of dict with sorter_name as key This allow to overwrite default params for sorter. - mode_if_folder_exists: {'raise', 'overwrite', 'keep'} + mode_if_folder_exists: "raise" | "overwrite" | "keep", default: "raise" The mode when the subfolder of recording/sorter already exists. - * 'raise' : raise error if subfolder exists - * 'overwrite' : delete and force recompute - * 'keep' : do not compute again if f=subfolder exists and log is OK - engine: {'loop', 'joblib', 'dask'} + * "raise" : raise error if subfolder exists + * "overwrite" : delete and force recompute + * "keep" : do not compute again if f=subfolder exists and log is OK + engine: "loop" | "joblib" | "dask", default: "loop" Which engine to use to run sorter. engine_kwargs: dict This contains kwargs specific to the launcher engine: - * 'loop' : no kwargs - * 'joblib' : {'n_jobs' : } number of processes - * 'dask' : {'client':} the dask client for submitting task + * "loop" : no kwargs + * "joblib" : {"n_jobs" : } number of processes + * "dask" : {"client":} the dask client for submitting task verbose: bool Controls sorter verboseness. with_output: bool diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index a49a605a75..ee788b8611 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -57,28 +57,27 @@ Path to output folder remove_existing_folder: bool If True and output_folder exists yet then delete. - delete_output_folder: bool - If True, output folder is deleted (default False) - verbose: bool + delete_output_folder: bool, default: False + If True, output folder is deleted + verbose: bool, default: False If True, output is verbose - raise_error: bool - If True, an error is raised if spike sorting fails (default). + raise_error: bool, default: True + If True, an error is raised if spike sorting fails If False, the process continues and the error is logged in the log file. - docker_image: bool or str + docker_image: bool or str, default: False If True, pull the default docker container for the sorter and run the sorter in that container using docker. Use a str to specify a non-default container. If that container is not local it will be pulled from docker hub. - If False, the sorter is run locally. - singularity_image: bool or str + If False, the sorter is run locally + singularity_image: bool or str, default: False If True, pull the default docker container for the sorter and run the sorter in that container using singularity. Use a str to specify a non-default container. If that container is not local it will be pulled - from Docker Hub. - If False, the sorter is run locally. - delete_container_files: bool - If True, the container temporary files are deleted after the sorting is done (default False). - with_output: bool - If True, the output Sorting is returned as a Sorting (default True). + from Docker Hub. If False, the sorter is run locally + delete_container_files: bool, default: True + If True, the container temporary files are deleted after the sorting is done + with_output: bool, default: True + If True, the output Sorting is returned as a Sorting **sorter_params: keyword args - Spike sorter specific arguments (they can be retrieved with 'get_default_params(sorter_name_or_class)' + Spike sorter specific arguments (they can be retrieved with `get_default_sorter_params(sorter_name_or_class)`) Returns ------- @@ -229,8 +228,8 @@ def __init__(self, mode, container_image, volumes, py_user_base, extra_kwargs): """ Parameters ---------- - mode: str - "docker" or "singularity" strings + mode: "docker" | "singularity" + The container mode container_image: str container image name and tag volumes: dict @@ -352,18 +351,30 @@ def run_sorter_container( Parameters ---------- sorter_name: str + The sorter name recording: BaseRecording + The recording extractor to be spike sorted mode: str - container_image: str, optional - output_folder: str, optional - remove_existing_folder: bool, optional - delete_output_folder: bool, optional - verbose: bool, optional - raise_error: bool, optional - with_output: bool, optional - delete_container_files: bool, optional - extra_requirements: list, optional - sorter_params: + The container mode: "docker" or "singularity" + container_image: str, default: None + The container image name and tag. If None, the default container image is used + output_folder: str, default: None + Path to output folder + remove_existing_folder: bool, default: True + If True and output_folder exists yet then delete + delete_output_folder: bool, default: False + If True, output folder is deleted + verbose: bool, default: False + If True, output is verbose + raise_error: bool, default: True + If True, an error is raised if spike sorting fails + with_output: bool, default: True + If True, the output Sorting is returned as a Sorting + delete_container_files: bool, default: True + If True, the container temporary files are deleted after the sorting is done + extra_requirements: list, default: None + List of extra requirements to install in the container + **sorter_params: keyword args for the sorter """ diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index d961bdbc07..9c862f4278 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -124,13 +124,13 @@ def run_matching_num_spikes(self, spike_num, seed=0, we_kwargs=None, template_mo Parameters ---------- spike_num: int - The maximum number of spikes per unit. - seed: int - Random seed. (Default: 0) + The maximum number of spikes per unit + seed: int, default: 0 + Random seed we_kwargs: dict - A dictionary of keyword arguments for the WaveformExtractor. - template_mode: {'mean' | 'median' | 'std'} - The mode to use to extract templates from the WaveformExtractor. (Default: 'median') + A dictionary of keyword arguments for the WaveformExtractor + template_mode: "mean" | "median" | "std", default: "median" + The mode to use to extract templates from the WaveformExtractor Returns ------- @@ -161,8 +161,8 @@ def update_methods_kwargs(self, we, template_mode="median"): ---------- we: WaveformExtractor The new WaveformExtractor. - template_mode: {'mean' | 'median' | 'std'} - The mode to use to extract templates from the WaveformExtractor. (Default: 'median') + template_mode: "mean" | "median" | "std", default: "median" + The mode to use to extract templates from the WaveformExtractor Returns ------- @@ -188,14 +188,14 @@ def run_matching_misclassed( ---------- fraction_misclassed: float The fraction of misclassified spikes. - min_similarity: float - The minimum cosine similarity between templates to be considered similar. (Default: -1) - seed: int - Random seed. (Default: 0) + min_similarity: float, default: -1 + The minimum cosine similarity between templates to be considered similar + seed: int, default: 0 + Random seed we_kwargs: dict - A dictionary of keyword arguments for the WaveformExtractor. - template_mode: {'mean' | 'median' | 'std'} - The mode to use to extract templates from the WaveformExtractor. (Default: 'median') + A dictionary of keyword arguments for the WaveformExtractor + template_mode: "mean" | "median" | "std", default: "median" + The mode to use to extract templates from the WaveformExtractor Returns ------- @@ -261,14 +261,14 @@ def run_matching_missing_units( ---------- fraction_missing: float The fraction of missing units. - snr_threshold: float - The SNR threshold below which units are considered missing. (Default: 0) - seed: int - Random seed. (Default: 0) + snr_threshold: float, default: 0 + The SNR threshold below which units are considered missing + seed: int, default: 0 + Random seed we_kwargs: dict A dictionary of keyword arguments for the WaveformExtractor. - template_mode: {'mean' | 'median' | 'std'} - The mode to use to extract templates from the WaveformExtractor. (Default: 'median') + template_mode: "mean" | "median" | "std", default: "median" + The mode to use to extract templates from the WaveformExtractor Returns ------- @@ -335,16 +335,16 @@ def run_matching_vary_parameter( ---------- parameters: array-like The values of the parameter to vary. - parameter_name: {'num_spikes', 'fraction_misclassed', 'fraction_missing} + parameter_name: "num_spikes", "fraction_misclassed", "fraction_missing" The name of the parameter to vary. - num_replicates: int - The number of replicates to run for each parameter value. (Default: 1) + num_replicates: int, default: 1 + The number of replicates to run for each parameter value we_kwargs: dict - A dictionary of keyword arguments for the WaveformExtractor. - template_mode: {'mean' | 'median' | 'std'} - The mode to use to extract templates from the WaveformExtractor. (Default: 'median') + A dictionary of keyword arguments for the WaveformExtractor + template_mode: "mean" | "median" | "std", default: "median" + The mode to use to extract templates from the WaveformExtractor **kwargs - Keyword arguments for the run_matching method. + Keyword arguments for the run_matching method Returns ------- @@ -438,15 +438,15 @@ def compare_all_sortings(self, matching_df, collision=False, ground_truth="from_ A dataframe of NumpySortings for each method/parameter_value/iteration combination. collision: bool If True, use the CollisionGTComparison class. If False, use the compare_sorter_to_ground_truth function. - ground_truth: {'from_self' | 'from_df'} - If 'from_self', use the ground-truth sorting stored in the BenchmarkMatching object. If 'from_df', use the + ground_truth: "from_self" | "from_df", default: "from_self" + If "from_self", use the ground-truth sorting stored in the BenchmarkMatching object. If "from_df", use the ground-truth sorting stored in the matching_df. **kwargs Keyword arguments for the comparison function. Notes ----- - This function adds a new column to the matching_df called 'comparison' that contains the GroundTruthComparison + This function adds a new column to the matching_df called "comparison" that contains the GroundTruthComparison object for each row. """ if ground_truth == "from_self": diff --git a/src/spikeinterface/sortingcomponents/clustering/main.py b/src/spikeinterface/sortingcomponents/clustering/main.py index 8902c873c7..fa1c860814 100644 --- a/src/spikeinterface/sortingcomponents/clustering/main.py +++ b/src/spikeinterface/sortingcomponents/clustering/main.py @@ -1,6 +1,6 @@ from .method_list import * -from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.job_tools import fix_job_kwargs, _shared_job_kwargs_doc def find_cluster_from_peaks(recording, peaks, method="stupid", method_kwargs={}, extra_outputs=False, **job_kwargs): @@ -15,11 +15,12 @@ def find_cluster_from_peaks(recording, peaks, method="stupid", method_kwargs={}, peaks: WaveformExtractor The waveform extractor method: str - Which method to use ('stupid' | 'XXXX') - method_kwargs: dict, optional + Which method to use ("stupid" | "XXXX") + method_kwargs: dict, default: dict() Keyword arguments for the chosen method - extra_outputs: bool + extra_outputs: bool, default: False If True then debug is also return + {} Returns ------- @@ -44,3 +45,6 @@ def find_cluster_from_peaks(recording, peaks, method="stupid", method_kwargs={}, raise NotImplementedError return labels, peak_labels + + +find_cluster_from_peaks.__doc__ = find_cluster_from_peaks.__doc__.format(_shared_job_kwargs_doc) diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 1ed51fb04f..24ec923f06 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -7,7 +7,6 @@ import scipy.spatial from sklearn.decomposition import PCA from sklearn.discriminant_analysis import LinearDiscriminantAnalysis -from hdbscan import HDBSCAN import numpy as np import networkx as nx diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py index 68b34a7041..3139f1cf85 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py @@ -47,30 +47,8 @@ class SlidingNNClustering: - - """_summary_ - - TODO: - - 2D and higher d - - Args: - recording (_type_): _description_ - peaks (_type_): _description_ - time_window_s (_type_, optional): window for sampling nearest neighbors. Defaults to 60*5. - margin_ms (int, optional): margin for chunking. Defaults to 100. - ms_before (int, optional): time prior to peak. Defaults to 1. - ms_after (int, optional): time after peak. Defaults to 1. - n_channel_neighbors (int, optional): number of neighbors per channel. Defaults to 8. - n_neighbors (int, optional): number of neighbors for graph construction. Defaults to 5. - embedding_dim (int, optional): Number of embedding dimensions. Defaults to number of channels in recording. - knn_verbose (bool, optional): whether to make knn computation verbose. Defaults to True. - low_memory (bool, optional): memory usage for nearest neighbor computation. Defaults to False. - n_jobs (int, optional): number of jobs to perform computations over. Defaults to -1. - suppress_tqdm (bool, optional): Whether to display tqdm progress bar. Defaults to False. - Returns: - nn_idx (array, # spikes x # 2*n_neighbors): Graph of nearest neighbor indices - nn_dist (array, # spikes x # 2*n_neighbors): Distances between nearest neighbor points - + """ + Sliding window nearest neighbor clustering. """ _default_params = { @@ -413,20 +391,7 @@ def sparse_euclidean(x, y, n_samples, n_dense): # HACK: this function only exists because I couldn't get the spikeinterface one to work... def retrieve_padded_trace(recording, start_frame, end_frame, margin_frames, channel_ids=None): - """Grabs a chunk of recording trace, with padding - NOTE: I tried using the built in spikeinterface function for this but - recieved an error. - - Args: - recording (_type_): _description_ - start_frame (_type_): _description_ - end_frame (_type_): _description_ - margin_frames (_type_): _description_ - channel_ids (_type_, optional): _description_. Defaults to None. - - Returns: - _type_: _description_ - """ + """Grabs a chunk of recording trace, with padding""" n_frames = recording.get_num_frames() # get the padding _pre = np.max([0, start_frame - margin_frames]) @@ -452,23 +417,7 @@ def get_chunk_spike_waveforms( n_channel_neighbors=5, margin_frames=3000, ): - """Grabs the spike waveforms for a chunk of a recording - Args: - recording (_type_): _description_ - start_frame (_type_): _description_ - end_frame (_type_): _description_ - peaks (_type_): _description_ - channel_neighbors (_type_): _description_ - spike_pre_frames (int, optional): _description_. Defaults to 30. - spike_post_frames (int, optional): _description_. Defaults to 30. - n_channel_neighbors (int, optional): _description_. Defaults to 5. - margin_frames (int, optional): _description_. Defaults to 3000. - - Returns: - all_spikes: spike waveforms - all_chan_idx: channel indices of nearest neighbors - peaks_in_chunk_idx: index of spikes in this chunk - """ + """Grabs the spike waveforms for a chunk of a recording.""" # grab the trace traces = retrieve_padded_trace(recording, start_frame, end_frame, margin_frames, channel_ids=None) @@ -577,22 +526,7 @@ def swap_elements(l, idx1, idx2): def merge_nn_dicts(peaks, n_neighbors, peaks_in_chunk_idx_list, knn_indices_list, knn_distances_list): - """merge together peaks_in_chunk_idx_list and knn_indices_list - to build final graph - - Args: - peaks (_type_): array of peaks - n_neighbors (_type_): number of neighbors - peaks_in_chunk_idx_list (_type_): list of spike index - knn_indices_list (_type_): indices of connections - knn_distances_list (_type_): distances - - Raises: - ValueError: _description_ - - Returns: - _type_: _description_ - """ + """Merge together peaks_in_chunk_idx_list and knn_indices_list to build final graph.""" nn_index_array = np.zeros((len(peaks), n_neighbors * 2), dtype=int) - 1 nn_distance_array = np.zeros((len(peaks), n_neighbors * 2), dtype=float) diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 48ec26679e..55ef0ced40 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -3,7 +3,6 @@ from tqdm.auto import tqdm from sklearn.decomposition import TruncatedSVD -from hdbscan import HDBSCAN import numpy as np @@ -37,17 +36,17 @@ def split_clusters( recording: Recording Recording object features_dict_or_folder: dict or folder - A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features. - method: str + A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features + method: str, default: "hdbscan_on_local_pca" The method name - method_kwargs: dict + method_kwargs: dict, default: dict() The method option - recursive: bool Default False - Reccursive or not. - recursive_depth: None or int - If recursive=True, then this is the max split per spikes. - returns_split_count: bool - Optionally return the split count vector. Same size as labels. + recursive: bool, default: False + Recursive or not + recursive_depth: None or int, default: None + If recursive=True, then this is the max split per spikes + returns_split_count: bool, default: False + Optionally return the split count vector. Same size as labels Returns ------- @@ -218,6 +217,8 @@ def split( final_features = TruncatedSVD(n_pca_features).fit_transform(flatten_features) if clusterer == "hdbscan": + from hdbscan import HDBSCAN + clust = HDBSCAN( min_cluster_size=min_cluster_size, min_samples=min_samples, diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index 9ef036de35..f7f020d153 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -28,10 +28,11 @@ def compute_features_from_peaks( Parameters ---------- recording: RecordingExtractor - The recording extractor object. + The recording extractor object peaks: array - Peaks array, as returned by detect_peaks() in "compact_numpy" way. - feature_list: List of features to be computed. + Peaks array, as returned by detect_peaks() in "compact_numpy" way + feature_list: list, default: ["ptp"] + List of features to be computed. Possible features are: - amplitude - ptp - center_of_mass @@ -40,10 +41,10 @@ def compute_features_from_peaks( - ptp_lag - random_projections_ptp - random_projections_energy - ms_before: float - The duration in ms before the peak for extracting the features (default 1 ms) - ms_after: float - The duration in ms after the peakfor extracting the features (default 1 ms) + ms_before: float, default: 1.0 + The duration in ms before the peak for extracting the features + ms_after: float, default: 1.0 + The duration in ms after the peakfor extracting the features {} diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 260c6a89f3..eec9052e7c 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -12,8 +12,8 @@ def find_spikes_from_templates(recording, method="naive", method_kwargs={}, extr ---------- recording: RecordingExtractor The recording extractor object - method: str - Which method to use ('naive' | 'tridesclous' | 'circus' | 'circus-omp' | 'wobble') + method: "naive" | "tridesclous" | "circus" | "circus-omp" | "wobble" + Which method to use for template matching method_kwargs: dict, optional Keyword arguments for the chosen method extra_outputs: bool @@ -30,8 +30,8 @@ def find_spikes_from_templates(recording, method="naive", method_kwargs={}, extr Notes ----- - For all methods except 'wobble', templates are represented as a WaveformExtractor in method_kwargs - so statistics can be extracted. For 'wobble' templates are represented as a numpy.ndarray. + For all methods except "wobble", templates are represented as a WaveformExtractor in method_kwargs + so statistics can be extracted. For "wobble" templates are represented as a numpy.ndarray. """ from .method_list import matching_methods diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index 4b04462e45..9482d50f9a 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -1,10 +1,8 @@ """Sorting components: template matching.""" import numpy as np -from spikeinterface.core import WaveformExtractor +from spikeinterface.core import WaveformExtractor, get_template_channel_sparsity, get_template_extremum_channel from spikeinterface.core import get_noise_levels, get_channel_distances, get_chunk_with_margin, get_random_data_chunks -from spikeinterface.postprocessing import get_template_channel_sparsity, get_template_extremum_channel - from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive spike_dtype = [ diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 5327e28916..07b9f8baa4 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -42,8 +42,8 @@ class WobbleParameters: Notes ----- - 'Peaks' refer to relative maxima in the convolution of the templates with the voltage trace - (or residual) and 'spikes' refer to putative extracellular action potentials (EAPs). Peaks are considered spikes + "Peaks" refer to relative maxima in the convolution of the templates with the voltage trace + (or residual) and "spikes" refer to putative extracellular action potentials (EAPs). Peaks are considered spikes if their amplitude clears the threshold parameter. """ @@ -107,8 +107,8 @@ class TemplateMetadata: Notes ----- - A 'unit' refers to a putative neuron which may have one or more 'templates' of its spike waveform. - Each 'template' may have many upsampled 'jittered_templates' depending on the 'jitter_factor'. + A "unit" refers to a putative neuron which may have one or more "templates" of its spike waveform. + Each "template" may have many upsampled "jittered_templates" depending on the "jitter_factor". """ num_samples: int @@ -275,21 +275,21 @@ def __post_init__(self): class WobbleMatch(BaseTemplateMatchingEngine): """Template matching method from the Paninski lab. - Templates are jittered or 'wobbled' in time and amplitude to capture variability in spike amplitude and + Templates are jittered or "wobbled" in time and amplitude to capture variability in spike amplitude and super-resolution jitter in spike timing. Algorithm --------- At initialization: - 1. Compute channel sparsity to determine which units are 'visible' to each other + 1. Compute channel sparsity to determine which units are "visible" to each other 2. Compress Templates using Singular Value Decomposition into rank approx_rank 3. Upsample the temporal component of compressed templates and re-index to obtain many super-resolution-jittered temporal components for each template 3. Convolve each pair of jittered compressed templates together (subject to channel sparsity) For each chunk of traces: - 1. Compute the 'objective function' to be minimized by convolving each true template with the traces + 1. Compute the "objective function" to be minimized by convolving each true template with the traces 2. Normalize the objective relative to the magnitude of each true template - 3. Detect spikes by indexing peaks in the objective corresponding to 'matches' between the spike and a template + 3. Detect spikes by indexing peaks in the objective corresponding to "matches" between the spike and a template 4. Determine which super-resolution-jittered template best matches each spike and scale the amplitude to match 5. Subtract scaled pairwise convolved jittered templates from the objective(s) to account for the effect of removing detected spikes from the traces @@ -299,11 +299,11 @@ class WobbleMatch(BaseTemplateMatchingEngine): Notes ----- For consistency, throughout this module - - a 'unit' refers to a putative neuron which may have one or more 'templates' of its spike waveform - - Each 'template' may have many upsampled 'jittered_templates' depending on the 'jitter_factor' - - 'peaks' refer to relative maxima in the convolution of the templates with the voltage trace - - 'spikes' refer to putative extracellular action potentials (EAPs) - - 'peaks' are considered spikes if their amplitude clears the threshold parameter + - a "unit" refers to a putative neuron which may have one or more "templates" of its spike waveform + - Each "template" may have many upsampled "jittered_templates" depending on the "jitter_factor" + - "peaks" refer to relative maxima in the convolution of the templates with the voltage trace + - "spikes" refer to putative extracellular action potentials (EAPs) + - "peaks" are considered spikes if their amplitude clears the threshold parameter """ default_params = { @@ -512,7 +512,7 @@ def find_peaks(cls, objective, objective_normalized, spike_trains, params, templ scalings : ndarray (num_spikes,) Amplitude scaling used for each spike. distance_metric : ndarray (num_spikes) - A metric that describes how good of a 'fit' each spike is to its corresponding template + A metric that describes how good of a "fit" each spike is to its corresponding template Notes ----- diff --git a/src/spikeinterface/sortingcomponents/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion_estimation.py index 1c7d2a53e7..df73575a01 100644 --- a/src/spikeinterface/sortingcomponents/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion_estimation.py @@ -47,57 +47,58 @@ def estimate_motion( peaks: numpy array Peak vector (complex dtype) peak_locations: numpy array - Complex dtype with 'x', 'y', 'z' fields + Complex dtype with "x", "y", "z" fields {method_doc} **histogram section** - direction: 'x', 'y', 'z' + direction: "x" | "y" | "z", default: "y" Dimension on which the motion is estimated - bin_duration_s: float + bin_duration_s: float, default: 10 Bin duration in second - bin_um: float (default 10.) - Spatial bin size in micro meter - margin_um: float (default 0.) + bin_um: float, default: 10 + Spatial bin size in micrometers + margin_um: float, default: 0 Margin in um to exclude from histogram estimation and non-rigid smoothing functions to avoid edge effects. Positive margin extrapolate out of the probe the motion. - Negative margin crop the motion on the border. + Negative margin crop the motion on the border **non-rigid section** - rigid : bool (default True) + rigid : bool, default: False Compute rigid (one motion for the entire probe) or non rigid motion Rigid computation is equivalent to non-rigid with only one window with rectangular shape. - win_shape: 'gaussian' or 'rect' or 'triangle' + win_shape: "gaussian" | "rect" | "triangle", default: "gaussian" The shape of the windows for non rigid. - When rigid this is force to 'rect' - win_step_um: float (default 50.) + When rigid this is force to "rect" + win_step_um: float, default: 50 Step deteween window - win_sigma_um: float (deafult 150.) + win_sigma_um: float, default: 150 + Sigma of the gaussian window **motion cleaning section** - post_clean: bool (default False) - Apply some post cleaning to motion matrix or not. - speed_threshold: float default 30. - Detect to fast motion bump and remove then with interpolation. + post_clean: bool, default: False + Apply some post cleaning to motion matrix or not + speed_threshold: float default: 30. + Detect to fast motion bump and remove then with interpolation sigma_smooth_s: None or float - Optional smooting gaussian kernel when not None. + Optional smooting gaussian kernel when not None - output_extra_check: bool + output_extra_check: bool, default: False If True then return an extra dict that contains variables to check intermediate steps (motion_histogram, non_rigid_windows, pairwise_displacement) - upsample_to_histogram_bin: bool or None + upsample_to_histogram_bin: bool or None, default: False If True then upsample the returned motion array to the number of depth bins specified by bin_um. When None: * for non rigid case: then automatically True * for rigid (non_rigid_kwargs=None): automatically False - This feature is in fact a bad idea and the interpolation should be done outside using better methods. - progress_bar: bool - Display progress bar or not. - verbose: bool + This feature is in fact a bad idea and the interpolation should be done outside using better methods + progress_bar: bool, default: False + Display progress bar or not + verbose: bool, default: False If True, output is verbose @@ -216,19 +217,19 @@ class DecentralizedRegistration: histogram_time_smooth_s: None or float Optional gaussian smoother on histogram on time axis. This is given as the sigma of the gaussian in seconds. - pairwise_displacement_method: 'conv' or 'phase_cross_correlation' + pairwise_displacement_method: "conv" or "phase_cross_correlation" How to estimate the displacement in the pairwise matrix. max_displacement_um: float Maximum possible discplacement in micrometers. - weight_scale: 'linear' or 'exp' + weight_scale: "linear" or "exp" For parwaise displacement, how to to rescale the associated weight matrix. - error_sigma: float 0.2 - In case weight_scale='exp' this controls the sigma of the exponential. - conv_engine: 'numpy' or 'torch' - In case of pairwise_displacement_method='conv', what library to use to compute - the underlying correlation. + error_sigma: float, default: 0.2 + In case weight_scale="exp" this controls the sigma of the exponential. + conv_engine: "numpy" or "torch" or None, default: None + In case of pairwise_displacement_method="conv", what library to use to compute + the underlying correlation torch_device=None - In case of conv_engine='torch', you can control which device (cpu or gpu) + In case of conv_engine="torch", you can control which device (cpu or gpu) batch_size: int Size of batch for the convolution. Increasing this will speed things up dramatically on GPUs and sometimes on CPU as well. @@ -240,17 +241,17 @@ class DecentralizedRegistration: When not None the parwise discplament matrix is computed in a small time horizon. In short only pair of bins close in time. So the pariwaise matrix is super sparse and have values only the diagonal. - convergence_method='lsmr', 'lsqr_robust', 'gradient_descent' + convergence_method: "lsmr" | "lsqr_robust" | "gradient_descent", default: "lsqr_robust" Which method to use to compute the global displacement vector from the pairwise matrix. robust_regression_sigma: float - Use for convergence_method='lsqr_robust' for iterative selection of the regression. - temporal_prior : bool=True + Use for convergence_method="lsqr_robust" for iterative selection of the regression. + temporal_prior : bool, default: True Ensures continuity across time, unless there is evidence in the recording for jumps. - spatial_prior : bool, False + spatial_prior : bool, default: False Ensures continuity across space. Not usually necessary except in recordings with glitches across space. - force_spatial_median_continuity: bool, False - When spatial_prior=False we can optionaly apply a median continuity across spatial windows. + force_spatial_median_continuity: bool, default: False + When spatial_prior=False we can optionally apply a median continuity across spatial windows. reference_displacement : string, one of: "mean", "median", "time", "mode_search" Strategy for picking what is considered displacement=0. - "mean" : the mean displacement is subtracted @@ -258,7 +259,7 @@ class DecentralizedRegistration: - "time" : the displacement at a given time (in seconds) is subtracted - "mode_search" : an attempt is made to guess the mode. needs work. lsqr_robust_n_iter: int - Number of iteration for convergence_method='lsqr_robust'. + Number of iteration for convergence_method="lsqr_robust". """ @classmethod @@ -467,17 +468,17 @@ class IterativeTemplateRegistration: name = "iterative_template" params_doc = """ - num_amp_bins: int - number ob bins in the histogram on the log amplitues dimension, by default 20. - num_shifts_global: int - Number of spatial bin shifts to consider for global alignment, by default 15 - num_iterations: int - Number of iterations for global alignment procedure, by default 10 - num_shifts_block: int - Number of spatial bin shifts to consider for non-rigid alignment, by default 5 - smoothing_sigma: float - Sigma of gaussian for covariance matrices smoothing, by default 0.5 - kriging_sigma: float + num_amp_bins: int, default: 20 + number ob bins in the histogram on the log amplitues dimension + num_shifts_global: int, default: 15 + Number of spatial bin shifts to consider for global alignment + num_iterations: int, default: 10 + Number of iterations for global alignment procedure + num_shifts_block: int, default: 5 + Number of spatial bin shifts to consider for non-rigid alignment + smoothing_sigma: float, default: 0.5 + Sigma of gaussian for covariance matrices smoothing + kriging_sigma: float, sigma parameter for kriging_kernel function kriging_p: foat p parameter for kriging_kernel function @@ -663,19 +664,19 @@ def make_2d_motion_histogram( The peaks array peak_locations : np.array Array with peak locations - weight_with_amplitude : bool, optional - If True, motion histogram is weighted by amplitudes, by default False - direction : str, optional - 'x', 'y', 'z', by default 'y' - bin_duration_s : float, optional - The temporal bin duration in s, by default 1. - bin_um : float, optional - The spatial bin size in um, by default 2. Ignored if spatial_bin_edges is given. - margin_um : float, optional - The margin to add to the minimum and maximum positions before spatial binning, by default 50. + weight_with_amplitude : bool, default: False + If True, motion histogram is weighted by amplitudes + direction : "x" | "y" | "z", default: "y" + The depth direction + bin_duration_s : float, default: 1.0 + The temporal bin duration in s + bin_um : float, default: 2.0 + The spatial bin size in um. Ignored if spatial_bin_edges is given. + margin_um : float, default: 50 + The margin to add to the minimum and maximum positions before spatial binning. Ignored if spatial_bin_edges is given. - spatial_bin_edges : np.array, optional - The pre-computed spatial bin edges, by default None + spatial_bin_edges : np.array, default: None + The pre-computed spatial bin edges Returns ------- @@ -739,19 +740,19 @@ def make_3d_motion_histograms( The peaks array peak_locations : np.array Array with peak locations - direction : str, optional - 'x', 'y', 'z', by default 'y' - bin_duration_s : float, optional - The temporal bin duration in s, by default 1. - bin_um : float, optional - The spatial bin size in um, by default 2. Ignored if spatial_bin_edges is given. - margin_um : float, optional - The margin to add to the minimum and maximum positions before spatial binning, by default 50. + direction : "x" | "y" | "z", default: "y" + The depth direction + bin_duration_s : float, default: 1.0 + The temporal bin duration in s. + bin_um : float, default: 2.0 + The spatial bin size in um. Ignored if spatial_bin_edges is given. + margin_um : float, default: 50 + The margin to add to the minimum and maximum positions before spatial binning. Ignored if spatial_bin_edges is given. - log_transform : bool, optional - If True, histograms are log-transformed, by default True - spatial_bin_edges : np.array, optional - The pre-computed spatial bin edges, by default None + log_transform : bool, default: True + If True, histograms are log-transformed + spatial_bin_edges : np.array, default: None + The pre-computed spatial bin edges Returns ------- @@ -1198,22 +1199,22 @@ def iterative_template_registration( spikecounts_hist_images : np.ndarray Spike count histogram images (num_temporal_bins, num_spatial_bins, num_amps_bins) - non_rigid_windows : list, optional + non_rigid_windows : list, default: None If num_non_rigid_windows > 1, this argument is required and it is a list of - windows to taper spatial bins in different blocks, by default None - num_shifts_global : int, optional - Number of spatial bin shifts to consider for global alignment, by default 15 - num_iterations : int, optional - Number of iterations for global alignment procedure, by default 10 - num_shifts_block : int, optional - Number of spatial bin shifts to consider for non-rigid alignment, by default 5 - smoothing_sigma : float, optional - Sigma of gaussian for covariance matrices smoothing, by default 0.5 - kriging_sogma : float, optional + windows to taper spatial bins in different blocks + num_shifts_global : int, default: 15 + Number of spatial bin shifts to consider for global alignment + num_iterations : int, default: 10 + Number of iterations for global alignment procedure + num_shifts_block : int, default: 5 + Number of spatial bin shifts to consider for non-rigid alignment + smoothing_sigma : float, default: 0.5 + Sigma of gaussian for covariance matrices smoothing + kriging_sigma : float, default: 1 sigma parameter for kriging_kernel function - kriging_p : float, optional + kriging_p : float, default: 2 p parameter for kriging_kernel function - kriging_d : float, optional + kriging_d : float, default: 2 d parameter for kriging_kernel function Returns @@ -1350,8 +1351,8 @@ def normxcorr1d( the above formula is made to the weighted case -- and all of the normalizations are done per block in the same way. - Arguments - --------- + Parameters + ---------- template : tensor, shape (num_templates, length) The reference template signal x : tensor, 1d shape (length,) or 2d shape (num_inputs, length) @@ -1362,7 +1363,7 @@ def normxcorr1d( If true, means will be subtracted (per weighted patch). normalized : bool If true, normalize by the variance (per weighted patch). - padding : int, optional + padding : str How far to look? if unset, we'll use half the length conv_engine : string, one of "torch", "numpy" What library to use for computing cross-correlations. diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index b4a44105e4..a81212897c 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -103,9 +103,10 @@ def interpolate_motion_on_traces( Dimension of shift in channel_locations. channel_inds: None or list If not None, interpolate only a subset of channels. - spatial_interpolation_method: str in ('idw', 'kriging') - * idw : Inverse Distance Weighing - * kriging : kilosort2.5 like + spatial_interpolation_method: "idw" | "kriging", default: "kriging" + The spatial interpolation method used to interpolate the channel locations: + * idw : Inverse Distance Weighing + * kriging : kilosort2.5 like spatial_interpolation_kwargs: * specific option for the interpolation method @@ -225,28 +226,28 @@ class InterpolateMotionRecording(BasePreprocessor): Temporal bins in second. spatial_bins: None or np.array Bins for non-rigid motion. If None, rigid motion is used - direction: int (0, 1, 2) - Dimension along which channel_locations are shifted (0 - x, 1 - y, 2 - z), by default 1 - spatial_interpolation_method: str - 'kriging' or 'idw' or 'nearest'. + direction: 0 | 1 | 2, default: 1 + Dimension along which channel_locations are shifted (0 - x, 1 - y, 2 - z) + spatial_interpolation_method: "kriging" | "idw" | "nearest", default: "kriging" + The spatial interpolation method used to interpolate the channel locations. See `spikeinterface.preprocessing.get_spatial_interpolation_kernel()` for more details. Choice of the method: - * 'kriging' : the same one used in kilosort - * 'idw' : inverse distance weighted - * 'nearest' : use nereast channel - sigma_um: float (default 20.) - Used in the 'kriging' formula - p: int (default 1) - Used in the 'kriging' formula - num_closest: int (default 3) - Number of closest channels used by 'idw' method for interpolation. - border_mode: str + * "kriging" : the same one used in kilosort + * "idw" : inverse distance weighted + * "nearest" : use neareast channel + sigma_um: float, default: 20.0 + Used in the "kriging" formula + p: int, default: 1 + Used in the "kriging" formula + num_closest: int, default: 3 + Number of closest channels used by "idw" method for interpolation. + border_mode: "remove_channels" | "force_extrapolate" | "force_zeros", default: "remove_channels" Control how channels are handled on border: - * 'remove_channels': remove channels on the border, the recording has less channels - * 'force_extrapolate': keep all channel and force extrapolation (can lead to strange signal) - * 'force_zeros': keep all channel but set zeros when outside (force_extrapolate=False) + * "remove_channels": remove channels on the border, the recording has less channels + * "force_extrapolate": keep all channel and force extrapolation (can lead to strange signal) + * "force_zeros": keep all channel but set zeros when outside (force_extrapolate=False) Returns ------- diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index bc52ea2c70..ec790e614a 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -53,8 +53,8 @@ def detect_peaks( ): """Peak detection based on threshold crossing in term of k x MAD. - In 'by_channel' : peak are detected in each channel independently - In 'locally_exclusive' : a single best peak is taken from a set of neighboring channels + In "by_channel" : peak are detected in each channel independently + In "locally_exclusive" : a single best peak is taken from a set of neighboring channels Parameters ---------- @@ -152,16 +152,17 @@ def __init__( Parameters ---------- recording : BaseRecording - The recording to process. + The recording to process peak_detector_node : PeakDetector - The peak detector node to use. + The peak detector node to use waveform_extraction_node : WaveformsNode - The waveform extraction node to use. + The waveform extraction node to use waveform_denoising_node - The waveform denoising node to use. - num_iterations : int, optional, default=2 - The number of iterations to run the algorithm. - return_output : bool, optional, default=True + The waveform denoising node to use + num_iterations : int, default: 2 + The number of iterations to run the algorithm + return_output : bool, default: True + Whether to return the output of the algorithm """ PeakDetector.__init__(self, recording, return_output=return_output) self.peak_detector_node = peak_detector_node @@ -356,26 +357,27 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): class DetectPeakByChannel(PeakDetectorWrapper): - """Detect peaks using the 'by channel' method.""" + """Detect peaks using the "by channel" method.""" name = "by_channel" engine = "numpy" preferred_mp_context = None params_doc = """ - peak_sign: 'neg', 'pos', 'both' - Sign of the peak. - detect_threshold: float - Threshold, in median absolute deviations (MAD), to use to detect peaks. - exclude_sweep_ms: float or None + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the peak + detect_threshold: float, default: 5 + Threshold, in median absolute deviations (MAD), to use to detect peaks + exclude_sweep_ms: float, default: 0.1 Time, in ms, during which the peak is isolated. Exclusive param with exclude_sweep_size For example, if `exclude_sweep_ms` is 0.1, a peak is detected if a sample crosses the threshold, - and no larger peaks are located during the 0.1ms preceding and following the peak. - noise_levels: array, optional - Estimated noise levels to use, if already computed. - If not provide then it is estimated from a random snippet of the data. - random_chunk_kwargs: dict, optional + and no larger peaks are located during the 0.1ms preceding and following the peak + noise_levels: array or None, default: None + Estimated noise levels to use, if already computed + If not provide then it is estimated from a random snippet of the data + random_chunk_kwargs: dict, default: dict() A dict that contain option to randomize chunk for get_noise_levels(). - Only used if noise_levels is None.""" + Only used if noise_levels is None + """ @classmethod def check_params( @@ -437,30 +439,31 @@ def detect_peaks(cls, traces, peak_sign, abs_threholds, exclude_sweep_size): class DetectPeakByChannelTorch(PeakDetectorWrapper): - """Detect peaks using the 'by channel' method with pytorch.""" + """Detect peaks using the "by channel" method with pytorch.""" name = "by_channel_torch" engine = "torch" preferred_mp_context = "spawn" params_doc = """ - peak_sign: 'neg', 'pos', 'both' - Sign of the peak. - detect_threshold: float - Threshold, in median absolute deviations (MAD), to use to detect peaks. - exclude_sweep_ms: float or None + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the peak + detect_threshold: float, default: 5 + Threshold, in median absolute deviations (MAD), to use to detect peaks + exclude_sweep_ms: float, default: 0.1 Time, in ms, during which the peak is isolated. Exclusive param with exclude_sweep_size For example, if `exclude_sweep_ms` is 0.1, a peak is detected if a sample crosses the threshold, - and no larger peaks are located during the 0.1ms preceding and following the peak. - noise_levels: array, optional + and no larger peaks are located during the 0.1ms preceding and following the peak + noise_levels: array or None, default: None Estimated noise levels to use, if already computed. - If not provide then it is estimated from a random snippet of the data. - device : str, optional - "cpu", "cuda", or None. If None and cuda is available, "cuda" is selected, by default None - return_tensor : bool, optional - If True, the output is returned as a tensor, otherwise as a numpy array, by default False - random_chunk_kwargs: dict, optional + If not provide then it is estimated from a random snippet of the data + device : str or None, default: None + "cpu", "cuda", or None. If None and cuda is available, "cuda" is selected + return_tensor : bool, default: False + If True, the output is returned as a tensor, otherwise as a numpy array + random_chunk_kwargs: dict, default: dict() A dict that contain option to randomize chunk for get_noise_levels(). - Only used if noise_levels is None.""" + Only used if noise_levels is None. + """ @classmethod def check_params( @@ -502,7 +505,7 @@ def detect_peaks(cls, traces, peak_sign, abs_threholds, exclude_sweep_size, devi class DetectPeakLocallyExclusive(PeakDetectorWrapper): - """Detect peaks using the 'locally exclusive' method.""" + """Detect peaks using the "locally exclusive" method.""" name = "locally_exclusive" engine = "numba" @@ -578,7 +581,7 @@ def detect_peaks(cls, traces, peak_sign, abs_threholds, exclude_sweep_size, neig class DetectPeakLocallyExclusiveTorch(PeakDetectorWrapper): - """Detect peaks using the 'locally exclusive' method with pytorch.""" + """Detect peaks using the "locally exclusive" method with pytorch.""" name = "locally_exclusive_torch" engine = "torch" @@ -711,18 +714,18 @@ def _torch_detect_peaks(traces, peak_sign, abs_thresholds, exclude_sweep_size=5, Chunk of traces abs_thresholds : np.array Absolute thresholds by channel - peak_sign : str, optional - "neg", "pos" or "both", by default "neg" - exclude_sweep_size : int, optional - How many temporal neighbors to compare with during argrelmin, by default 5 + peak_sign : "neg" | "pos" | "both", default: "neg" + The sign of the peak to detect peaks + exclude_sweep_size : int, default: 5 + How many temporal neighbors to compare with during argrelmin Called `order` in original the implementation. The `max_window` parameter, used for deduplication, is now set as 2* exclude_sweep_size - neighbor_mask : np.array, optional + neighbor_mask : np.array or None, default: None If given, a matrix with shape (num_channels, num_neighbours) with neighbour indices for each channel. The matrix needs to be rectangular and - padded to num_channels, by default None - device : str, optional - "cpu", "cuda", or None. If None and cuda is available, "cuda" is selected, by default None + padded to num_channels + device : str or None, default: None + "cpu", "cuda", or None. If None and cuda is available, "cuda" is selected Returns ------- diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index 6495503b43..75c8f7f03f 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -102,7 +102,7 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_ ------- peak_locations: ndarray Array with estimated location for each spike. - The dtype depends on the method. ('x', 'y') or ('x', 'y', 'z', 'alpha'). + The dtype depends on the method. ("x", "y") or ("x", "y", "z", "alpha"). """ peak_retriever = PeakRetriever(recording, peaks) peak_locations = _run_localization_from_peak_source( @@ -165,8 +165,8 @@ class LocalizeCenterOfMass(LocalizeBase): params_doc = """ radius_um: float Radius in um for channel sparsity. - feature: str ['ptp', 'mean', 'energy', 'peak_voltage'] - Feature to consider for computation. Default is 'ptp' + feature: "ptp" | "mean" | "energy" | "peak_voltage", default: "ptp" + Feature to consider for computation """ def __init__(self, recording, return_output=True, parents=["extract_waveforms"], radius_um=75.0, feature="ptp"): @@ -227,12 +227,12 @@ class LocalizeMonopolarTriangulation(PipelineNode): For channel sparsity. max_distance_um: float, default: 1000 Boundary for distance estimation. - enforce_decrease : bool (default True) + enforce_decrease : bool, default: True Enforce spatial decreasingness for PTP vectors - feature: string in ['ptp', 'energy', 'peak_voltage'] + feature: "ptp", "energy", "peak_voltage", default: "ptp" The available features to consider for estimating the position via monopolar triangulation are peak-to-peak amplitudes (ptp, default), - energy ('energy', as L2 norm) or voltages at the center of the waveform + energy ("energy", as L2 norm) or voltages at the center of the waveform (peak_voltage) """ @@ -326,11 +326,11 @@ class LocalizeGridConvolution(PipelineNode): The margin for the grid of fake templates prototype: np.array Fake waveforms for the templates. If None, generated as Gaussian - percentile: float (default 10) + percentile: float, default: 5 The percentage in [0, 100] of the best scalar products kept to estimate the position - sparsity_threshold: float (default 0.1) - The sparsity threshold (in 0-1) below which weights should be considered as 0. + sparsity_threshold: float, default: 0.01 + The sparsity threshold (in [0-1]) below which weights should be considered as 0 """ def __init__( diff --git a/src/spikeinterface/sortingcomponents/peak_selection.py b/src/spikeinterface/sortingcomponents/peak_selection.py index f051b1be46..823da2d928 100644 --- a/src/spikeinterface/sortingcomponents/peak_selection.py +++ b/src/spikeinterface/sortingcomponents/peak_selection.py @@ -11,15 +11,15 @@ def select_peaks(peaks, method="uniform", seed=None, return_indices=False, **met Parameters ---------- peaks: the peaks that have been found - method: 'uniform', 'uniform_locations', 'smart_sampling_amplitudes', 'smart_sampling_locations', - 'smart_sampling_locations_and_time' + method: "uniform", "uniform_locations", "smart_sampling_amplitudes", "smart_sampling_locations", + "smart_sampling_locations_and_time" Method to use. Options: - * 'uniform': a random subset is selected from all the peaks, on a per channel basis by default - * 'smart_sampling_amplitudes': peaks are selected via monte-carlo rejection probabilities + * "uniform": a random subset is selected from all the peaks, on a per channel basis by default + * "smart_sampling_amplitudes": peaks are selected via monte-carlo rejection probabilities based on peak amplitudes, on a per channel basis - * 'smart_sampling_locations': peaks are selection via monte-carlo rejections probabilities + * "smart_sampling_locations": peaks are selection via monte-carlo rejections probabilities based on peak locations, on a per area region basis- - * 'smart_sampling_locations_and_time': peaks are selection via monte-carlo rejections probabilities + * "smart_sampling_locations_and_time": peaks are selection via monte-carlo rejections probabilities based on peak locations and time positions, assuming everything is independent seed: int @@ -29,26 +29,26 @@ def select_peaks(peaks, method="uniform", seed=None, return_indices=False, **met method_kwargs: dict of kwargs method Keyword arguments for the chosen method: - 'uniform': - * select_per_channel: bool - If True, the selection is done on a per channel basis (False by default) + "uniform": + * select_per_channel: bool, default: False + If True, the selection is done on a per channel basis * n_peaks: int If select_per_channel is True, this is the number of peaks per channels, otherwise this is the total number of peaks - 'smart_sampling_amplitudes': + "smart_sampling_amplitudes": * noise_levels : array The noise levels used while detecting the peaks * n_peaks: int If select_per_channel is True, this is the number of peaks per channels, otherwise this is the total number of peaks - * select_per_channel: bool - If True, the selection is done on a per channel basis (False by default) - 'smart_sampling_locations': + * select_per_channel: bool, default: False + If True, the selection is done on a per channel basis + "smart_sampling_locations": * n_peaks: int Total number of peaks to select * peaks_locations: array The locations of all the peaks, computed via localize_peaks - 'smart_sampling_locations_and_time': + "smart_sampling_locations_and_time": * n_peaks: int Total number of peaks to select * peaks_locations: array diff --git a/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py b/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py index df6dd81a97..e70f44442c 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py +++ b/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py @@ -15,15 +15,15 @@ class SavGolDenoiser(WaveformsNode): Parameters ---------- recording: BaseRecording - The recording extractor object. - return_output: bool, optional - Whether to return output from this node (default True). - parents: list of PipelineNodes, optional - The parent nodes of this node (default None). - order: int, optional - the order of the filter (default 3) - window_length_ms: float, optional - the temporal duration of the filter in ms (default 0.25) + The recording extractor object + return_output: bool, default: True + Whether to return output from this node + parents: list of PipelineNodes, default: None + The parent nodes of this node + order: int, default: 3 + the order of the filter + window_length_ms: float, default: 0.25 + the temporal duration of the filter in ms """ def __init__( diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index 89f4efa924..a9e6126ccc 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -103,23 +103,24 @@ def fit( Parameters ---------- recording : BaseRecording - The recording object. + The recording object n_components : int - The number of components to use for the PCA model. + The number of components to use for the PCA model model_folder_path : str, Path - The path to the folder containing the pca model and the training metadata. + The path to the folder containing the pca model and the training metadata detect_peaks_params : dict - The parameters for peak detection. + The parameters for peak detection peak_selection_params : dict - The parameters for peak selection. - whiten : bool, optional - Whether to whiten the data, by default True. - radius_um : float, optional - The radius (in micrometers) to use for definint sparsity, by default None. - ms_before : float, optional - The number of milliseconds to include before the peak of the spike, by default 1. - ms_after : float, optional - The number of milliseconds to include after the peak of the spike, by default 1. + The parameters for peak selection + ms_before : float, default: 1 + The number of milliseconds to include before the peak of the spike + ms_after : float, default: 1 + The number of milliseconds to include after the peak of the spike + whiten : bool, default: True + Whether to whiten the data + radius_um : float or None, default: None + The radius (in micrometers) to use for definint sparsity. If None, no sparsity is used + {} @@ -186,12 +187,12 @@ class TemporalPCAProjection(TemporalPCBaseNode): Parameters ---------- recording : BaseRecording - The recording object. + The recording object parents: list - The parent nodes of this node. This should contain a mechanism to extract waveforms. + The parent nodes of this node. This should contain a mechanism to extract waveforms model_folder_path : str, Path - The path to the folder containing the pca model and the training metadata. - return_output: bool, optional, true by default + The path to the folder containing the pca model and the training metadata + return_output: bool, default: True use false to suppress the output of this node in the pipeline """ @@ -242,12 +243,12 @@ class TemporalPCADenoising(TemporalPCBaseNode): Parameters ---------- recording : BaseRecording - The recording object. + The recording object parents: list - The parent nodes of this node. This should contain a mechanism to extract waveforms. + The parent nodes of this node. This should contain a mechanism to extract waveforms model_folder_path : str, Path - The path to the folder containing the pca model and the training metadata. - return_output: bool, optional, true by default + The path to the folder containing the pca model and the training metadata + return_output: bool, default: True use false to suppress the output of this node in the pipeline """ diff --git a/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py b/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py index edac5c1932..a1e532eeb7 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py +++ b/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py @@ -16,27 +16,27 @@ class WaveformThresholder(WaveformsNode): This node allows you to perform adaptive masking by setting channels to 0 that have a given feature below a certain threshold. The available features - to consider are peak-to-peak amplitude ('ptp'), mean amplitude ('mean'), - energy ('energy'), and peak voltage ('peak_voltage'). + to consider are peak-to-peak amplitude ("ptp"), mean amplitude ("mean"), + energy ("energy"), and peak voltage ("peak_voltage"). Parameters ---------- recording: BaseRecording - The recording extractor object. - return_output: bool, optional - Whether to return output from this node (default True). - parents: list of PipelineNodes, optional - The parent nodes of this node (default None). - feature: {'ptp', 'mean', 'energy', 'peak_voltage'}, optional - The feature to be considered for thresholding (default 'ptp'). Features are normalized with the channel noise levels. - threshold: float, optional - The threshold value for the selected feature (default 2). - noise_levels: array, optional + The recording extractor object + return_output: bool, default: True + Whether to return output from this node + parents: list of PipelineNodes, default: None + The parent nodes of this node + feature: "ptp" | "mean" | "energy" | "peak_voltage", default: "ptp" + The feature to be considered for thresholding . Features are normalized with the channel noise levels. + threshold: float, default: 2 + The threshold value for the selected feature + noise_levels: array of None, default: None The noise levels to determine the thresholds - random_chunk_kwargs: dict + random_chunk_kwargs: dict, default: dict() Parameters for computing noise levels, if not provided (sub optimal) - operator: callable, optional - Comparator to flag values that should be set to 0 (default less or equal) + operator: callable, default: operator.le (less or equal) + Comparator to flag values that should be set to 0 """ def __init__( diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 6b6496a577..a917f043b3 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -15,24 +15,23 @@ class AmplitudesWidget(BaseWidget): ---------- waveform_extractor : WaveformExtractor The input waveform extractor - unit_ids : list - List of unit ids, default None - segment_index : int - The segment index (or None if mono-segment), default None - max_spikes_per_unit : int - Number of max spikes per unit to display. Use None for all spikes. - Default None. - hide_unit_selector : bool - If True the unit selector is not displayed, default False + unit_ids : list or None, default: None + List of unit ids + segment_index : int or None, default: None + The segment index (or None if mono-segment) + max_spikes_per_unit : int or None, default: None + Number of max spikes per unit to display. Use None for all spikes + hide_unit_selector : bool, default: False + If True the unit selector is not displayed (sortingview backend) - plot_histogram : bool - If True, an histogram of the amplitudes is plotted on the right axis, default False + plot_histogram : bool, default: False + If True, an histogram of the amplitudes is plotted on the right axis (matplotlib backend) - bins : int + bins : int or None, default: None If plot_histogram is True, the number of bins for the amplitude histogram. - If None this is automatically adjusted, default None - plot_legend : bool - True includes legend in plot, default True + If None this is automatically adjusted + plot_legend : bool, default: True + True includes legend in plot """ def __init__( diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 9fc7b73707..a5d3cb2429 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -6,7 +6,7 @@ def get_default_plotter_backend(): """Return the default backend for spikeinterface widgets. - The default backend is 'matplotlib' at init. + The default backend is "matplotlib" at init. It can be be globally set with `set_default_plotter_backend(backend)` """ @@ -21,23 +21,23 @@ def set_default_plotter_backend(backend): backend_kwargs_desc = { "matplotlib": { - "figure": "Matplotlib figure. When None, it is created. Default None", - "ax": "Single matplotlib axis. When None, it is created. Default None", - "axes": "Multiple matplotlib axes. When None, they is created. Default None", - "ncols": "Number of columns to create in subplots. Default 5", - "figsize": "Size of matplotlib figure. Default None", - "figtitle": "The figure title. Default None", + "figure": "Matplotlib figure. When None, it is created, default: None", + "ax": "Single matplotlib axis. When None, it is created, default: None", + "axes": "Multiple matplotlib axes. When None, they is created, default: None", + "ncols": "Number of columns to create in subplots, default: 5", + "figsize": "Size of matplotlib figure, default: None", + "figtitle": "The figure title, default: None", }, "sortingview": { - "generate_url": "If True, the figurl URL is generated and printed. Default True", - "display": "If True and in jupyter notebook/lab, the widget is displayed in the cell. Default True.", - "figlabel": "The figurl figure label. Default None", - "height": "The height of the sortingview View in jupyter. Default None", + "generate_url": "If True, the figurl URL is generated and printed, default: True", + "display": "If True and in jupyter notebook/lab, the widget is displayed in the cell, default: True.", + "figlabel": "The figurl figure label, default: None", + "height": "The height of the sortingview View in jupyter, default: None", }, "ipywidgets": { - "width_cm": "Width of the figure in cm (default 10)", - "height_cm": "Height of the figure in cm (default 6)", - "display": "If True, widgets are immediately displayed", + "width_cm": "Width of the figure in cm, default: 10", + "height_cm": "Height of the figure in cm, default 6", + "display": "If True, widgets are immediately displayed, default: True", # "controllers": "" }, "ephyviewer": {}, @@ -123,7 +123,7 @@ def __init__(self, d): Helper function that transform a dict into an object where attributes are the keys of the dict - d = {'a': 1, 'b': 'yep'} + d = {"a": 1, "b": "yep"} o = to_attr(d) print(o.a, o.b) """ diff --git a/src/spikeinterface/widgets/collision.py b/src/spikeinterface/widgets/collision.py index 2b86a2af2d..046146635c 100644 --- a/src/spikeinterface/widgets/collision.py +++ b/src/spikeinterface/widgets/collision.py @@ -13,13 +13,13 @@ class ComparisonCollisionBySimilarityWidget(BaseWidget): The collision ground truth comparison object templates: array template of units - mode: 'heatmap' or 'lines' - to see collision curves for every pairs ('heatmap') or as lines averaged over pairs. + mode: "heatmap" or "lines" + to see collision curves for every pairs ("heatmap") or as lines averaged over pairs. similarity_bins: array - if mode is 'lines', the bins used to average the pairs + if mode is "lines", the bins used to average the pairs cmap: string - colormap used to show averages if mode is 'lines' - metric: 'cosine_similarity' + colormap used to show averages if mode is "lines" + metric: "cosine_similarity" metric for ordering good_only: True keep only the pairs with a non zero accuracy (found templates) @@ -182,12 +182,12 @@ class StudyComparisonCollisionBySimilarityWidget(BaseWidget): The collision study object. case_keys: list or None A selection of cases to plot, if None, then all. - metric: 'cosine_similarity' + metric: "cosine_similarity" metric for ordering similarity_bins: array - if mode is 'lines', the bins used to average the pairs + if mode is "lines", the bins used to average the pairs cmap: string - colormap used to show averages if mode is 'lines' + colormap used to show averages if mode is "lines" good_only: False keep only the pairs with a non zero accuracy (found templates) min_accuracy: float diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index 3ec3fa11b6..9403a5dd03 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -15,16 +15,16 @@ class CrossCorrelogramsWidget(BaseWidget): ---------- waveform_or_sorting_extractor : WaveformExtractor or BaseSorting The object to compute/get crosscorrelograms from - unit_ids list - List of unit ids, default None - window_ms : float - Window for CCGs in ms, default 100.0 ms - bin_ms : float - Bin size in ms, default 1.0 ms - hide_unit_selector : bool - For sortingview backend, if True the unit selector is not displayed, default False - unit_colors: dict or None - If given, a dictionary with unit ids as keys and colors as values, default None + unit_ids list or None, default: None + List of unit ids + window_ms : float, default: 100.0 + Window for CCGs in ms + bin_ms : float, default: 1.0 + Bin size in ms + hide_unit_selector : bool, default: False + For sortingview backend, if True the unit selector is not displayed + unit_colors: dict or None, default: None + If given, a dictionary with unit ids as keys and colors as values """ def __init__( diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 6e4433ee60..5e934f9702 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -136,10 +136,10 @@ class StudyPerformances(BaseWidget): mode: "ordered" | "snr" | "swarm", default: "ordered" Which plot mode to use: - * "ordered": plot performance metrics vs unit indices ordered by decreasing accuracy (default) + * "ordered": plot performance metrics vs unit indices ordered by decreasing accuracy * "snr": plot performance metrics vs snr * "swarm": plot performance metrics as a swarm plot (see seaborn.swarmplot for details) - performance_names: list or tuple, default: ('accuracy', 'precision', 'recall') + performance_names: list or tuple, default: ("accuracy", "precision", "recall") Which performances to plot ("accuracy", "precision", "recall") case_keys: list or None A selection of cases to plot, if None, then all. diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index bc44e58a33..3a5c8437cd 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -16,18 +16,18 @@ class MetricsBaseWidget(BaseWidget): Data frame with metrics sorting: BaseSorting The sorting object used for metrics calculations - unit_ids: list - List of unit ids, default None - skip_metrics: list or None - If given, a list of quality metrics to skip, default None - include_metrics: list or None - If given, a list of quality metrics to include, default None - unit_colors : dict or None - If given, a dictionary with unit ids as keys and colors as values, default None - hide_unit_selector : bool - For sortingview backend, if True the unit selector is not displayed, default False - include_metrics_data : bool - If True, metrics data are included in unit table, by default True + unit_ids: list or None, default: None + List of unit ids, default: None + skip_metrics: list or None, default: None + If given, a list of quality metrics to skip, default: None + include_metrics: list or None, default: None + If given, a list of quality metrics to include, default: None + unit_colors : dict or None, default: None + If given, a dictionary with unit ids as keys and colors as values + hide_unit_selector : bool, default: False + For sortingview backend, if True the unit selector is not displayed + include_metrics_data : bool, default: True + If True, metrics data are included in unit table """ def __init__( diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index cb11bcce0c..b097dca1f0 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -11,24 +11,24 @@ class MotionWidget(BaseWidget): ---------- motion_info: dict The motion info return by correct_motion() or load back with load_motion_info() - recording : RecordingExtractor, optional - The recording extractor object (only used to get "real" times), default None - sampling_frequency : float, optional - The sampling frequency (needed if recording is None), default None - depth_lim : tuple - The min and max depth to display, default None (min and max of the recording) - motion_lim : tuple - The min and max motion to display, default None (min and max of the motion) - color_amplitude : bool - If True, the color of the scatter points is the amplitude of the peaks, default False - scatter_decimate : int - If > 1, the scatter points are decimated, default None - amplitude_cmap : str - The colormap to use for the amplitude, default 'inferno' - amplitude_clim : tuple - The min and max amplitude to display, default None (min and max of the amplitudes) - amplitude_alpha : float - The alpha of the scatter points, default 0.5 + recording : RecordingExtractor, default: None + The recording extractor object (only used to get "real" times) + sampling_frequency : float, default: None + The sampling frequency (needed if recording is None) + depth_lim : tuple or None, default: None + The min and max depth to display, if None (min and max of the recording) + motion_lim : tuple or None, default: None + The min and max motion to display, if None (min and max of the motion) + color_amplitude : bool, default: False + If True, the color of the scatter points is the amplitude of the peaks + scatter_decimate : int, default: None + If > 1, the scatter points are decimated + amplitude_cmap : str, default: "inferno" + The colormap to use for the amplitude + amplitude_clim : tuple or None, default: None + The min and max amplitude to display, if None (min and max of the amplitudes) + amplitude_alpha : float, default: 1 + The alpha of the scatter points """ def __init__( diff --git a/src/spikeinterface/widgets/multicomparison.py b/src/spikeinterface/widgets/multicomparison.py index e01a79dfd5..fb34156fef 100644 --- a/src/spikeinterface/widgets/multicomparison.py +++ b/src/spikeinterface/widgets/multicomparison.py @@ -13,15 +13,15 @@ class MultiCompGraphWidget(BaseWidget): ---------- multi_comparison: BaseMultiComparison The multi comparison object - draw_labels: bool + draw_labels: bool, default: False If True unit labels are shown - node_cmap: matplotlib colormap - The colormap to be used for the nodes (default 'viridis') - edge_cmap: matplotlib colormap - The colormap to be used for the edges (default 'hot') - alpha_edges: float + node_cmap: matplotlib colormap, default: "viridis" + The colormap to be used for the nodes + edge_cmap: matplotlib colormap, default: "hot" + The colormap to be used for the edges + alpha_edges: float, default: 0.5 Alpha value for edges - colorbar: bool + colorbar: bool, default: False If True a colorbar for the edges is plotted """ @@ -119,9 +119,9 @@ class MultiCompGlobalAgreementWidget(BaseWidget): ---------- multi_comparison: BaseMultiComparison The multi comparison object - plot_type: str - 'pie' or 'bar' - cmap: matplotlib colormap, default: 'YlOrRd' + plot_type: "pie" | "bar", default: "pie" + The plot type + cmap: matplotlib colormap, default: "YlOrRd" The colormap to be used for the nodes fontsize: int, default: 9 The text fontsize @@ -197,15 +197,14 @@ class MultiCompAgreementBySorterWidget(BaseWidget): ---------- multi_comparison: BaseMultiComparison The multi comparison object - plot_type: str - 'pie' or 'bar' - cmap: matplotlib colormap - The colormap to be used for the nodes (default 'Reds') - axes: list of matplotlib axes - The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax - and figure parameters are ignored. + plot_type: "pie" | "bar", default: "pie + The plot type + cmap: matplotlib colormap, default: "Reds" + The colormap to be used for the nodes + fontsize: int, default: 9 + The text fontsize show_legend: bool - Show the legend in the last axes (default True). + Show the legend in the last axes Returns ------- diff --git a/src/spikeinterface/widgets/peak_activity.py b/src/spikeinterface/widgets/peak_activity.py index 8501d7ef7d..f2b3562aff 100644 --- a/src/spikeinterface/widgets/peak_activity.py +++ b/src/spikeinterface/widgets/peak_activity.py @@ -21,16 +21,16 @@ class PeakActivityMapWidget(BaseWidget): peaks: None or numpy array Optionally can give already detected peaks to avoid multiple computation. - detect_peaks_kwargs: None or dict + detect_peaks_kwargs: None or dict, default: None If peaks is None here the kwargs for detect_peak function. - bin_duration_s: None or float + bin_duration_s: None or float, default: None If None then static image If not None then it is an animation per bin. - with_contact_color: bool (default True) + with_contact_color: bool, default: True Plot rates with contact colors - with_interpolated_map: bool (default True) + with_interpolated_map: bool, default: True Plot rates with interpolated map - with_channel_ids: bool False default + with_channel_ids: bool, default: False Add channel ids text on the probe diff --git a/src/spikeinterface/widgets/quality_metrics.py b/src/spikeinterface/widgets/quality_metrics.py index 4a6b46b72d..4406557d93 100644 --- a/src/spikeinterface/widgets/quality_metrics.py +++ b/src/spikeinterface/widgets/quality_metrics.py @@ -10,16 +10,16 @@ class QualityMetricsWidget(MetricsBaseWidget): ---------- waveform_extractor : WaveformExtractor The object to compute/get quality metrics from - unit_ids: list - List of unit ids, default None - include_metrics: list - If given, a list of quality metrics to include, default None - skip_metrics: list or None - If given, a list of quality metrics to skip, default None - unit_colors : dict or None - If given, a dictionary with unit ids as keys and colors as values, default None - hide_unit_selector : bool - For sortingview backend, if True the unit selector is not displayed, default False + unit_ids: list or None, default: None + List of unit ids + include_metrics: list or None, default: None + If given, a list of quality metrics to include + skip_metrics: list or None, default: None + If given, a list of quality metrics to skip + unit_colors : dict or None, default: None + If given, a dictionary with unit ids as keys and colors as values + hide_unit_selector : bool, default: False + For sortingview backend, if True the unit selector is not displayed """ def __init__( diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index b9760205f9..9b98c1adaa 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -19,18 +19,24 @@ class SortingSummaryWidget(BaseWidget): Parameters ---------- waveform_extractor : WaveformExtractor - The waveform extractor object. - sparsity : ChannelSparsity or None - Optional ChannelSparsity to apply, default None + The waveform extractor object + unit_ids : list or None, default: None + List of unit ids + sparsity : ChannelSparsity or None, default: None + Optional ChannelSparsity to apply If WaveformExtractor is already sparse, the argument is ignored - max_amplitudes_per_unit : int or None - Maximum number of spikes per unit for plotting amplitudes, - by default None (all spikes) - curation : bool - If True, manual curation is enabled, by default False + max_amplitudes_per_unit : int or None, default: None + Maximum number of spikes per unit for plotting amplitudes. + If None, all spikes are plotted + curation : bool, default: False + If True, manual curation is enabled (sortingview backend) - unit_table_properties : list or None - List of properties to be added to the unit table, by default None + unit_table_properties : list or None, default: None + List of properties to be added to the unit table + label_choices : list or None, default: None + List of labels to be added to the curation table + unit_table_properties : list or None, default: None + List of properties to be added to the unit table (sortingview backend) """ diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index fda2356105..6ab0962f99 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -13,26 +13,25 @@ class SpikeLocationsWidget(BaseWidget): ---------- waveform_extractor : WaveformExtractor The object to compute/get spike locations from - unit_ids : list - List of unit ids, default None - segment_index : int or None - The segment index (or None if mono-segment), default None - max_spikes_per_unit : int + unit_ids : list or None, default: None + List of unit ids + segment_index : int or None, default: None + The segment index (or None if mono-segment) + max_spikes_per_unit : int or None, default: 500 Number of max spikes per unit to display. Use None for all spikes. - Default 500. - with_channel_ids : bool - Add channel ids text on the probe, default False - unit_colors : dict or None - If given, a dictionary with unit ids as keys and colors as values, default None - hide_unit_selector : bool - For sortingview backend, if True the unit selector is not displayed, default False - plot_all_units : bool + with_channel_ids : bool, default: False + Add channel ids text on the probe + unit_colors : dict or None, default: None + If given, a dictionary with unit ids as keys and colors as values + hide_unit_selector : bool, default: False + For sortingview backend, if True the unit selector is not displayed + plot_all_units : bool, default: True If True, all units are plotted. The unselected ones (not in unit_ids), - are plotted in grey. Default True (matplotlib backend) - plot_legend : bool - If True, the legend is plotted. Default False (matplotlib backend) - hide_axis : bool - If True, the axis is set to off. Default False (matplotlib backend) + are plotted in grey (matplotlib backend) + plot_legend : bool, default: False + If True, the legend is plotted (matplotlib backend) + hide_axis : bool, default: False + If True, the axis is set to off (matplotlib backend) """ # possible_backends = {} diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index b68efc3f8a..b6946542b7 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -19,44 +19,43 @@ class SpikesOnTracesWidget(BaseWidget): ---------- waveform_extractor : WaveformExtractor The waveform extractor - channel_ids : list - The channel ids to display, default None - unit_ids : list - List of unit ids, default None - order_channel_by_depth : bool - If true orders channel by depth, default False - time_range: list - List with start time and end time, default None - sparsity : ChannelSparsity or None - Optional ChannelSparsity to apply. - If WaveformExtractor is already sparse, the argument is ignored, default None - unit_colors : dict or None - If given, a dictionary with unit ids as keys and colors as values, default None + channel_ids : list or None, default: None + The channel ids to display + unit_ids : list or None, default: None + List of unit ids + order_channel_by_depth : bool, default: False + If true orders channel by depth + time_range: list or None, default: None + List with start time and end time in seconds + sparsity : ChannelSparsity or None, default: None + Optional ChannelSparsity to apply + If WaveformExtractor is already sparse, the argument is ignored + unit_colors : dict or None, default: None + If given, a dictionary with unit ids as keys and colors as values If None, then the get_unit_colors() is internally used. (matplotlib backend) - mode : str in ('line', 'map', 'auto') default: 'auto' - * 'line': classical for low channel count - * 'map': for high channel count use color heat map - * 'auto': auto switch depending on the channel count ('line' if less than 64 channels, 'map' otherwise) - return_scaled : bool - If True and the recording has scaled traces, it plots the scaled traces, default False - cmap : str - matplotlib colormap used in mode 'map', default 'RdBu' - show_channel_ids : bool - Set yticks with channel ids, default False - color_groups : bool - If True groups are plotted with different colors, default False - color : str - The color used to draw the traces, default None - clim : None, tuple or dict - When mode is 'map', this argument controls color limits. + mode : "line" | "map" | "auto", default: "auto" + * "line": classical for low channel count + * "map": for high channel count use color heat map + * "auto": auto switch depending on the channel count ("line" if less than 64 channels, "map" otherwise) + return_scaled : bool, default: False + If True and the recording has scaled traces, it plots the scaled traces + cmap : str, default: "RdBu" + matplotlib colormap used in mode "map" + show_channel_ids : bool, default: False + Set yticks with channel ids + color_groups : bool, default: False + If True groups are plotted with different colors + color : str or None, default: None + The color used to draw the traces + clim : None, tuple or dict, default: None + When mode is "map", this argument controls color limits. If dict, keys should be the same as recording keys - Default None - with_colorbar : bool - When mode is 'map', a colorbar is added, by default True - tile_size : int - For sortingview backend, the size of each tile in the rendered image, default 512 - seconds_per_row : float - For 'map' mode and sortingview backend, seconds to render in each row, default 0.2 + with_colorbar : bool, default: True + When mode is "map", a colorbar is added + tile_size : int, default: 512 + For sortingview backend, the size of each tile in the rendered image + seconds_per_row : float, default: 0.2 + For "map" mode and sortingview backend, seconds to render in each row """ def __init__( diff --git a/src/spikeinterface/widgets/template_metrics.py b/src/spikeinterface/widgets/template_metrics.py index 748babb57d..1658efe737 100644 --- a/src/spikeinterface/widgets/template_metrics.py +++ b/src/spikeinterface/widgets/template_metrics.py @@ -10,16 +10,16 @@ class TemplateMetricsWidget(MetricsBaseWidget): ---------- waveform_extractor : WaveformExtractor The object to compute/get template metrics from - unit_ids : list - List of unit ids, default None - include_metrics : list - If given list of quality metrics to include, default None - skip_metrics : list or None - If given, a list of quality metrics to skip, default None - unit_colors : dict or None - If given, a dictionary with unit ids as keys and colors as values, default None - hide_unit_selector : bool - For sortingview backend, if True the unit selector is not displayed, default False + unit_ids : list or None, default: None + List of unit ids + include_metrics : list or None, default: None + If given list of quality metrics to include + skip_metrics : list or None or None, default: None + If given, a list of quality metrics to skip + unit_colors : dict or None, default: None + If given, a dictionary with unit ids as keys and colors as values + hide_unit_selector : bool, default: False + For sortingview backend, if True the unit selector is not displayed """ def __init__( diff --git a/src/spikeinterface/widgets/template_similarity.py b/src/spikeinterface/widgets/template_similarity.py index 63ac177835..4ab469f456 100644 --- a/src/spikeinterface/widgets/template_similarity.py +++ b/src/spikeinterface/widgets/template_similarity.py @@ -12,17 +12,17 @@ class TemplateSimilarityWidget(BaseWidget): ---------- waveform_extractor : WaveformExtractor The object to compute/get template similarity from - unit_ids : list - List of unit ids default None - display_diagonal_values : bool + unit_ids : list or None, default: None + List of unit ids default: None + display_diagonal_values : bool, default: False If False, the diagonal is displayed as zeros. - If True, the similarity values (all 1s) are displayed, default False - cmap : str - The matplotlib colormap. Default 'viridis'. - show_unit_ticks : bool - If True, ticks display unit ids, default False. - show_colorbar : bool - If True, color bar is displayed, default True. + If True, the similarity values (all 1s) are displayed + cmap : matplotlib colormap, default: "viridis" + The matplotlib colormap + show_unit_ticks : bool, default: False + If True, ticks display unit ids + show_colorbar : bool, default: True + If True, color bar is displayed """ def __init__( diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index fc8b30eb05..63fe4e8d8f 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -16,42 +16,41 @@ class TracesWidget(BaseWidget): recording: RecordingExtractor, dict, or list The recording extractor object. If dict (or list) then it is a multi-layer display to compare, for example, different processing steps - segment_index: None or int - The segment index (required for multi-segment recordings), default None - channel_ids: list - The channel ids to display, default None - order_channel_by_depth: bool - Reorder channel by depth, default False - time_range: list - List with start time and end time, default None - mode: str - Three possible modes, default 'auto': - - * 'line': classical for low channel count - * 'map': for high channel count use color heat map - * 'auto': auto switch depending on the channel count ('line' if less than 64 channels, 'map' otherwise) - return_scaled: bool - If True and the recording has scaled traces, it plots the scaled traces, default False - cmap: str - matplotlib colormap used in mode 'map', default 'RdBu' - show_channel_ids: bool - Set yticks with channel ids, default False - color_groups: bool - If True groups are plotted with different colors, default False - color: str - The color used to draw the traces, default None - clim: None, tuple or dict - When mode is 'map', this argument controls color limits. + segment_index: None or int, default: None + The segment index (required for multi-segment recordings) + channel_ids: list or None, default: None + The channel ids to display + order_channel_by_depth: bool, default: False + Reorder channel by depth + time_range: list, tuple or None, default: None + List with start time and end time + mode: "line" | "map" | "auto", default: "auto" + Three possible modes + + * "line": classical for low channel count + * "map": for high channel count use color heat map + * "auto": auto switch depending on the channel count ("line" if less than 64 channels, "map" otherwise) + return_scaled: bool, default: False + If True and the recording has scaled traces, it plots the scaled traces + cmap: matplotlib colormap, default: "RdBu_r" + matplotlib colormap used in mode "map" + show_channel_ids: bool, default: False + Set yticks with channel ids + color_groups: bool, default: False + If True groups are plotted with different colors + color: str or None, default: None + The color used to draw the traces + clim: None, tuple or dict, default: None + When mode is "map", this argument controls color limits. If dict, keys should be the same as recording keys - Default None - with_colorbar: bool - When mode is 'map', a colorbar is added, by default True - tile_size: int - For sortingview backend, the size of each tile in the rendered image, default 1500 - seconds_per_row: float - For 'map' mode and sortingview backend, seconds to render in each row, default 0.2 - add_legend : bool - If True adds legend to figures, default True + with_colorbar: bool, default: True + When mode is "map", a colorbar is added + tile_size: int, default: 1500 + For sortingview backend, the size of each tile in the rendered image + seconds_per_row: float, default: 0.2 + For "map" mode and sortingview backend, seconds to render in each row + add_legend : bool, default: True + If True adds legend to figures, default: True """ def __init__( diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index 1cc7c909a1..1e40a7940e 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -16,12 +16,12 @@ class UnitDepthsWidget(BaseWidget): ---------- waveform_extractor : WaveformExtractor The input waveform extractor - unit_colors : dict or None - If given, a dictionary with unit ids as keys and colors as values, default None - depth_axis : int - The dimension of unit_locations that is depth, default 1 - peak_sign: str (neg/pos/both) - Sign of peak for amplitudes, default 'neg' + unit_colors : dict or None, default: None + If given, a dictionary with unit ids as keys and colors as values + depth_axis : int, default: 1 + The dimension of unit_locations that is depth + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of peak for amplitudes """ def __init__( diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index b41ee3508b..d5f26f6dfd 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -16,21 +16,21 @@ class UnitLocationsWidget(BaseWidget): ---------- waveform_extractor : WaveformExtractor The object to compute/get unit locations from - unit_ids : list - List of unit ids default None - with_channel_ids : bool - Add channel ids text on the probe, default False - unit_colors : dict or None - If given, a dictionary with unit ids as keys and colors as values, default None - hide_unit_selector : bool - If True, the unit selector is not displayed, default False (sortingview backend) - plot_all_units : bool + unit_ids : list or None, default: None + List of unit ids + with_channel_ids : bool, default: False + Add channel ids text on the probe + unit_colors : dict or None, default: None + If given, a dictionary with unit ids as keys and colors as values + hide_unit_selector : bool, default: False + If True, the unit selector is not displayed (sortingview backend) + plot_all_units : bool, default: True If True, all units are plotted. The unselected ones (not in unit_ids), - are plotted in grey, default True (matplotlib backend) - plot_legend : bool - If True, the legend is plotted, default False (matplotlib backend) - hide_axis : bool - If True, the axis is set to off, default False (matplotlib backend) + are plotted in grey (matplotlib backend) + plot_legend : bool, default: False + If True, the legend is plotted (matplotlib backend) + hide_axis : bool, default: False + If True, the axis is set to off (matplotlib backend) """ def __init__( diff --git a/src/spikeinterface/widgets/unit_presence.py b/src/spikeinterface/widgets/unit_presence.py index 3d605936a2..0d7429f17d 100644 --- a/src/spikeinterface/widgets/unit_presence.py +++ b/src/spikeinterface/widgets/unit_presence.py @@ -13,19 +13,18 @@ class UnitPresenceWidget(BaseWidget): The sorting extractor object segment_index: None or int The segment index. - time_range: list + time_range: list or None, default: None List with start time and end time - bin_duration_s: float, default 0.5 - Bin size (in seconds) for the heat map time axis. - smooth_sigma: float or None - + bin_duration_s: float, default: 0.5 + Bin size (in seconds) for the heat map time axis + smooth_sigma: float, default: 4.5 + Sigma for the Gaussian kernel (in number of bins) """ def __init__( self, sorting, segment_index=None, - unit_ids=None, time_range=None, bin_duration_s=0.05, smooth_sigma=4.5, diff --git a/src/spikeinterface/widgets/unit_probe_map.py b/src/spikeinterface/widgets/unit_probe_map.py index 4068c1c530..b8eea80ef4 100644 --- a/src/spikeinterface/widgets/unit_probe_map.py +++ b/src/spikeinterface/widgets/unit_probe_map.py @@ -22,9 +22,9 @@ class UnitProbeMapWidget(BaseWidget): List of unit ids. channel_ids: list The channel ids to display - animated: True/False - animation for amplitude on time - with_channel_ids: bool False default + animated: bool, default: False + Animation for amplitude on time + with_channel_ids: bool, default: False add channel ids text on the probe """ diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 964b5813e6..35fde07326 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -23,10 +23,10 @@ class UnitSummaryWidget(BaseWidget): The waveform extractor object unit_id : int or str The unit id to plot the summary of - unit_colors : dict or None - If given, a dictionary with unit ids as keys and colors as values, default None - sparsity : ChannelSparsity or None - Optional ChannelSparsity to apply, default None + unit_colors : dict or None, default: None + If given, a dictionary with unit ids as keys and colors as values, + sparsity : ChannelSparsity or None, default: None + Optional ChannelSparsity to apply. If WaveformExtractor is already sparse, the argument is ignored """ diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 8ffc931bf2..71d5f1663b 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -16,47 +16,47 @@ class UnitWaveformsWidget(BaseWidget): ---------- waveform_extractor : WaveformExtractor The input waveform extractor - channel_ids: list - The channel ids to display, default None - unit_ids : list - List of unit ids, default None - plot_templates : bool - If True, templates are plotted over the waveforms, default True - sparsity : ChannelSparsity or None - Optional ChannelSparsity to apply, default None + channel_ids: list or None, default: None + The channel ids to display + unit_ids : list or None, default: None + List of unit ids + plot_templates : bool, default: True + If True, templates are plotted over the waveforms + sparsity : ChannelSparsity or None, default: None + Optional ChannelSparsity to apply If WaveformExtractor is already sparse, the argument is ignored - set_title : bool - Create a plot title with the unit number if True, default True - plot_channels : bool - Plot channel locations below traces, default False - unit_selected_waveforms : None or dict + set_title : bool, default: True + Create a plot title with the unit number if True + plot_channels : bool, default: False + Plot channel locations below traces + unit_selected_waveforms : None or dict, default: None A dict key is unit_id and value is the subset of waveforms indices that should be - be displayed (matplotlib backend), default None - max_spikes_per_unit : int or None + be displayed (matplotlib backend) + max_spikes_per_unit : int or None, default: 50 If given and unit_selected_waveforms is None, only max_spikes_per_unit random units are - displayed per waveform, default 50 (matplotlib backend) - axis_equal : bool - Equal aspect ratio for x and y axis, to visualize the array geometry to scale, default False - lw_waveforms : float - Line width for the waveforms, default 1 (matplotlib backend) - lw_templates : float - Line width for the templates, default 2 (matplotlib backend) - unit_colors : None or dict - A dict key is unit_id and value is any color format handled by matplotlib, default None + displayed per waveform, (matplotlib backend) + axis_equal : bool, default: False + Equal aspect ratio for x and y axis, to visualize the array geometry to scale + lw_waveforms : float, default: 1 + Line width for the waveforms, (matplotlib backend) + lw_templates : float, default: 2 + Line width for the templates, (matplotlib backend) + unit_colors : None or dict, default: None + A dict key is unit_id and value is any color format handled by matplotlib. If None, then the get_unit_colors() is internally used. (matplotlib backend) - alpha_waveforms : float - Alpha value for waveforms, default 0.5 (matplotlib backend) - alpha_templates : float - Alpha value for templates, default 1 (matplotlib backend) - hide_unit_selector : bool - For sortingview backend, if True the unit selector is not displayed, default False - same_axis : bool - If True, waveforms and templates are displayed on the same axis, default False (matplotlib backend) - x_offset_units : bool + alpha_waveforms : float, default: 0.5 + Alpha value for waveforms (matplotlib backend) + alpha_templates : float, default: 1 + Alpha value for templates, (matplotlib backend) + hide_unit_selector : bool, default: False + For sortingview backend, if True the unit selector is not displayed + same_axis : bool, default: False + If True, waveforms and templates are displayed on the same axis (matplotlib backend) + x_offset_units : bool, default: False In case same_axis is True, this parameter allow to x-offset the waveforms for different units - (recommended for a few units), default False (matlotlib backend) - plot_legend : bool - Display legend, default True + (recommended for a few units) (matlotlib backend) + plot_legend : bool, default: True + Display legend (matplotlib backend) """ def __init__( diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index b3391c0712..c49d866139 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -14,23 +14,23 @@ class UnitWaveformDensityMapWidget(BaseWidget): ---------- waveform_extractor : WaveformExtractor The waveformextractor for calculating waveforms - channel_ids : list - The channel ids to display, default None - unit_ids : list - List of unit ids, default None - sparsity : ChannelSparsity or None - Optional ChannelSparsity to apply, default None + channel_ids : list or None, default: None + The channel ids to display + unit_ids : list or None, default: None + List of unit ids + sparsity : ChannelSparsity or None, default: None + Optional ChannelSparsity to apply If WaveformExtractor is already sparse, the argument is ignored - use_max_channel : bool - Use only the max channel, default False - peak_sign : str (neg/pos/both) - Used to detect max channel only when use_max_channel=True, default 'neg' - unit_colors : None or dict + use_max_channel : bool, default: False + Use only the max channel + peak_sign : "neg" | "pos" | "both", default: "neg" + Used to detect max channel only when use_max_channel=True + unit_colors : None or dict, default: None A dict key is unit_id and value is any color format handled by matplotlib. - If None, then the get_unit_colors() is internally used, default None - same_axis : bool + If None, then the get_unit_colors() is internally used + same_axis : bool, default: False If True then all density are plot on the same axis and then channels is the union - all channel per units, default False + all channel per units """ def __init__( diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 03d5be0c53..e31ef7679e 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -24,19 +24,19 @@ def get_some_colors(keys, color_engine="auto", map_name="gist_ncar", format="RGB Parameters ---------- - color_engine : str 'auto' / 'matplotlib' / 'colorsys' / 'distinctipy' + color_engine : "auto" | "matplotlib" | "colorsys" | "distinctipy", default: "auto" The engine to generate colors map_name : str Used for matplotlib - format: str - The output formats, default 'RGBA' - shuffle : bool or None - Shuffle or not, default None + format: str, default: "RGBA" + The output formats + shuffle : bool or None, default: None + Shuffle or not the colors. If None then: * set to True for matplotlib and colorsys * set to False for distinctipy - seed: int or None - Set the seed, default None + seed: int or None, default: None + Set the seed Returns -------