From cf5041062a73b1c61d8f15200ade73ea1f1d8bae Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 19:14:51 +0100 Subject: [PATCH 01/44] Run checks for singularity, docker and related python module installations. --- src/spikeinterface/sorters/runsorter.py | 18 ++++++++++++++ src/spikeinterface/sorters/utils/misc.py | 31 ++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index baec6aaac3..44a08a34a7 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -169,6 +169,15 @@ def run_sorter( container_image = None else: container_image = docker_image + + if not has_docker(): + raise RuntimeError("Docker is not installed. Install docker " + "on this machine to run sorting with docker.") + + if not has_docker_python(): + raise RuntimeError("The python `docker` package must be installed." + "Install with `pip install docker`") + else: mode = "singularity" assert not docker_image @@ -176,6 +185,15 @@ def run_sorter( container_image = None else: container_image = singularity_image + + if not has_singularity(): + raise RuntimeError("Singularity is not installed. Install singularity " + "on this machine to run sorting with singularity.") + + if not has_spython(): + raise RuntimeError("The python singularity package must be installed." + "Install with `pip install spython`") + return run_sorter_container( container_image=container_image, mode=mode, diff --git a/src/spikeinterface/sorters/utils/misc.py b/src/spikeinterface/sorters/utils/misc.py index 0a6b4a986c..a1cf34f059 100644 --- a/src/spikeinterface/sorters/utils/misc.py +++ b/src/spikeinterface/sorters/utils/misc.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +import subprocess # TODO: decide best format for this from subprocess import check_output, CalledProcessError from typing import List, Union @@ -80,3 +81,33 @@ def has_nvidia(): return device_count > 0 except RuntimeError: # Failed to dlopen libcuda.so return False + +def _run_subprocess_silently(command): + output = subprocess.run( + command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + ) + return output + + +def has_docker(): + return self._run_subprocess_silently("docker --version").returncode == 0 + + +def has_singularity(): + return self._run_subprocess_silently("singularity --version").returncode == 0 + + +def has_docker_python(): + try: + import docker + return True + except ImportError: + return False + + +def has_spython(): + try: + import spython + return True + except ImportError: + return False From e49521939f2023c50943afad21a663c3d7822011 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 20:09:03 +0100 Subject: [PATCH 02/44] Add nvidia dependency checks, tidy up. --- src/spikeinterface/sorters/runsorter.py | 17 +++++++++++---- src/spikeinterface/sorters/utils/__init__.py | 2 +- src/spikeinterface/sorters/utils/misc.py | 22 +++++++++++++++++--- 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 44a08a34a7..884cba590f 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -19,7 +19,7 @@ from ..core import BaseRecording, NumpySorting, load_extractor from ..core.core_tools import check_json, is_editable_mode from .sorterlist import sorter_dict -from .utils import SpikeSortingError, has_nvidia +from .utils import SpikeSortingError, has_nvidia, has_docker, has_docker_python, has_singularity, has_spython, has_docker_nvidia_installed, get_nvidia_docker_dependecies from .container_tools import ( find_recording_folders, path_to_unix, @@ -175,7 +175,7 @@ def run_sorter( "on this machine to run sorting with docker.") if not has_docker_python(): - raise RuntimeError("The python `docker` package must be installed." + raise RuntimeError("The python `docker` package must be installed. " "Install with `pip install docker`") else: @@ -191,8 +191,8 @@ def run_sorter( "on this machine to run sorting with singularity.") if not has_spython(): - raise RuntimeError("The python singularity package must be installed." - "Install with `pip install spython`") + raise RuntimeError("The python `spython` package must be installed to " + "run singularity. Install with `pip install spython`") return run_sorter_container( container_image=container_image, @@ -480,6 +480,15 @@ def run_sorter_container( if gpu_capability == "nvidia-required": assert has_nvidia(), "The container requires a NVIDIA GPU capability, but it is not available" extra_kwargs["container_requires_gpu"] = True + + if platform.system() == "Linux" and has_docker_nvidia_installed(): + warn( + f"nvidia-required but none of \n{get_nvidia_docker_dependecies()}\n were found. " + f"This may result in an error being raised during sorting. Try " + "installing `nvidia-container-toolkit`, including setting the " + "configuration steps, if running into errors." + ) + elif gpu_capability == "nvidia-optional": if has_nvidia(): extra_kwargs["container_requires_gpu"] = True diff --git a/src/spikeinterface/sorters/utils/__init__.py b/src/spikeinterface/sorters/utils/__init__.py index 6cad10b211..7f6f3089d4 100644 --- a/src/spikeinterface/sorters/utils/__init__.py +++ b/src/spikeinterface/sorters/utils/__init__.py @@ -1,2 +1,2 @@ from .shellscript import ShellScript -from .misc import SpikeSortingError, get_git_commit, has_nvidia, get_matlab_shell_name, get_bash_path +from .misc import SpikeSortingError, get_git_commit, has_nvidia, get_matlab_shell_name, get_bash_path, has_docker, has_docker_python, has_singularity, has_spython, has_docker_nvidia_installed, get_nvidia_docker_dependecies diff --git a/src/spikeinterface/sorters/utils/misc.py b/src/spikeinterface/sorters/utils/misc.py index a1cf34f059..4a900f4485 100644 --- a/src/spikeinterface/sorters/utils/misc.py +++ b/src/spikeinterface/sorters/utils/misc.py @@ -82,6 +82,7 @@ def has_nvidia(): except RuntimeError: # Failed to dlopen libcuda.so return False + def _run_subprocess_silently(command): output = subprocess.run( command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL @@ -90,12 +91,27 @@ def _run_subprocess_silently(command): def has_docker(): - return self._run_subprocess_silently("docker --version").returncode == 0 + return _run_subprocess_silently("docker --version").returncode == 0 def has_singularity(): - return self._run_subprocess_silently("singularity --version").returncode == 0 - + return _run_subprocess_silently("singularity --version").returncode == 0 + +def get_nvidia_docker_dependecies(): + return [ + "nvidia-docker", + "nvidia-docker2", + "nvidia-container-toolkit", + ] + +def has_docker_nvidia_installed(): + all_dependencies = get_nvidia_docker_dependecies() + has_dep = [] + for dep in all_dependencies: + has_dep.append( + _run_subprocess_silently(f"{dep} --version").returncode == 0 + ) + return not any(has_dep) def has_docker_python(): try: From e0656bb86901127c8b1c0f708e4970584e79a40d Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 20:15:48 +0100 Subject: [PATCH 03/44] Add docstrings. --- src/spikeinterface/sorters/utils/misc.py | 44 ++++++++++++++++++------ 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/sorters/utils/misc.py b/src/spikeinterface/sorters/utils/misc.py index 4a900f4485..66744fbab1 100644 --- a/src/spikeinterface/sorters/utils/misc.py +++ b/src/spikeinterface/sorters/utils/misc.py @@ -84,9 +84,10 @@ def has_nvidia(): def _run_subprocess_silently(command): - output = subprocess.run( - command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL - ) + """ + Run a subprocess command without outputting to stderr or stdout. + """ + output = subprocess.run(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) return output @@ -97,25 +98,45 @@ def has_docker(): def has_singularity(): return _run_subprocess_silently("singularity --version").returncode == 0 + +def has_docker_nvidia_installed(): + """ + On Linux, nvidia has a set of container dependencies + that are required for running GPU in docker. This is a little + complex and is described in more detail in the links below. + To summarise breifly, at least one of the `get_nvidia_docker_dependecies()` + is almost certainly required to run docker with GPU. + + https://github.com/NVIDIA/nvidia-docker/issues/1268 + https://www.howtogeek.com/devops/how-to-use-an-nvidia-gpu-with-docker-containers/ + + Returns + ------- + Whether at least one of the dependencies listed in + `get_nvidia_docker_dependecies()` is installed. + """ + all_dependencies = get_nvidia_docker_dependecies() + has_dep = [] + for dep in all_dependencies: + has_dep.append(_run_subprocess_silently(f"{dep} --version").returncode == 0) + return not any(has_dep) + + def get_nvidia_docker_dependecies(): + """ + See `has_docker_nvidia_installed()` + """ return [ "nvidia-docker", "nvidia-docker2", "nvidia-container-toolkit", ] -def has_docker_nvidia_installed(): - all_dependencies = get_nvidia_docker_dependecies() - has_dep = [] - for dep in all_dependencies: - has_dep.append( - _run_subprocess_silently(f"{dep} --version").returncode == 0 - ) - return not any(has_dep) def has_docker_python(): try: import docker + return True except ImportError: return False @@ -124,6 +145,7 @@ def has_docker_python(): def has_spython(): try: import spython + return True except ImportError: return False From b145b04ac31a8de3d9c9fbfc56b4a9974ce0eb3a Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 21:21:51 +0100 Subject: [PATCH 04/44] Add tests for runsorter dependencies. --- src/spikeinterface/sorters/runsorter.py | 35 +++-- .../tests/test_runsorter_dependency_checks.py | 144 ++++++++++++++++++ 2 files changed, 170 insertions(+), 9 deletions(-) create mode 100644 src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 884cba590f..5b2e80b83d 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -19,7 +19,18 @@ from ..core import BaseRecording, NumpySorting, load_extractor from ..core.core_tools import check_json, is_editable_mode from .sorterlist import sorter_dict -from .utils import SpikeSortingError, has_nvidia, has_docker, has_docker_python, has_singularity, has_spython, has_docker_nvidia_installed, get_nvidia_docker_dependecies + +# full import required for monkeypatch testing. +from spikeinterface.sorters.utils import ( + SpikeSortingError, + has_nvidia, + has_docker, + has_docker_python, + has_singularity, + has_spython, + has_docker_nvidia_installed, + get_nvidia_docker_dependecies, +) from .container_tools import ( find_recording_folders, path_to_unix, @@ -171,12 +182,14 @@ def run_sorter( container_image = docker_image if not has_docker(): - raise RuntimeError("Docker is not installed. Install docker " - "on this machine to run sorting with docker.") + raise RuntimeError( + "Docker is not installed. Install docker " "on this machine to run sorting with docker." + ) if not has_docker_python(): - raise RuntimeError("The python `docker` package must be installed. " - "Install with `pip install docker`") + raise RuntimeError( + "The python `docker` package must be installed. " "Install with `pip install docker`" + ) else: mode = "singularity" @@ -187,12 +200,16 @@ def run_sorter( container_image = singularity_image if not has_singularity(): - raise RuntimeError("Singularity is not installed. Install singularity " - "on this machine to run sorting with singularity.") + raise RuntimeError( + "Singularity is not installed. Install singularity " + "on this machine to run sorting with singularity." + ) if not has_spython(): - raise RuntimeError("The python `spython` package must be installed to " - "run singularity. Install with `pip install spython`") + raise RuntimeError( + "The python `spython` package must be installed to " + "run singularity. Install with `pip install spython`" + ) return run_sorter_container( container_image=container_image, diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py new file mode 100644 index 0000000000..8dbb1b20f6 --- /dev/null +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -0,0 +1,144 @@ +import os +import pytest +from pathlib import Path +import shutil +import platform +from spikeinterface import generate_ground_truth_recording +from spikeinterface.sorters.utils import has_spython, has_docker_python +from spikeinterface.sorters import run_sorter +import subprocess +import sys +import copy + + +def _monkeypatch_return_false(): + return False + + +class TestRunersorterDependencyChecks: + """ + This class performs tests to check whether expected + dependency checks prior to sorting are run. The + run_sorter function should raise an error if: + - singularity is not installed + - spython is not installed (python package) + - docker is not installed + - docker is not installed (python package) + when running singularity / docker respectively. + + Two separate checks should be run. First, that the + relevant `has_` function (indicating if the dependency + is installed) is working. Unfortunately it is not possible to + easily test this core singularity and docker installs, so this is not done. + `uninstall_python_dependency()` allows a test to check if the + `has_spython()` and `has_docker_dependency()` return `False` as expected + when these python modules are not installed. + + Second, the `run_sorters()` function should return the appropriate error + when these functions return that the dependency is not available. This is + easier to test as these `has_` reporting functions can be + monkeypatched to return False at runtime. This is done for these 4 + dependency checks, and tests check the expected error is raised. + + Notes + ---- + `has_nvidia()` and `has_docker_nvidia_installed()` are not tested + as these are complex GPU-related dependencies which are difficult to mock. + """ + + @pytest.fixture(scope="function") + def uninstall_python_dependency(self, request): + """ + This python fixture mocks python modules not been importable + by setting the relevant `sys.modules` dict entry to `None`. + It uses `yeild` so that the function can tear-down the test + (even if it failed) and replace the patched `sys.module` entry. + + This function uses an `indirect` parameterisation, meaning the + `request.param` is passed to the fixture at the start of the + test function. This is used to reuse code for nearly identical + `spython` and `docker` python dependency tests. + """ + dep_name = request.param + assert dep_name in ["spython", "docker"] + + try: + if dep_name == "spython": + import spython + else: + import docker + dependency_installed = True + except: + dependency_installed = False + + if dependency_installed: + copy_import = sys.modules[dep_name] + sys.modules[dep_name] = None + yield + if dependency_installed: + sys.modules[dep_name] = copy_import + + @pytest.fixture(scope="session") + def recording(self): + """ + Make a small recording to have something to pass to the sorter. + """ + recording, _ = generate_ground_truth_recording(durations=[10]) + return recording + + @pytest.mark.skipif(platform.system() != "Linux", reason="spython install only for Linux.") + @pytest.mark.parametrize("uninstall_python_dependency", ["spython"], indirect=True) + def test_has_spython(self, recording, uninstall_python_dependency): + """ + Test the `has_spython()` function, see class docstring and + `uninstall_python_dependency()` for details. + """ + assert has_spython() is False + + @pytest.mark.parametrize("uninstall_python_dependency", ["docker"], indirect=True) + def test_has_docker_python(self, recording, uninstall_python_dependency): + """ + Test the `has_docker_python()` function, see class docstring and + `uninstall_python_dependency()` for details. + """ + assert has_docker_python() is False + + @pytest.mark.parametrize("dependency", ["singularity", "spython"]) + def test_has_singularity_and_spython(self, recording, monkeypatch, dependency): + """ + When running a sorting, if singularity dependencies (singularity + itself or the `spython` package`) are not installed, an error is raised. + Beacause it is hard to actually uninstall these dependencies, the + `has_` functions that let `run_sorter` know if the dependency + are installed are monkeypatched. This is done so at runtime these always + return False. Then, test the expected error is raised when the dependency + is not found. + """ + test_func = f"has_{dependency}" + + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.{test_func}", _monkeypatch_return_false) + with pytest.raises(RuntimeError) as e: + run_sorter("kilosort2_5", recording, singularity_image=True) + + if dependency == "spython": + assert "The python `spython` package must be installed" in str(e) + else: + assert "Singularity is not installed." in str(e) + + @pytest.mark.parametrize("dependency", ["docker", "docker_python"]) + def test_has_docker_and_docker_python(self, recording, monkeypatch, dependency): + """ + See `test_has_singularity_and_spython()` for details. This test + is almost identical, but with some key changes for Docker. + """ + test_func = f"has_{dependency}" + + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.{test_func}", _monkeypatch_return_false) + + with pytest.raises(RuntimeError) as e: + run_sorter("kilosort2_5", recording, docker_image=True) + + if dependency == "docker_python": + assert "The python `docker` package must be installed" in str(e) + else: + assert "Docker is not installed." in str(e) From 78ccc2719676b238dbd92d2ad5384786ca0724e0 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 21:24:29 +0100 Subject: [PATCH 05/44] Remove unnecessary non-relative import. --- src/spikeinterface/sorters/runsorter.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 5b2e80b83d..c16435cdb5 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -19,9 +19,7 @@ from ..core import BaseRecording, NumpySorting, load_extractor from ..core.core_tools import check_json, is_editable_mode from .sorterlist import sorter_dict - -# full import required for monkeypatch testing. -from spikeinterface.sorters.utils import ( +from .utils import ( SpikeSortingError, has_nvidia, has_docker, From f1438c4ce20bbd7ae3c910b793f92ebb4d723253 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 21:27:03 +0100 Subject: [PATCH 06/44] Fix some string formatting, add docstring to monkeypatch function. --- src/spikeinterface/sorters/runsorter.py | 6 ++---- .../sorters/tests/test_runsorter_dependency_checks.py | 4 ++++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index c16435cdb5..f9994dd38d 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -181,13 +181,11 @@ def run_sorter( if not has_docker(): raise RuntimeError( - "Docker is not installed. Install docker " "on this machine to run sorting with docker." + "Docker is not installed. Install docker on this machine to run sorting with docker." ) if not has_docker_python(): - raise RuntimeError( - "The python `docker` package must be installed. " "Install with `pip install docker`" - ) + raise RuntimeError("The python `docker` package must be installed. Install with `pip install docker`") else: mode = "singularity" diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py index 8dbb1b20f6..c81593b7db 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -12,6 +12,10 @@ def _monkeypatch_return_false(): + """ + A function to monkeypatch the `has_` functions, + ensuring the always return `False` at runtime. + """ return False From fd4406e0826f80329614e3b59388e9640c00fe3e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Jun 2024 20:27:36 +0000 Subject: [PATCH 07/44] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/utils/__init__.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/utils/__init__.py b/src/spikeinterface/sorters/utils/__init__.py index 7f6f3089d4..62317be6f2 100644 --- a/src/spikeinterface/sorters/utils/__init__.py +++ b/src/spikeinterface/sorters/utils/__init__.py @@ -1,2 +1,14 @@ from .shellscript import ShellScript -from .misc import SpikeSortingError, get_git_commit, has_nvidia, get_matlab_shell_name, get_bash_path, has_docker, has_docker_python, has_singularity, has_spython, has_docker_nvidia_installed, get_nvidia_docker_dependecies +from .misc import ( + SpikeSortingError, + get_git_commit, + has_nvidia, + get_matlab_shell_name, + get_bash_path, + has_docker, + has_docker_python, + has_singularity, + has_spython, + has_docker_nvidia_installed, + get_nvidia_docker_dependecies, +) From 7af611ba289e220c4bf36f4b62ae26efe94f93b1 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 21:42:26 +0100 Subject: [PATCH 08/44] Mock all has functions to ensure tests do not depend on actual dependencies. --- .../tests/test_runsorter_dependency_checks.py | 58 +++++++++++++------ 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py index c81593b7db..a248033089 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -4,7 +4,7 @@ import shutil import platform from spikeinterface import generate_ground_truth_recording -from spikeinterface.sorters.utils import has_spython, has_docker_python +from spikeinterface.sorters.utils import has_spython, has_docker_python, has_docker, has_singularity from spikeinterface.sorters import run_sorter import subprocess import sys @@ -19,6 +19,10 @@ def _monkeypatch_return_false(): return False +def _monkeypatch_return_true(): + return True + + class TestRunersorterDependencyChecks: """ This class performs tests to check whether expected @@ -91,6 +95,7 @@ def recording(self): return recording @pytest.mark.skipif(platform.system() != "Linux", reason="spython install only for Linux.") + @pytest.mark.skipif(not has_singularity(), reason="singularity required for this test.") @pytest.mark.parametrize("uninstall_python_dependency", ["spython"], indirect=True) def test_has_spython(self, recording, uninstall_python_dependency): """ @@ -100,6 +105,7 @@ def test_has_spython(self, recording, uninstall_python_dependency): assert has_spython() is False @pytest.mark.parametrize("uninstall_python_dependency", ["docker"], indirect=True) + @pytest.mark.skipif(not has_docker(), reason="docker required for this test.") def test_has_docker_python(self, recording, uninstall_python_dependency): """ Test the `has_docker_python()` function, see class docstring and @@ -107,8 +113,7 @@ def test_has_docker_python(self, recording, uninstall_python_dependency): """ assert has_docker_python() is False - @pytest.mark.parametrize("dependency", ["singularity", "spython"]) - def test_has_singularity_and_spython(self, recording, monkeypatch, dependency): + def test_no_singularity_error_raised(self, recording, monkeypatch): """ When running a sorting, if singularity dependencies (singularity itself or the `spython` package`) are not installed, an error is raised. @@ -118,31 +123,46 @@ def test_has_singularity_and_spython(self, recording, monkeypatch, dependency): return False. Then, test the expected error is raised when the dependency is not found. """ - test_func = f"has_{dependency}" + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.has_singularity", _monkeypatch_return_false) - monkeypatch.setattr(f"spikeinterface.sorters.runsorter.{test_func}", _monkeypatch_return_false) with pytest.raises(RuntimeError) as e: run_sorter("kilosort2_5", recording, singularity_image=True) - if dependency == "spython": - assert "The python `spython` package must be installed" in str(e) - else: - assert "Singularity is not installed." in str(e) + assert "Singularity is not installed." in str(e) - @pytest.mark.parametrize("dependency", ["docker", "docker_python"]) - def test_has_docker_and_docker_python(self, recording, monkeypatch, dependency): + def test_no_spython_error_raised(self, recording, monkeypatch): """ - See `test_has_singularity_and_spython()` for details. This test - is almost identical, but with some key changes for Docker. + See `test_no_singularity_error_raised()`. """ - test_func = f"has_{dependency}" + # make sure singularity test returns true as that comes first + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.has_singularity", _monkeypatch_return_true) + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.has_spython", _monkeypatch_return_false) + + with pytest.raises(RuntimeError) as e: + run_sorter("kilosort2_5", recording, singularity_image=True) + + assert "The python `spython` package must be installed" in str(e) - monkeypatch.setattr(f"spikeinterface.sorters.runsorter.{test_func}", _monkeypatch_return_false) + def test_no_docker_error_raised(self, recording, monkeypatch): + """ + See `test_no_singularity_error_raised()`. + """ + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.has_docker", _monkeypatch_return_false) + + with pytest.raises(RuntimeError) as e: + run_sorter("kilosort2_5", recording, docker_image=True) + + assert "Docker is not installed." in str(e) + + def test_as_no_docker_python_error_raised(self, recording, monkeypatch): + """ + See `test_no_singularity_error_raised()`. + """ + # make sure docker test returns true as that comes first + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.has_docker", _monkeypatch_return_true) + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.has_docker_python", _monkeypatch_return_false) with pytest.raises(RuntimeError) as e: run_sorter("kilosort2_5", recording, docker_image=True) - if dependency == "docker_python": - assert "The python `docker` package must be installed" in str(e) - else: - assert "Docker is not installed." in str(e) + assert "The python `docker` package must be installed" in str(e) From 0c0b1f908d8e356b9a58cacd4524ace871ff93b3 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 21:43:10 +0100 Subject: [PATCH 09/44] Remove unecessary skips. --- .../sorters/tests/test_runsorter_dependency_checks.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py index a248033089..741fe4ae0e 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -95,7 +95,6 @@ def recording(self): return recording @pytest.mark.skipif(platform.system() != "Linux", reason="spython install only for Linux.") - @pytest.mark.skipif(not has_singularity(), reason="singularity required for this test.") @pytest.mark.parametrize("uninstall_python_dependency", ["spython"], indirect=True) def test_has_spython(self, recording, uninstall_python_dependency): """ @@ -105,7 +104,6 @@ def test_has_spython(self, recording, uninstall_python_dependency): assert has_spython() is False @pytest.mark.parametrize("uninstall_python_dependency", ["docker"], indirect=True) - @pytest.mark.skipif(not has_docker(), reason="docker required for this test.") def test_has_docker_python(self, recording, uninstall_python_dependency): """ Test the `has_docker_python()` function, see class docstring and From 1be1dbd39a339ff56c0803ff7a59e5650d95b781 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 13 Jun 2024 09:04:04 +0100 Subject: [PATCH 10/44] Update docstrings. --- .../sorters/tests/test_runsorter_dependency_checks.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py index 741fe4ae0e..c4beaba072 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -20,14 +20,18 @@ def _monkeypatch_return_false(): def _monkeypatch_return_true(): + """ + Monkeypatch for some `has_` functions to + return `True` so functions that are later in the + `runsorter` code can be checked. + """ return True class TestRunersorterDependencyChecks: """ - This class performs tests to check whether expected - dependency checks prior to sorting are run. The - run_sorter function should raise an error if: + This class tests whether expected dependency checks prior to sorting are run. + The run_sorter function should raise an error if: - singularity is not installed - spython is not installed (python package) - docker is not installed From 00663080b03f7933d37ba4ff2ee32e3402aa200e Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 13 Jun 2024 09:10:30 +0100 Subject: [PATCH 11/44] Swap return bool for to match function name. --- src/spikeinterface/sorters/runsorter.py | 2 +- src/spikeinterface/sorters/utils/misc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index f9994dd38d..80608f8973 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -494,7 +494,7 @@ def run_sorter_container( assert has_nvidia(), "The container requires a NVIDIA GPU capability, but it is not available" extra_kwargs["container_requires_gpu"] = True - if platform.system() == "Linux" and has_docker_nvidia_installed(): + if platform.system() == "Linux" and not has_docker_nvidia_installed(): warn( f"nvidia-required but none of \n{get_nvidia_docker_dependecies()}\n were found. " f"This may result in an error being raised during sorting. Try " diff --git a/src/spikeinterface/sorters/utils/misc.py b/src/spikeinterface/sorters/utils/misc.py index 66744fbab1..1e01b9c052 100644 --- a/src/spikeinterface/sorters/utils/misc.py +++ b/src/spikeinterface/sorters/utils/misc.py @@ -119,7 +119,7 @@ def has_docker_nvidia_installed(): has_dep = [] for dep in all_dependencies: has_dep.append(_run_subprocess_silently(f"{dep} --version").returncode == 0) - return not any(has_dep) + return any(has_dep) def get_nvidia_docker_dependecies(): From 9664f69c4bcdd24e20584f601bcbd6a9ae79e174 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 19 Jun 2024 11:49:32 +0200 Subject: [PATCH 12/44] Apply suggestions from code review Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- .../sorters/tests/test_runsorter_dependency_checks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py index c4beaba072..83d6ec3161 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -13,7 +13,7 @@ def _monkeypatch_return_false(): """ - A function to monkeypatch the `has_` functions, + A function to monkeypatch the `has_` functions, ensuring the always return `False` at runtime. """ return False @@ -61,12 +61,12 @@ class TestRunersorterDependencyChecks: @pytest.fixture(scope="function") def uninstall_python_dependency(self, request): """ - This python fixture mocks python modules not been importable + This python fixture mocks python modules not being importable by setting the relevant `sys.modules` dict entry to `None`. - It uses `yeild` so that the function can tear-down the test + It uses `yield` so that the function can tear-down the test (even if it failed) and replace the patched `sys.module` entry. - This function uses an `indirect` parameterisation, meaning the + This function uses an `indirect` parameterization, meaning the `request.param` is passed to the fixture at the start of the test function. This is used to reuse code for nearly identical `spython` and `docker` python dependency tests. From 543cc8f2a67719e4ae8b5b64a198a6c7256406e4 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Wed, 19 Jun 2024 18:12:31 +0100 Subject: [PATCH 13/44] Add apptainer case to 'has_singularity()' Co-authored-by: Alessio Buccino --- src/spikeinterface/sorters/utils/misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/utils/misc.py b/src/spikeinterface/sorters/utils/misc.py index 1e01b9c052..82480ffe0a 100644 --- a/src/spikeinterface/sorters/utils/misc.py +++ b/src/spikeinterface/sorters/utils/misc.py @@ -96,7 +96,7 @@ def has_docker(): def has_singularity(): - return _run_subprocess_silently("singularity --version").returncode == 0 + return _run_subprocess_silently("singularity --version").returncode == 0 or _run_subprocess_silently("apptainer --version").returncode == 0 def has_docker_nvidia_installed(): From dceb08070af9954b25c99c82ed2df314ef924aa7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Jun 2024 17:12:51 +0000 Subject: [PATCH 14/44] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/utils/misc.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/utils/misc.py b/src/spikeinterface/sorters/utils/misc.py index 82480ffe0a..9c8c3bba89 100644 --- a/src/spikeinterface/sorters/utils/misc.py +++ b/src/spikeinterface/sorters/utils/misc.py @@ -96,7 +96,10 @@ def has_docker(): def has_singularity(): - return _run_subprocess_silently("singularity --version").returncode == 0 or _run_subprocess_silently("apptainer --version").returncode == 0 + return ( + _run_subprocess_silently("singularity --version").returncode == 0 + or _run_subprocess_silently("apptainer --version").returncode == 0 + ) def has_docker_nvidia_installed(): From 617649569e147f8a530d6cfd0c0637857481e367 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 20 Jun 2024 13:14:37 -0600 Subject: [PATCH 15/44] improve error log to json in run_sorter --- src/spikeinterface/sorters/basesorter.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 8c52626703..799444ddbd 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -262,7 +262,12 @@ def run_from_folder(cls, output_folder, raise_error, verbose): has_error = True run_time = None log["error"] = True - log["error_trace"] = traceback.format_exc() + error_log_to_display = traceback.format_exc() + trace_lines = error_log_to_display.strip().split("\n") + error_to_json = ["Traceback (most recent call last):"] + [ + f" {line}" if not line.startswith(" ") else line for line in trace_lines[1:] + ] + log["error_trace"] = error_to_json log["error"] = has_error log["run_time"] = run_time @@ -290,7 +295,7 @@ def run_from_folder(cls, output_folder, raise_error, verbose): if has_error and raise_error: raise SpikeSortingError( - f"Spike sorting error trace:\n{log['error_trace']}\n" + f"Spike sorting error trace:\n{error_log_to_display}\n" f"Spike sorting failed. You can inspect the runtime trace in {output_folder}/spikeinterface_log.json." ) From fc3e6331eb3284e592808e238ab6954cb394154f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jun 2024 17:49:48 +0200 Subject: [PATCH 16/44] Add plot_drift_map --- src/spikeinterface/widgets/driftmap.py | 143 ++++++++++++++++++++++ src/spikeinterface/widgets/motion.py | 84 +++++-------- src/spikeinterface/widgets/widget_list.py | 3 + 3 files changed, 179 insertions(+), 51 deletions(-) create mode 100644 src/spikeinterface/widgets/driftmap.py diff --git a/src/spikeinterface/widgets/driftmap.py b/src/spikeinterface/widgets/driftmap.py new file mode 100644 index 0000000000..60e8df2972 --- /dev/null +++ b/src/spikeinterface/widgets/driftmap.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import numpy as np + +from .base import BaseWidget, to_attr + + +class DriftMapWidget(BaseWidget): + """ + Plot the a drift map from a motion info dictionary. + + Parameters + ---------- + peaks : np.array + The peaks array, with dtype ("sample_index", "channel_index", "amplitude", "segment_index") + peak_locations : np.array + The peak locations, with dtype ("x", "y") or ("x", "y", "z") + direction : "x" or "y", default: "y" + The direction to display + segment_index : int, default: None + The segment index to display. + recording : RecordingExtractor, default: None + The recording extractor object (only used to get "real" times) + segment_index : int, default: 0 + The segment index to display. + sampling_frequency : float, default: None + The sampling frequency (needed if recording is None) + depth_lim : tuple or None, default: None + The min and max depth to display, if None (min and max of the recording) + color_amplitude : bool, default: True + If True, the color of the scatter points is the amplitude of the peaks + scatter_decimate : int, default: None + If > 1, the scatter points are decimated + cmap : str, default: "inferno" + The colormap to use for the amplitude + clim : tuple or None, default: None + The min and max amplitude to display, if None (min and max of the amplitudes) + alpha : float, default: 1 + The alpha of the scatter points + """ + + def __init__( + self, + peaks, + peak_locations, + direction="y", + recording=None, + sampling_frequency=None, + segment_index=None, + depth_lim=None, + color_amplitude=True, + scatter_decimate=None, + cmap="inferno", + clim=None, + alpha=1, + backend=None, + **backend_kwargs, + ): + if segment_index is None: + assert ( + len(np.unique(peaks["segment_index"])) == 1 + ), "segment_index must be specified if there is only one segment in the peaks array" + assert recording or sampling_frequency, "recording or sampling_frequency must be specified" + if recording is not None: + sampling_frequency = recording.sampling_frequency + times = recording.get_times(segment_index=segment_index) + else: + times = None + + plot_data = dict( + peaks=peaks, + peak_locations=peak_locations, + direction=direction, + times=times, + sampling_frequency=sampling_frequency, + segment_index=segment_index, + depth_lim=depth_lim, + color_amplitude=color_amplitude, + scatter_decimate=scatter_decimate, + cmap=cmap, + clim=clim, + alpha=alpha, + recording=recording, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from matplotlib.colors import Normalize + + from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks + + dp = to_attr(data_plot) + + assert backend_kwargs["axes"] is None, "axes argument is not allowed in MotionWidget" + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + fig = self.figure + + if dp.times is None: + # temporal_bins_plot = dp.temporal_bins + x = dp.peaks["sample_index"] / dp.sampling_frequency + else: + # use real times and adjust temporal bins with t_start + # temporal_bins_plot = dp.temporal_bins + dp.times[0] + x = dp.times[dp.peaks["sample_index"]] + + y = dp.peak_locations[dp.direction] + if dp.scatter_decimate is not None: + x = x[:: dp.scatter_decimate] + y = y[:: dp.scatter_decimate] + y2 = y2[:: dp.scatter_decimate] + + if dp.color_amplitude: + amps = dp.peaks["amplitude"] + amps_abs = np.abs(amps) + q_95 = np.quantile(amps_abs, 0.95) + if dp.scatter_decimate is not None: + amps = amps[:: dp.scatter_decimate] + amps_abs = amps_abs[:: dp.scatter_decimate] + cmap = plt.colormaps[dp.cmap] + if dp.clim is None: + amps = amps_abs + amps /= q_95 + c = cmap(amps) + else: + norm_function = Normalize(vmin=dp.clim[0], vmax=dp.clim[1], clip=True) + c = cmap(norm_function(amps)) + color_kwargs = dict( + color=None, + c=c, + alpha=dp.alpha, + ) + else: + color_kwargs = dict(color="k", c=None, alpha=dp.alpha) + + self.ax.scatter(x, y, s=1, **color_kwargs) + if dp.depth_lim is not None: + self.ax.set_ylim(*dp.depth_lim) + self.ax.set_title("Peak depth") + self.ax.set_xlabel("Times [s]") + self.ax.set_ylabel("Depth [$\\mu$m]") diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index fc0c91423d..7d733523df 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -3,6 +3,7 @@ import numpy as np from .base import BaseWidget, to_attr +from .driftmap import DriftMapWidget class MotionWidget(BaseWidget): @@ -107,7 +108,7 @@ class MotionInfoWidget(BaseWidget): Parameters ---------- motion_info : dict - The motion info return by correct_motion() or load back with load_motion_info() + The motion info returned by correct_motion() or loaded back with load_motion_info() segment_index : int, default: None The segment index to display. recording : RecordingExtractor, default: None @@ -153,7 +154,9 @@ def __init__( if len(motion.displacement) == 1: segment_index = 0 else: - raise ValueError("plot motion : teh Motion object is multi segment you must provide segmentindex=XX") + raise ValueError( + "plot drift map : the Motion object is multi-segment you must provide segment_index=XX" + ) times = recording.get_times() if recording is not None else None @@ -214,14 +217,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax1.sharex(ax0) ax1.sharey(ax0) - if dp.times is None: - # temporal_bins_plot = dp.temporal_bins - x = dp.peaks["sample_index"] / dp.sampling_frequency - else: - # use real times and adjust temporal bins with t_start - # temporal_bins_plot = dp.temporal_bins + dp.times[0] - x = dp.times[dp.peaks["sample_index"]] - corrected_location = correct_motion_on_peaks( dp.peaks, dp.peak_locations, @@ -229,47 +224,34 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp.recording, ) - y = dp.peak_locations[motion.direction] - y2 = corrected_location[motion.direction] - if dp.scatter_decimate is not None: - x = x[:: dp.scatter_decimate] - y = y[:: dp.scatter_decimate] - y2 = y2[:: dp.scatter_decimate] - - if dp.color_amplitude: - amps = dp.peaks["amplitude"] - amps_abs = np.abs(amps) - q_95 = np.quantile(amps_abs, 0.95) - if dp.scatter_decimate is not None: - amps = amps[:: dp.scatter_decimate] - amps_abs = amps_abs[:: dp.scatter_decimate] - cmap = plt.colormaps[dp.amplitude_cmap] - if dp.amplitude_clim is None: - amps = amps_abs - amps /= q_95 - c = cmap(amps) - else: - norm_function = Normalize(vmin=dp.amplitude_clim[0], vmax=dp.amplitude_clim[1], clip=True) - c = cmap(norm_function(amps)) - color_kwargs = dict( - color=None, - c=c, - alpha=dp.amplitude_alpha, - ) - else: - color_kwargs = dict(color="k", c=None, alpha=dp.amplitude_alpha) - - ax0.scatter(x, y, s=1, **color_kwargs) - if dp.depth_lim is not None: - ax0.set_ylim(*dp.depth_lim) - ax0.set_title("Peak depth") - ax0.set_xlabel("Times [s]") - ax0.set_ylabel("Depth [$\\mu$m]") - - ax1.scatter(x, y2, s=1, **color_kwargs) - ax1.set_xlabel("Times [s]") - ax1.set_ylabel("Depth [$\\mu$m]") - ax1.set_title("Corrected peak depth") + commpon_drift_map_kwargs = dict( + direction=dp.motion.direction, + recording=dp.recording, + segment_index=dp.segment_index, + depth_lim=dp.depth_lim, + color_amplitude=dp.color_amplitude, + scatter_decimate=dp.scatter_decimate, + cmap=dp.amplitude_cmap, + clim=dp.amplitude_clim, + alpha=dp.amplitude_alpha, + backend="matplotlib", + ) + + drift_map = DriftMapWidget( + dp.peaks, + dp.peak_locations, + ax=ax0, + immediate_plot=True, + **commpon_drift_map_kwargs, + ) + + drift_map_corrected = DriftMapWidget( + dp.peaks, + corrected_location, + ax=ax1, + immediate_plot=True, + **commpon_drift_map_kwargs, + ) ax2.plot(temporal_bins_s, displacement, alpha=0.2, color="black") ax2.plot(temporal_bins_s, np.mean(displacement, axis=1), color="C0") diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 6367e098ea..8d4accaa7e 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -9,6 +9,7 @@ from .amplitudes import AmplitudesWidget from .autocorrelograms import AutoCorrelogramsWidget from .crosscorrelograms import CrossCorrelogramsWidget +from .driftmap import DriftMapWidget from .isi_distribution import ISIDistributionWidget from .motion import MotionWidget, MotionInfoWidget from .multicomparison import MultiCompGraphWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget @@ -44,6 +45,7 @@ ConfusionMatrixWidget, ComparisonCollisionBySimilarityWidget, CrossCorrelogramsWidget, + DriftMapWidget, ISIDistributionWidget, MotionWidget, MotionInfoWidget, @@ -118,6 +120,7 @@ plot_confusion_matrix = ConfusionMatrixWidget plot_comparison_collision_by_similarity = ComparisonCollisionBySimilarityWidget plot_crosscorrelograms = CrossCorrelogramsWidget +plot_drift_map = DriftMapWidget plot_isi_distribution = ISIDistributionWidget plot_motion = MotionWidget plot_motion_info = MotionInfoWidget From baf1287215e41b020dc97b5d6428dbdc5446ef76 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 25 Jun 2024 11:20:06 +0100 Subject: [PATCH 17/44] Fix docstrings for extractors module --- doc/api.rst | 2 +- doc/modules/extractors.rst | 2 +- src/spikeinterface/extractors/cbin_ibl.py | 4 +++- .../extractors/herdingspikesextractors.py | 2 +- src/spikeinterface/extractors/iblextractors.py | 2 +- .../extractors/neoextractors/alphaomega.py | 5 +++++ .../extractors/neoextractors/biocam.py | 1 - .../extractors/neoextractors/blackrock.py | 3 ++- .../extractors/neoextractors/ced.py | 2 -- .../extractors/neoextractors/intan.py | 2 ++ .../extractors/neoextractors/maxwell.py | 2 ++ .../extractors/neoextractors/neuralynx.py | 11 ++++++----- .../extractors/neoextractors/plexon2.py | 2 +- .../extractors/neoextractors/spikegadgets.py | 2 +- .../extractors/neoextractors/tdt.py | 2 ++ src/spikeinterface/extractors/toy_example.py | 10 ++++++++-- src/spikeinterface/preprocessing/filter.py | 17 ++++++++--------- 17 files changed, 44 insertions(+), 27 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index a7476cd62f..c5c9ebe4dd 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -117,7 +117,7 @@ Non-NEO-based .. autofunction:: read_bids .. autofunction:: read_cbin_ibl .. autofunction:: read_combinato - .. autofunction:: read_ibl_streaming_recording + .. autofunction:: read_ibl_recording .. autofunction:: read_hdsort .. autofunction:: read_herdingspikes .. autofunction:: read_kilosort diff --git a/doc/modules/extractors.rst b/doc/modules/extractors.rst index 2d0e047672..ba08e45aca 100644 --- a/doc/modules/extractors.rst +++ b/doc/modules/extractors.rst @@ -125,7 +125,7 @@ For raw recording formats, we currently support: * **Biocam HDF5** :py:func:`~spikeinterface.extractors.read_biocam()` * **CED** :py:func:`~spikeinterface.extractors.read_ced()` * **EDF** :py:func:`~spikeinterface.extractors.read_edf()` -* **IBL streaming** :py:func:`~spikeinterface.extractors.read_ibl_streaming_recording()` +* **IBL streaming** :py:func:`~spikeinterface.extractors.read_ibl_recording()` * **Intan** :py:func:`~spikeinterface.extractors.read_intan()` * **MaxWell** :py:func:`~spikeinterface.extractors.read_maxwell()` * **MCS H5** :py:func:`~spikeinterface.extractors.read_mcsh5()` diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index a6da19408f..1687acb073 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -27,9 +27,11 @@ class CompressedBinaryIblExtractor(BaseRecording): load_sync_channel : bool, default: False Load or not the last channel (sync). If not then the probe is loaded. - stream_name : str, default: "ap". + stream_name : {"ap", "lp"}, default: "ap". Whether to load AP or LFP band, one of "ap" or "lp". + cbin_file : str or None, default None + The cbin file of the recording. If None, searches in `folder_path` for file. Returns ------- diff --git a/src/spikeinterface/extractors/herdingspikesextractors.py b/src/spikeinterface/extractors/herdingspikesextractors.py index 139d51d62e..87f7dd74c4 100644 --- a/src/spikeinterface/extractors/herdingspikesextractors.py +++ b/src/spikeinterface/extractors/herdingspikesextractors.py @@ -20,7 +20,7 @@ class HerdingspikesSortingExtractor(BaseSorting): Parameters ---------- - folder_path : str or Path + file_path : str or Path Path to the ALF folder. load_unit_info : bool, default: True Whether to load the unit info from the file. diff --git a/src/spikeinterface/extractors/iblextractors.py b/src/spikeinterface/extractors/iblextractors.py index 2444314aec..27bb95854f 100644 --- a/src/spikeinterface/extractors/iblextractors.py +++ b/src/spikeinterface/extractors/iblextractors.py @@ -41,7 +41,7 @@ class IblRecordingExtractor(BaseRecording): stream_name : str The name of the stream to load for the session. These can be retrieved from calling `StreamingIblExtractor.get_stream_names(session="")`. - load_sync_channels : bool, default: false + load_sync_channel : bool, default: false Load or not the last channel (sync). If not then the probe is loaded. cache_folder : str or None, default: None diff --git a/src/spikeinterface/extractors/neoextractors/alphaomega.py b/src/spikeinterface/extractors/neoextractors/alphaomega.py index 5c8e58d3a5..239928f66d 100644 --- a/src/spikeinterface/extractors/neoextractors/alphaomega.py +++ b/src/spikeinterface/extractors/neoextractors/alphaomega.py @@ -50,6 +50,11 @@ def map_to_neo_kwargs(cls, folder_path, lsx_files=None): class AlphaOmegaEventExtractor(NeoBaseEventExtractor): """ Class for reading events from AlphaOmega MPX file format + + Parameters + ---------- + folder_path : str or Path-like + The folder path to the AlphaOmega events. """ mode = "folder" diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 96d4dd25a6..9f23575dba 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -42,7 +42,6 @@ def __init__( electrode_width=None, stream_id=None, stream_name=None, - block_index=None, all_annotations=False, ): neo_kwargs = self.map_to_neo_kwargs(file_path) diff --git a/src/spikeinterface/extractors/neoextractors/blackrock.py b/src/spikeinterface/extractors/neoextractors/blackrock.py index 5e28c4a20d..0015fd9f67 100644 --- a/src/spikeinterface/extractors/neoextractors/blackrock.py +++ b/src/spikeinterface/extractors/neoextractors/blackrock.py @@ -26,6 +26,8 @@ class BlackrockRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. + use_names_as_ids : bool or None, default: None + If True, use channel names as IDs. If None, use default IDs. """ mode = "file" @@ -37,7 +39,6 @@ def __init__( file_path, stream_id=None, stream_name=None, - block_index=None, all_annotations=False, use_names_as_ids=False, ): diff --git a/src/spikeinterface/extractors/neoextractors/ced.py b/src/spikeinterface/extractors/neoextractors/ced.py index 401c927fc7..e2c79478fa 100644 --- a/src/spikeinterface/extractors/neoextractors/ced.py +++ b/src/spikeinterface/extractors/neoextractors/ced.py @@ -23,8 +23,6 @@ class CedRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream id you want to load. stream_name : str, default: None If there are several streams, specify the stream name you want to load. - block_index : int, default: None - If there are several blocks, specify the block index you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. """ diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index c37ff47807..9d4db3103c 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -27,6 +27,8 @@ class IntanRecordingExtractor(NeoBaseRecordingExtractor): If True, data that violates integrity assumptions will be loaded. At the moment the only integrity check we perform is that timestamps are continuous. Setting this to True will ignore this check and set the attribute `discontinuous_timestamps` to True in the underlying neo object. + use_names_as_ids : bool or None, default: None + If True, use channel names as IDs. If None, use default IDs. """ mode = "file" diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 3888b6d5a0..a66075b451 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -35,6 +35,8 @@ class MaxwellRecordingExtractor(NeoBaseRecordingExtractor): you want to extract. (rec_name='rec0000'). install_maxwell_plugin : bool, default: False If True, install the maxwell plugin for neo. + block_index : int, default: None + If there are several blocks (experiments), specify the block index you want to load """ mode = "file" diff --git a/src/spikeinterface/extractors/neoextractors/neuralynx.py b/src/spikeinterface/extractors/neoextractors/neuralynx.py index 25b6bb5b61..0670371ba9 100644 --- a/src/spikeinterface/extractors/neoextractors/neuralynx.py +++ b/src/spikeinterface/extractors/neoextractors/neuralynx.py @@ -26,16 +26,17 @@ class NeuralynxRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. - exlude_filename : list[str], default: None + exclude_filename : list[str], default: None List of filename to exclude from the loading. For example, use `exclude_filename=["events.nev"]` to skip loading the event file. strict_gap_mode : bool, default: False See neo documentation. Detect gaps using strict mode or not. - * strict_gap_mode = True then a gap is consider when timstamp difference between two - consecutive data packets is more than one sample interval. - * strict_gap_mode = False then a gap has an increased tolerance. Some new systems with different clocks need this option - otherwise, too many gaps are detected + * strict_gap_mode = True then a gap is consider when timstamp difference between + two consecutive data packets is more than one sample interval. + * strict_gap_mode = False then a gap has an increased tolerance. Some new systems + with different clocks need this option otherwise, too many gaps are detected + Note that here the default is False contrary to neo. """ diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 941158def1..c7351a308b 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -19,7 +19,7 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream id you want to load. stream_name : str, default: None If there are several streams, specify the stream name you want to load. - use_names_as_ids: + use_names_as_ids : bool, default: True If True, the names of the signals are used as channel ids. If False, the channel ids are a combination of the source id and the channel index. diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index f326c49cd1..3d57817f88 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -32,7 +32,7 @@ class SpikeGadgetsRecordingExtractor(NeoBaseRecordingExtractor): NeoRawIOClass = "SpikeGadgetsRawIO" name = "spikegadgets" - def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None, all_annotations=False): + def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/tdt.py b/src/spikeinterface/extractors/neoextractors/tdt.py index 146f6a4b4c..27b456102f 100644 --- a/src/spikeinterface/extractors/neoextractors/tdt.py +++ b/src/spikeinterface/extractors/neoextractors/tdt.py @@ -23,6 +23,8 @@ class TdtRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. + block_index : int, default: None + If there are several blocks (experiments), specify the block index you want to load """ mode = "folder" diff --git a/src/spikeinterface/extractors/toy_example.py b/src/spikeinterface/extractors/toy_example.py index 450044d07b..2f007cca88 100644 --- a/src/spikeinterface/extractors/toy_example.py +++ b/src/spikeinterface/extractors/toy_example.py @@ -57,12 +57,18 @@ def toy_example( Spike time in the recording spike_labels : np.array or list[nparray] or None, default: None Cluster label for each spike time (needs to specified both together). - # score_detection : int (between 0 and 1) - # Generate the sorting based on a subset of spikes compare with the trace generation firing_rate : float, default: 3.0 The firing rate for the units (in Hz) seed : int or None, default: None Seed for random initialization. + upsample_factor : None or int, default: None + A upsampling factor used only when templates are not provided. + num_columns : int, default: 1 + Number of columns in probe. + average_peak_amplitude : float, default: -100 + Average peak amplitude of generated templates + contact_spacing_um : float, default: 40.0 + Spacing between probe contacts. Returns ------- diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 6a1733c57c..d18227ca83 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -10,15 +10,14 @@ _common_filter_docs = """**filter_kwargs : dict Certain keyword arguments for `scipy.signal` filters: - filter_order : order - The order of the filter - filter_mode : "sos" | "ba", default: "sos" - Filter form of the filter coefficients: - - second-order sections ("sos") - - numerator/denominator : ("ba") - ftype : str, default: "butter" - Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1". - """ + filter_order : order + The order of the filter + filter_mode : "sos" | "ba", default: "sos" + Filter form of the filter coefficients: + - second-order sections ("sos") + - numerator/denominator : ("ba") + ftype : str, default: "butter" + Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1".""" class FilterRecording(BasePreprocessor): From b88ddcb9969e01c019452e1a0d1832b092390ea8 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 26 Jun 2024 15:17:09 +0100 Subject: [PATCH 18/44] Respond to review --- src/spikeinterface/extractors/neoextractors/intan.py | 4 ++-- src/spikeinterface/extractors/toy_example.py | 6 +++--- src/spikeinterface/preprocessing/filter.py | 4 +++- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index 9d4db3103c..50fda79123 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -27,8 +27,8 @@ class IntanRecordingExtractor(NeoBaseRecordingExtractor): If True, data that violates integrity assumptions will be loaded. At the moment the only integrity check we perform is that timestamps are continuous. Setting this to True will ignore this check and set the attribute `discontinuous_timestamps` to True in the underlying neo object. - use_names_as_ids : bool or None, default: None - If True, use channel names as IDs. If None, use default IDs. + use_names_as_ids : bool, default: False + If True, use channel names as IDs. If False, use default IDs inherited from neo. """ mode = "file" diff --git a/src/spikeinterface/extractors/toy_example.py b/src/spikeinterface/extractors/toy_example.py index 2f007cca88..55b787f3ed 100644 --- a/src/spikeinterface/extractors/toy_example.py +++ b/src/spikeinterface/extractors/toy_example.py @@ -62,13 +62,13 @@ def toy_example( seed : int or None, default: None Seed for random initialization. upsample_factor : None or int, default: None - A upsampling factor used only when templates are not provided. + An upsampling factor, used only when templates are not provided. num_columns : int, default: 1 Number of columns in probe. average_peak_amplitude : float, default: -100 - Average peak amplitude of generated templates + Average peak amplitude of generated templates. contact_spacing_um : float, default: 40.0 - Spacing between probe contacts. + Spacing between probe contacts in micrometers. Returns ------- diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index d18227ca83..93462ac5d8 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -11,7 +11,9 @@ _common_filter_docs = """**filter_kwargs : dict Certain keyword arguments for `scipy.signal` filters: filter_order : order - The order of the filter + The order of the filter. Note as filtering is applied with scipy's + `filtfilt` functions (i.e. acausal, zero-phase) the effective + order will be double the `filter_order`. filter_mode : "sos" | "ba", default: "sos" Filter form of the filter coefficients: - second-order sections ("sos") From a166e5a3d419c49aa6afc69f0e2f98ea7eb9d0c3 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 26 Jun 2024 15:33:51 -0600 Subject: [PATCH 19/44] add recording iterator --- src/spikeinterface/core/core_tools.py | 58 +++++++++++++++++-- src/spikeinterface/sorters/container_tools.py | 11 +--- 2 files changed, 55 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index f3d8b3df7f..3fe4939524 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -1,6 +1,6 @@ from __future__ import annotations from pathlib import Path, WindowsPath -from typing import Union +from typing import Union, Generator import os import sys import datetime @@ -8,6 +8,7 @@ from copy import deepcopy import importlib from math import prod +from collections import namedtuple import numpy as np @@ -183,6 +184,50 @@ def is_dict_extractor(d: dict) -> bool: return is_extractor +recording_dict_element = namedtuple(typename="recording_dict_element", field_names=["value", "name", "access_path"]) + + +def recording_dict_iterator(extractor_dict: dict) -> Generator[recording_dict_element]: + """ + Iterator for recursive traversal of a dictionary. + This function explores the dictionary recursively and yields the path to each value along with the value itself. + + By path here we mean the keys that lead to the value in the dictionary: + e.g. for the dictionary {'a': {'b': 1}}, the path to the value 1 is ('a', 'b'). + + See `BaseExtractor.to_dict()` for a description of `extractor_dict` structure. + + Parameters + ---------- + extractor_dict : dict + Input dictionary + + Yields + ------ + recording_dict_element + Named tuple containing the value, the name, and the access_path to the value in the dictionary. + + """ + + def _recording_dict_iterator(dict_list_or_value, access_path=(), name=""): + if isinstance(dict_list_or_value, dict): + for k, v in dict_list_or_value.items(): + yield from _recording_dict_iterator(v, access_path + (k,), name=k) + elif isinstance(dict_list_or_value, list): + for i, v in enumerate(dict_list_or_value): + yield from _recording_dict_iterator( + v, access_path + (i,), name=name + ) # Propagate name of list to children + else: + yield recording_dict_element( + value=dict_list_or_value, + name=name, + access_path=access_path, + ) + + yield from _recording_dict_iterator(extractor_dict) + + def recursive_path_modifier(d, func, target="path", copy=True) -> dict: """ Generic function for recursive modification of paths in an extractor dict. @@ -250,15 +295,16 @@ def recursive_path_modifier(d, func, target="path", copy=True) -> dict: raise ValueError(f"{k} key for path must be str or list[str]") -def _get_paths_list(d): +def _get_paths_list(d: dict) -> list[str | Path]: # this explore a dict and get all paths flatten in a list # the trick is to use a closure func called by recursive_path_modifier() - path_list = [] - def append_to_path(p): - path_list.append(p) + element_is_path = lambda element: "path" in element.name and isinstance(element.value, (str, Path)) + path_list = [e.value for e in recording_dict_iterator(d) if element_is_path(e)] + + # if check_if_exists: TODO: Enable this once container_tools test uses proper mocks + # path_list = [p for p in path_list if Path(p).exists()] - recursive_path_modifier(d, append_to_path, target="path", copy=True) return path_list diff --git a/src/spikeinterface/sorters/container_tools.py b/src/spikeinterface/sorters/container_tools.py index 60eb080ae5..8e03090eaf 100644 --- a/src/spikeinterface/sorters/container_tools.py +++ b/src/spikeinterface/sorters/container_tools.py @@ -9,19 +9,14 @@ # TODO move this inside functions -from spikeinterface.core.core_tools import recursive_path_modifier +from spikeinterface.core.core_tools import recursive_path_modifier, _get_paths_list def find_recording_folders(d): """Finds all recording folders 'paths' in a dict""" - folders_to_mount = [] - def append_parent_folder(p): - p = Path(p) - folders_to_mount.append(p.resolve().absolute().parent) - return p - - _ = recursive_path_modifier(d, append_parent_folder, target="path", copy=True) + path_list = _get_paths_list(d=d) + folders_to_mount = [Path(p).resolve().parent for p in path_list] try: # this will fail if on different drives (Windows) base_folders_to_mount = [Path(os.path.commonpath(folders_to_mount))] From 27a7c9a96c2e8f008109c99d8dd90ac52ac5fd3e Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 26 Jun 2024 16:58:39 -0600 Subject: [PATCH 20/44] add and fix tests --- src/spikeinterface/core/core_tools.py | 83 ++++++++-- .../core/tests/test_core_tools.py | 153 ++++++++++++------ 2 files changed, 170 insertions(+), 66 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 3fe4939524..9e90b56c8d 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -187,7 +187,7 @@ def is_dict_extractor(d: dict) -> bool: recording_dict_element = namedtuple(typename="recording_dict_element", field_names=["value", "name", "access_path"]) -def recording_dict_iterator(extractor_dict: dict) -> Generator[recording_dict_element]: +def extractor_dict_iterator(extractor_dict: dict) -> Generator[recording_dict_element]: """ Iterator for recursive traversal of a dictionary. This function explores the dictionary recursively and yields the path to each value along with the value itself. @@ -209,13 +209,13 @@ def recording_dict_iterator(extractor_dict: dict) -> Generator[recording_dict_el """ - def _recording_dict_iterator(dict_list_or_value, access_path=(), name=""): + def _extractor_dict_iterator(dict_list_or_value, access_path=(), name=""): if isinstance(dict_list_or_value, dict): for k, v in dict_list_or_value.items(): - yield from _recording_dict_iterator(v, access_path + (k,), name=k) + yield from _extractor_dict_iterator(v, access_path + (k,), name=k) elif isinstance(dict_list_or_value, list): for i, v in enumerate(dict_list_or_value): - yield from _recording_dict_iterator( + yield from _extractor_dict_iterator( v, access_path + (i,), name=name ) # Propagate name of list to children else: @@ -225,7 +225,32 @@ def _recording_dict_iterator(dict_list_or_value, access_path=(), name=""): access_path=access_path, ) - yield from _recording_dict_iterator(extractor_dict) + yield from _extractor_dict_iterator(extractor_dict) + + +def set_value_in_recording_dict(extractor_dict: dict, access_path: tuple, new_value): + """ + In place modification of a value in a nested dictionary given its access path. + + Parameters + ---------- + extractor_dict : dict + The dictionary to modify + access_path : tuple + The path to the value in the dictionary + new_value : object + The new value to set + + Returns + ------- + dict + The modified dictionary + """ + + current = extractor_dict + for key in access_path[:-1]: + current = current[key] + current[access_path[-1]] = new_value def recursive_path_modifier(d, func, target="path", copy=True) -> dict: @@ -295,12 +320,13 @@ def recursive_path_modifier(d, func, target="path", copy=True) -> dict: raise ValueError(f"{k} key for path must be str or list[str]") -def _get_paths_list(d: dict) -> list[str | Path]: - # this explore a dict and get all paths flatten in a list - # the trick is to use a closure func called by recursive_path_modifier() +# This is the current definition that an element in a recording_dict is a path +# This is shared across a couple of definition so it is here for DNRY +element_is_path = lambda element: "path" in element.name and isinstance(element.value, (str, Path)) + - element_is_path = lambda element: "path" in element.name and isinstance(element.value, (str, Path)) - path_list = [e.value for e in recording_dict_iterator(d) if element_is_path(e)] +def _get_paths_list(d: dict) -> list[str | Path]: + path_list = [e.value for e in extractor_dict_iterator(d) if element_is_path(e)] # if check_if_exists: TODO: Enable this once container_tools test uses proper mocks # path_list = [p for p in path_list if Path(p).exists()] @@ -364,7 +390,7 @@ def check_paths_relative(input_dict, relative_folder) -> bool: return len(not_possible) == 0 -def make_paths_relative(input_dict, relative_folder) -> dict: +def make_paths_relative(input_dict: dict, relative_folder: str | Path) -> dict: """ Recursively transform a dict describing an BaseExtractor to make every path relative to a folder. @@ -380,9 +406,22 @@ def make_paths_relative(input_dict, relative_folder) -> dict: output_dict: dict A copy of the input dict with modified paths. """ + relative_folder = Path(relative_folder).resolve().absolute() - func = lambda p: _relative_to(p, relative_folder) - output_dict = recursive_path_modifier(input_dict, func, target="path", copy=True) + + path_elements_in_dict = [e for e in extractor_dict_iterator(input_dict) if element_is_path(e)] + # Only paths that exist are made relative + path_elements_in_dict = [e for e in path_elements_in_dict if Path(e.value).exists()] + + output_dict = deepcopy(input_dict) + for element in path_elements_in_dict: + new_value = _relative_to(element.value, relative_folder) + set_value_in_recording_dict( + extractor_dict=output_dict, + access_path=element.access_path, + new_value=new_value, + ) + return output_dict @@ -405,12 +444,28 @@ def make_paths_absolute(input_dict, base_folder): base_folder = Path(base_folder) # use as_posix instead of str to make the path unix like even on window func = lambda p: (base_folder / p).resolve().absolute().as_posix() - output_dict = recursive_path_modifier(input_dict, func, target="path", copy=True) + + path_elements_in_dict = [e for e in extractor_dict_iterator(input_dict) if element_is_path(e)] + output_dict = deepcopy(input_dict) + + output_dict = deepcopy(input_dict) + for element in path_elements_in_dict: + absolute_path = (base_folder / element.value).resolve() + if Path(absolute_path).exists(): + new_value = absolute_path.as_posix() # Not so sure about this, Sam + set_value_in_recording_dict( + extractor_dict=output_dict, + access_path=element.access_path, + new_value=new_value, + ) + return output_dict def recursive_key_finder(d, key): # Find all values for a key on a dictionary, even if nested + # TODO refactor to use extractor_dict_iterator + for k, v in d.items(): if isinstance(v, dict): yield from recursive_key_finder(v, key) diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index 8e00dcb779..043e0cabf3 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -51,14 +51,9 @@ def test_path_utils_functions(): assert d2["kwargs"]["path"].startswith("/yop") assert d2["kwargs"]["recording"]["kwargs"]["path"].startswith("/yop") - d3 = make_paths_relative(d, Path("/yep")) - assert d3["kwargs"]["path"] == "sub/path1" - assert d3["kwargs"]["recording"]["kwargs"]["path"] == "sub/path2" - - d4 = make_paths_absolute(d3, "/yop") - assert d4["kwargs"]["path"].startswith("/yop") - assert d4["kwargs"]["recording"]["kwargs"]["path"].startswith("/yop") +@pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") +def test_relative_path_on_windows(): if platform.system() == "Windows": # test for windows Path d = { @@ -74,57 +69,111 @@ def test_path_utils_functions(): } } - d2 = make_paths_relative(d, "c:\\yep") - # the str be must unix like path even on windows for more portability - assert d2["kwargs"]["path"] == "sub/path1" - assert d2["kwargs"]["recording"]["kwargs"]["path"] == "sub/path2" - # same drive assert check_paths_relative(d, r"c:\yep") # not the same drive assert not check_paths_relative(d, r"d:\yep") - d = { - "kwargs": { - "path": r"\\host\share\yep\sub\path1", - } - } - # UNC cannot be relative to d: drive - assert not check_paths_relative(d, r"d:\yep") - # UNC can be relative to the same UNC - assert check_paths_relative(d, r"\\host\share") - - def test_convert_string_to_bytes(): - # Test SI prefixes - assert convert_string_to_bytes("1k") == 1000 - assert convert_string_to_bytes("1M") == 1000000 - assert convert_string_to_bytes("1G") == 1000000000 - assert convert_string_to_bytes("1T") == 1000000000000 - assert convert_string_to_bytes("1P") == 1000000000000000 - # Test IEC prefixes - assert convert_string_to_bytes("1Ki") == 1024 - assert convert_string_to_bytes("1Mi") == 1048576 - assert convert_string_to_bytes("1Gi") == 1073741824 - assert convert_string_to_bytes("1Ti") == 1099511627776 - assert convert_string_to_bytes("1Pi") == 1125899906842624 - # Test mixed values - assert convert_string_to_bytes("1.5k") == 1500 - assert convert_string_to_bytes("2.5M") == 2500000 - assert convert_string_to_bytes("0.5G") == 500000000 - assert convert_string_to_bytes("1.2T") == 1200000000000 - assert convert_string_to_bytes("1.5Pi") == 1688849860263936 - # Test zero values - assert convert_string_to_bytes("0k") == 0 - assert convert_string_to_bytes("0Ki") == 0 - # Test invalid inputs (should raise assertion error) - with pytest.raises(AssertionError) as e: - convert_string_to_bytes("1Z") - assert str(e.value) == "Unknown suffix: Z" - - with pytest.raises(AssertionError) as e: - convert_string_to_bytes("1Xi") - assert str(e.value) == "Unknown suffix: Xi" +@pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") +def test_universal_naming_convention(): + d = { + "kwargs": { + "path": r"\\host\share\yep\sub\path1", + } + } + # UNC cannot be relative to d: drive + assert not check_paths_relative(d, r"d:\yep") + + # UNC can be relative to the same UNC + assert check_paths_relative(d, r"\\host\share") + + +def test_make_paths_relative(tmp_path): + + path_1 = tmp_path / "sub" / "path1" + path_2 = tmp_path / "sub" / "path2" + + # Create the objects in the path + path_1.mkdir(parents=True, exist_ok=True) + path_2.mkdir(parents=True, exist_ok=True) + extractor_dict = { + "kwargs": { + "path": str(path_1), # Note this is different in windows and posix + "electrical_series_path": "/acquisition/timeseries", # non-existent path-like objects should not be modified + "recording": { + "module": "mock_module", + "class": "mock_class", + "version": "1.2", + "annotations": {}, + "kwargs": {"path": str(path_2)}, + }, + } + } + modified_extractor_dict = make_paths_relative(extractor_dict, tmp_path) + assert modified_extractor_dict["kwargs"]["path"] == "sub/path1" + assert modified_extractor_dict["kwargs"]["recording"]["kwargs"]["path"] == "sub/path2" + assert modified_extractor_dict["kwargs"]["electrical_series_path"] == "/acquisition/timeseries" + + +def test_make_paths_absolute(tmp_path): + + path_1 = tmp_path / "sub" / "path1" + path_2 = tmp_path / "sub" / "path2" + + path_1.mkdir(parents=True, exist_ok=True) + path_2.mkdir(parents=True, exist_ok=True) + + extractor_dict = { + "kwargs": { + "path": "sub/path1", + "electrical_series_path": "/acquisition/timeseries", # non-existent path-like objects should not be modified + "recording": { + "module": "mock_module", + "class": "mock_class", + "version": "1.2", + "annotations": {}, + "kwargs": {"path": "sub/path2"}, + }, + } + } + + modified_extractor_dict = make_paths_absolute(extractor_dict, tmp_path) + assert modified_extractor_dict["kwargs"]["path"].startswith(str(tmp_path)) + assert modified_extractor_dict["kwargs"]["recording"]["kwargs"]["path"].startswith(str(tmp_path)) + assert modified_extractor_dict["kwargs"]["electrical_series_path"] == "/acquisition/timeseries" + + +def test_convert_string_to_bytes(): + # Test SI prefixes + assert convert_string_to_bytes("1k") == 1000 + assert convert_string_to_bytes("1M") == 1000000 + assert convert_string_to_bytes("1G") == 1000000000 + assert convert_string_to_bytes("1T") == 1000000000000 + assert convert_string_to_bytes("1P") == 1000000000000000 + # Test IEC prefixes + assert convert_string_to_bytes("1Ki") == 1024 + assert convert_string_to_bytes("1Mi") == 1048576 + assert convert_string_to_bytes("1Gi") == 1073741824 + assert convert_string_to_bytes("1Ti") == 1099511627776 + assert convert_string_to_bytes("1Pi") == 1125899906842624 + # Test mixed values + assert convert_string_to_bytes("1.5k") == 1500 + assert convert_string_to_bytes("2.5M") == 2500000 + assert convert_string_to_bytes("0.5G") == 500000000 + assert convert_string_to_bytes("1.2T") == 1200000000000 + assert convert_string_to_bytes("1.5Pi") == 1688849860263936 + # Test zero values + assert convert_string_to_bytes("0k") == 0 + assert convert_string_to_bytes("0Ki") == 0 + # Test invalid inputs (should raise assertion error) + with pytest.raises(AssertionError) as e: + convert_string_to_bytes("1Z") + assert str(e.value) == "Unknown suffix: Z" + + with pytest.raises(AssertionError) as e: + convert_string_to_bytes("1Xi") + assert str(e.value) == "Unknown suffix: Xi" def test_normal_pdf() -> None: From b3b85b2fe5670217d80c4adec1a751d1e1d5d024 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 26 Jun 2024 17:21:45 -0600 Subject: [PATCH 21/44] naming --- src/spikeinterface/core/core_tools.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 9e90b56c8d..d5480d6f00 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -228,7 +228,7 @@ def _extractor_dict_iterator(dict_list_or_value, access_path=(), name=""): yield from _extractor_dict_iterator(extractor_dict) -def set_value_in_recording_dict(extractor_dict: dict, access_path: tuple, new_value): +def set_value_in_extractor_dict(extractor_dict: dict, access_path: tuple, new_value): """ In place modification of a value in a nested dictionary given its access path. @@ -416,7 +416,7 @@ def make_paths_relative(input_dict: dict, relative_folder: str | Path) -> dict: output_dict = deepcopy(input_dict) for element in path_elements_in_dict: new_value = _relative_to(element.value, relative_folder) - set_value_in_recording_dict( + set_value_in_extractor_dict( extractor_dict=output_dict, access_path=element.access_path, new_value=new_value, @@ -453,7 +453,7 @@ def make_paths_absolute(input_dict, base_folder): absolute_path = (base_folder / element.value).resolve() if Path(absolute_path).exists(): new_value = absolute_path.as_posix() # Not so sure about this, Sam - set_value_in_recording_dict( + set_value_in_extractor_dict( extractor_dict=output_dict, access_path=element.access_path, new_value=new_value, From d794c8220e9e2ed2431636e53aee9b7b8d6b998b Mon Sep 17 00:00:00 2001 From: h-mayorquin Date: Thu, 27 Jun 2024 00:39:58 -0600 Subject: [PATCH 22/44] windows test remove inner conditional --- .../core/tests/test_core_tools.py | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index 043e0cabf3..ed13bd46fd 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -54,25 +54,24 @@ def test_path_utils_functions(): @pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") def test_relative_path_on_windows(): - if platform.system() == "Windows": - # test for windows Path - d = { - "kwargs": { - "path": r"c:\yep\sub\path1", - "recording": { - "module": "mock_module", - "class": "mock_class", - "version": "1.2", - "annotations": {}, - "kwargs": {"path": r"c:\yep\sub\path2"}, - }, - } + + d = { + "kwargs": { + "path": r"c:\yep\sub\path1", + "recording": { + "module": "mock_module", + "class": "mock_class", + "version": "1.2", + "annotations": {}, + "kwargs": {"path": r"c:\yep\sub\path2"}, + }, } + } - # same drive - assert check_paths_relative(d, r"c:\yep") - # not the same drive - assert not check_paths_relative(d, r"d:\yep") + # same drive + assert check_paths_relative(d, r"c:\yep") + # not the same drive + assert not check_paths_relative(d, r"d:\yep") @pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") @@ -139,8 +138,8 @@ def test_make_paths_absolute(tmp_path): } modified_extractor_dict = make_paths_absolute(extractor_dict, tmp_path) - assert modified_extractor_dict["kwargs"]["path"].startswith(str(tmp_path)) - assert modified_extractor_dict["kwargs"]["recording"]["kwargs"]["path"].startswith(str(tmp_path)) + assert modified_extractor_dict["kwargs"]["path"].startswith(str(tmp_path.as_posix())) + assert modified_extractor_dict["kwargs"]["recording"]["kwargs"]["path"].startswith(str(tmp_path.as_posix())) assert modified_extractor_dict["kwargs"]["electrical_series_path"] == "/acquisition/timeseries" From c1e4eee519c289899f2650d98e6210d631ae42f2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Jun 2024 00:41:00 +0000 Subject: [PATCH 23/44] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/tests/test_core_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index ed13bd46fd..724517577c 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -54,7 +54,7 @@ def test_path_utils_functions(): @pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") def test_relative_path_on_windows(): - + d = { "kwargs": { "path": r"c:\yep\sub\path1", From efede134e52a0a01e1665cffb5543a696673b525 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 27 Jun 2024 08:50:11 +0100 Subject: [PATCH 24/44] use_names_as_ids update --- src/spikeinterface/extractors/neoextractors/blackrock.py | 5 +++-- src/spikeinterface/extractors/neoextractors/intan.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/blackrock.py b/src/spikeinterface/extractors/neoextractors/blackrock.py index 0015fd9f67..ab3710e05e 100644 --- a/src/spikeinterface/extractors/neoextractors/blackrock.py +++ b/src/spikeinterface/extractors/neoextractors/blackrock.py @@ -26,8 +26,9 @@ class BlackrockRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. - use_names_as_ids : bool or None, default: None - If True, use channel names as IDs. If None, use default IDs. + use_names_as_ids : bool, default: False + If False, use default IDs inherited from Neo. If True, use channel names as IDs. + """ mode = "file" diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index 50fda79123..43439b80c9 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -28,7 +28,9 @@ class IntanRecordingExtractor(NeoBaseRecordingExtractor): check we perform is that timestamps are continuous. Setting this to True will ignore this check and set the attribute `discontinuous_timestamps` to True in the underlying neo object. use_names_as_ids : bool, default: False - If True, use channel names as IDs. If False, use default IDs inherited from neo. + If False, use default IDs inherited from Neo. If True, use channel names as IDs. + + """ mode = "file" From 1aa036885b3fefc3bf8440ee2a7cd71295badf0f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 12:04:03 +0200 Subject: [PATCH 25/44] Move drift_raster_map to motion, typing, docs, and tests --- src/spikeinterface/widgets/driftmap.py | 143 --------- src/spikeinterface/widgets/motion.py | 284 +++++++++++++++--- .../widgets/tests/test_widgets.py | 48 +-- src/spikeinterface/widgets/widget_list.py | 7 +- 4 files changed, 266 insertions(+), 216 deletions(-) delete mode 100644 src/spikeinterface/widgets/driftmap.py diff --git a/src/spikeinterface/widgets/driftmap.py b/src/spikeinterface/widgets/driftmap.py deleted file mode 100644 index 60e8df2972..0000000000 --- a/src/spikeinterface/widgets/driftmap.py +++ /dev/null @@ -1,143 +0,0 @@ -from __future__ import annotations - -import numpy as np - -from .base import BaseWidget, to_attr - - -class DriftMapWidget(BaseWidget): - """ - Plot the a drift map from a motion info dictionary. - - Parameters - ---------- - peaks : np.array - The peaks array, with dtype ("sample_index", "channel_index", "amplitude", "segment_index") - peak_locations : np.array - The peak locations, with dtype ("x", "y") or ("x", "y", "z") - direction : "x" or "y", default: "y" - The direction to display - segment_index : int, default: None - The segment index to display. - recording : RecordingExtractor, default: None - The recording extractor object (only used to get "real" times) - segment_index : int, default: 0 - The segment index to display. - sampling_frequency : float, default: None - The sampling frequency (needed if recording is None) - depth_lim : tuple or None, default: None - The min and max depth to display, if None (min and max of the recording) - color_amplitude : bool, default: True - If True, the color of the scatter points is the amplitude of the peaks - scatter_decimate : int, default: None - If > 1, the scatter points are decimated - cmap : str, default: "inferno" - The colormap to use for the amplitude - clim : tuple or None, default: None - The min and max amplitude to display, if None (min and max of the amplitudes) - alpha : float, default: 1 - The alpha of the scatter points - """ - - def __init__( - self, - peaks, - peak_locations, - direction="y", - recording=None, - sampling_frequency=None, - segment_index=None, - depth_lim=None, - color_amplitude=True, - scatter_decimate=None, - cmap="inferno", - clim=None, - alpha=1, - backend=None, - **backend_kwargs, - ): - if segment_index is None: - assert ( - len(np.unique(peaks["segment_index"])) == 1 - ), "segment_index must be specified if there is only one segment in the peaks array" - assert recording or sampling_frequency, "recording or sampling_frequency must be specified" - if recording is not None: - sampling_frequency = recording.sampling_frequency - times = recording.get_times(segment_index=segment_index) - else: - times = None - - plot_data = dict( - peaks=peaks, - peak_locations=peak_locations, - direction=direction, - times=times, - sampling_frequency=sampling_frequency, - segment_index=segment_index, - depth_lim=depth_lim, - color_amplitude=color_amplitude, - scatter_decimate=scatter_decimate, - cmap=cmap, - clim=clim, - alpha=alpha, - recording=recording, - ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - - def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - from matplotlib.colors import Normalize - - from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks - - dp = to_attr(data_plot) - - assert backend_kwargs["axes"] is None, "axes argument is not allowed in MotionWidget" - - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - fig = self.figure - - if dp.times is None: - # temporal_bins_plot = dp.temporal_bins - x = dp.peaks["sample_index"] / dp.sampling_frequency - else: - # use real times and adjust temporal bins with t_start - # temporal_bins_plot = dp.temporal_bins + dp.times[0] - x = dp.times[dp.peaks["sample_index"]] - - y = dp.peak_locations[dp.direction] - if dp.scatter_decimate is not None: - x = x[:: dp.scatter_decimate] - y = y[:: dp.scatter_decimate] - y2 = y2[:: dp.scatter_decimate] - - if dp.color_amplitude: - amps = dp.peaks["amplitude"] - amps_abs = np.abs(amps) - q_95 = np.quantile(amps_abs, 0.95) - if dp.scatter_decimate is not None: - amps = amps[:: dp.scatter_decimate] - amps_abs = amps_abs[:: dp.scatter_decimate] - cmap = plt.colormaps[dp.cmap] - if dp.clim is None: - amps = amps_abs - amps /= q_95 - c = cmap(amps) - else: - norm_function = Normalize(vmin=dp.clim[0], vmax=dp.clim[1], clip=True) - c = cmap(norm_function(amps)) - color_kwargs = dict( - color=None, - c=c, - alpha=dp.alpha, - ) - else: - color_kwargs = dict(color="k", c=None, alpha=dp.alpha) - - self.ax.scatter(x, y, s=1, **color_kwargs) - if dp.depth_lim is not None: - self.ax.set_ylim(*dp.depth_lim) - self.ax.set_title("Peak depth") - self.ax.set_xlabel("Times [s]") - self.ax.set_ylabel("Depth [$\\mu$m]") diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 7d733523df..ee1599822f 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -3,31 +3,32 @@ import numpy as np from .base import BaseWidget, to_attr -from .driftmap import DriftMapWidget + +from spikeinterface.core import BaseRecording, SortingAnalyzer +from spikeinterface.sortingcomponents.motion_utils import Motion class MotionWidget(BaseWidget): """ - Plot the Motion object + Plot the Motion object. Parameters ---------- motion : Motion - The motion object - segment_index : None | int - If Motion is multi segment, the must be not None - mode : "auto" | "line" | "map" - How to plot map or lines. - "auto" make it automatic if the number of depth is too high. + The motion object. + segment_index : int | None, default: None + If Motion is multi segment, the must be not None. + mode : "auto" | "line" | "map", default: "line" + How to plot map or lines. "auto" makes it automatic if the number of motion depths is too high. """ def __init__( self, - motion, - segment_index=None, - mode="line", - motion_lim=None, - backend=None, + motion: Motion, + segment_index: int | None = None, + mode: str = "line", + motion_lim: float | None = None, + backend: str | None = None, **backend_kwargs, ): if isinstance(motion, dict): @@ -51,19 +52,15 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt from .utils_matplotlib import make_mpl_figure - from matplotlib.colors import Normalize dp = to_attr(data_plot) - motion = data_plot["motion"] - segment_index = data_plot["segment_index"] - assert backend_kwargs["axes"] is None self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + motion = dp.motion displacement = motion.displacement[dp.segment_index] temporal_bins_s = motion.temporal_bins_s[dp.segment_index] depth = motion.spatial_bins_um @@ -97,55 +94,241 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.set_ylabel("Depth [um]") +class DriftRasterMapWidget(BaseWidget): + """ + Plot the drift raster map from peaks or a SortingAnalyzer. + + Parameters + ---------- + peaks : np.array | None, default: None + The peaks array, with dtype ("sample_index", "channel_index", "amplitude", "segment_index"), + as returned by the `detect_peaks` or `correct_motion` functions. + peak_locations : np.array | None, default: None + The peak locations, with dtype ("x", "y") or ("x", "y", "z"), as returned by the + `localize_peaks` or `correct_motion` functions. + sorting_analyzer : SortingAnalyzer | None, default: None + The sorting analyzer object. To use this function, the `SortingAnalyzer` must have the + "spike_locations" extension computed. + direction : "x" or "y", default: "y" + The direction to display. + segment_index : int, default: None + The segment index to display. + recording : RecordingExtractor | None, default: None + The recording extractor object (only used to get "real" times). + segment_index : int, default: 0 + The segment index to display. + sampling_frequency : float, default: None + The sampling frequency (needed if recording is None). + depth_lim : tuple or None, default: None + The min and max depth to display, if None (min and max of the recording). + scatter_decimate : int, default: None + If > 1, the scatter points are decimated. + color_amplitude : bool, default: True + If True, the color of the scatter points is the amplitude of the peaks. + cmap : str, default: "inferno" + The colormap to use for the amplitude. + color : str, default: "Gray" + The color of the scatter points if color_amplitude is False. + clim : tuple or None, default: None + The min and max amplitude to display, if None (min and max of the amplitudes). + alpha : float, default: 1 + The alpha of the scatter points. + """ + + def __init__( + self, + peaks: np.array | None = None, + peak_locations: np.array | None = None, + sorting_analyzer: SortingAnalyzer | None = None, + direction: str = "y", + recording: BaseRecording | None = None, + sampling_frequency: float | None = None, + segment_index: int | None = None, + depth_lim: tuple[float, float] | None = None, + color_amplitude: bool = True, + scatter_decimate: int | None = None, + cmap: str = "inferno", + color: str = "Gray", + clim: tuple[float, float] | None = None, + alpha: float = 1, + backend: str | None = None, + **backend_kwargs, + ): + assert peaks is not None or sorting_analyzer is not None + if peaks is not None: + assert peak_locations is not None + if recording is None: + assert sampling_frequency is not None, "If recording is None, you must provide the sampling frequency" + else: + sampling_frequency = recording.sampling_frequency + peak_amplitudes = peaks["amplitude"] + if sorting_analyzer is not None: + if sorting_analyzer.has_recording(): + recording = sorting_analyzer.recording + else: + recording = None + sampling_frequency = sorting_analyzer.sampling_frequency + peaks = sorting_analyzer.sorting.to_spike_vector() + assert sorting_analyzer.has_extension( + "spike_locations" + ), "The sorting analyzer must have the 'spike_locations' extension to use this function" + peak_locations = sorting_analyzer.get_extension("spike_locations").get_data() + if color_amplitude: + assert sorting_analyzer.has_extension("spike_amplitudes"), ( + "The sorting analyzer must have the 'spike_amplitudes' extension to use color_amplitude=True. " + "You can compute it or set color_amplitude=False." + ) + if sorting_analyzer.has_extension("spike_amplitudes"): + peak_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() + else: + peak_amplitudes = None + times = recording.get_times(segment_index=segment_index) if recording is not None else None + + if segment_index is None: + assert ( + len(np.unique(peaks["segment_index"])) == 1 + ), "segment_index must be specified if there is only one segment in the peaks array" + segment_index = 0 + else: + peak_mask = peaks["segment_index"] == segment_index + peaks = peaks[peak_mask] + peak_locations = peak_locations[peak_mask] + if peak_amplitudes is not None: + peak_amplitudes = peak_amplitudes[peak_mask] + + if recording is not None: + sampling_frequency = recording.sampling_frequency + times = recording.get_times(segment_index=segment_index) + else: + times = None + + plot_data = dict( + peaks=peaks, + peak_locations=peak_locations, + peak_amplitudes=peak_amplitudes, + direction=direction, + times=times, + sampling_frequency=sampling_frequency, + segment_index=segment_index, + depth_lim=depth_lim, + color_amplitude=color_amplitude, + color=color, + scatter_decimate=scatter_decimate, + cmap=cmap, + clim=clim, + alpha=alpha, + recording=recording, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from matplotlib.colors import Normalize + from .utils_matplotlib import make_mpl_figure + + from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks + + dp = to_attr(data_plot) + + assert backend_kwargs["axes"] is None, "axes argument is not allowed in MotionWidget" + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + fig = self.figure + + if dp.times is None: + x = dp.peaks["sample_index"] / dp.sampling_frequency + else: + x = dp.times[dp.peaks["sample_index"]] + + y = dp.peak_locations[dp.direction] + if dp.scatter_decimate is not None: + x = x[:: dp.scatter_decimate] + y = y[:: dp.scatter_decimate] + y2 = y2[:: dp.scatter_decimate] + + if dp.color_amplitude: + amps = dp.peak_amplitudes + amps_abs = np.abs(amps) + q_95 = np.quantile(amps_abs, 0.95) + if dp.scatter_decimate is not None: + amps = amps[:: dp.scatter_decimate] + amps_abs = amps_abs[:: dp.scatter_decimate] + cmap = plt.colormaps[dp.cmap] + if dp.clim is None: + amps = amps_abs + amps /= q_95 + c = cmap(amps) + else: + norm_function = Normalize(vmin=dp.clim[0], vmax=dp.clim[1], clip=True) + c = cmap(norm_function(amps)) + color_kwargs = dict( + color=None, + c=c, + alpha=dp.alpha, + ) + else: + color_kwargs = dict(color=dp.color, c=None, alpha=dp.alpha) + + self.ax.scatter(x, y, s=1, **color_kwargs) + if dp.depth_lim is not None: + self.ax.set_ylim(*dp.depth_lim) + self.ax.set_title("Peak depth") + self.ax.set_xlabel("Times [s]") + self.ax.set_ylabel("Depth [$\\mu$m]") + + class MotionInfoWidget(BaseWidget): """ - Plot motion information from the motion_info dict returned by correct_motion(). - This plot: - * the motion iself - * the peak depth vs time before correction - * the peak depth vs time after correction + Plot motion information from the motion_info dictionary returned by the `correct_motion()` funciton. + This widget plots:: + * the motion iself + * the drift raster map (peak depth vs time) before correction + * the drift raster map (peak depth vs time) after correction Parameters ---------- motion_info : dict - The motion info returned by correct_motion() or loaded back with load_motion_info() + The motion info returned by correct_motion() or loaded back with load_motion_info(). segment_index : int, default: None The segment index to display. recording : RecordingExtractor, default: None - The recording extractor object (only used to get "real" times) + The recording extractor object (only used to get "real" times). segment_index : int, default: 0 The segment index to display. sampling_frequency : float, default: None - The sampling frequency (needed if recording is None) + The sampling frequency (needed if recording is None). depth_lim : tuple or None, default: None - The min and max depth to display, if None (min and max of the recording) + The min and max depth to display, if None (min and max of the recording). motion_lim : tuple or None, default: None - The min and max motion to display, if None (min and max of the motion) - color_amplitude : bool, default: False - If True, the color of the scatter points is the amplitude of the peaks + The min and max motion to display, if None (min and max of the motion). scatter_decimate : int, default: None - If > 1, the scatter points are decimated + If > 1, the scatter points are decimated. + color_amplitude : bool, default: False + If True, the color of the scatter points is the amplitude of the peaks. amplitude_cmap : str, default: "inferno" - The colormap to use for the amplitude + The colormap to use for the amplitude. + amplitude_color : str, default: "Gray" + The color of the scatter points if color_amplitude is False. amplitude_clim : tuple or None, default: None - The min and max amplitude to display, if None (min and max of the amplitudes) + The min and max amplitude to display, if None (min and max of the amplitudes). amplitude_alpha : float, default: 1 - The alpha of the scatter points + The alpha of the scatter points. """ def __init__( self, - motion_info, - segment_index=None, - recording=None, - depth_lim=None, - motion_lim=None, - color_amplitude=False, - scatter_decimate=None, - amplitude_cmap="inferno", - amplitude_clim=None, - amplitude_alpha=1, - backend=None, + motion_info: dict, + segment_index: int | None = None, + recording: BaseRecording | None = None, + depth_lim: tuple[float, float] | None = None, + motion_lim: tuple[float, float] | None = None, + color_amplitude: bool = False, + scatter_decimate: int | None = None, + amplitude_cmap: str = "inferno", + amplitude_color: str = "Gray", + amplitude_clim: tuple[float, float] | None = None, + amplitude_alpha: float = 1, + backend: str | None = None, **backend_kwargs, ): @@ -169,6 +352,7 @@ def __init__( color_amplitude=color_amplitude, scatter_decimate=scatter_decimate, amplitude_cmap=amplitude_cmap, + amplitude_color=amplitude_color, amplitude_clim=amplitude_clim, amplitude_alpha=amplitude_alpha, recording=recording, @@ -178,9 +362,7 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt from .utils_matplotlib import make_mpl_figure - from matplotlib.colors import Normalize from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks @@ -229,15 +411,17 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): recording=dp.recording, segment_index=dp.segment_index, depth_lim=dp.depth_lim, - color_amplitude=dp.color_amplitude, scatter_decimate=dp.scatter_decimate, + color_amplitude=dp.color_amplitude, + color=dp.amplitude_color, cmap=dp.amplitude_cmap, clim=dp.amplitude_clim, alpha=dp.amplitude_alpha, backend="matplotlib", ) - drift_map = DriftMapWidget( + # with immediate_plot=True the widgets are plotted immediately + _ = DriftRasterMapWidget( dp.peaks, dp.peak_locations, ax=ax0, @@ -245,7 +429,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): **commpon_drift_map_kwargs, ) - drift_map_corrected = DriftMapWidget( + _ = DriftRasterMapWidget( dp.peaks, corrected_location, ax=ax1, diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index e841a1c93b..0eef8539cc 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -22,7 +22,7 @@ import spikeinterface.widgets as sw import spikeinterface.comparison as sc -from spikeinterface.preprocessing import scale +from spikeinterface.preprocessing import scale, correct_motion ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) @@ -56,6 +56,9 @@ def setUpClass(cls): cls.recording = recording cls.sorting = sorting + # estimate motion for motion widgets + _, cls.motion_info = correct_motion(recording, preset="kilosort_like", output_motion_info=True) + cls.num_units = len(cls.sorting.get_unit_ids()) extensions_to_compute = dict( @@ -581,9 +584,7 @@ def test_plot_multicomparison(self): sw.plot_multicomparison_agreement_by_sorter(mcmp, axes=axes) def test_plot_motion(self): - from spikeinterface.sortingcomponents.tests.test_motion_utils import make_fake_motion - - motion = make_fake_motion() + motion = self.motion_info["motion"] possible_backends = list(sw.MotionWidget.get_possible_backends()) for backend in possible_backends: @@ -591,22 +592,31 @@ def test_plot_motion(self): sw.plot_motion(motion, backend=backend, mode="line") sw.plot_motion(motion, backend=backend, mode="map") - def test_plot_motion_info(self): - from spikeinterface.sortingcomponents.tests.test_motion_utils import make_fake_motion - - motion = make_fake_motion() - rng = np.random.default_rng(seed=2205) - peak_locations = np.zeros(self.peaks.size, dtype=[("x", "float64"), ("y", "float64")]) - peak_locations["y"] = rng.uniform(motion.spatial_bins_um[0], motion.spatial_bins_um[-1], size=self.peaks.size) - - motion_info = dict( - motion=motion, - parameters=dict(sampling_frequency=30000.0), - run_times=dict(), - peaks=self.peaks, - peak_locations=peak_locations, - ) + def test_drift_raster_map(self): + peaks = self.motion_info["peaks"] + recording = self.recording + peak_locations = self.motion_info["peak_locations"] + analyzer = self.sorting_analyzer_sparse + possible_backends = list(sw.MotionWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + # with recoridng + sw.plot_drift_raster_map( + peaks=peaks, peak_locations=peak_locations, recording=recording, color_amplitude=True + ) + # without recording + sw.plot_drift_raster_map( + peaks=peaks, + peak_locations=peak_locations, + sampling_frequency=recording.sampling_frequency, + color_amplitude=False, + ) + # with analyzer + sw.plot_drift_raster_map(sorting_analyzer=analyzer, color_amplitude=True) + + def test_plot_motion_info(self): + motion_info = self.motion_info possible_backends = list(sw.MotionWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 8d4accaa7e..8163271ec4 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -9,9 +9,8 @@ from .amplitudes import AmplitudesWidget from .autocorrelograms import AutoCorrelogramsWidget from .crosscorrelograms import CrossCorrelogramsWidget -from .driftmap import DriftMapWidget from .isi_distribution import ISIDistributionWidget -from .motion import MotionWidget, MotionInfoWidget +from .motion import DriftRasterMapWidget, MotionWidget, MotionInfoWidget from .multicomparison import MultiCompGraphWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget from .peak_activity import PeakActivityMapWidget from .peaks_on_probe import PeaksOnProbeWidget @@ -45,7 +44,7 @@ ConfusionMatrixWidget, ComparisonCollisionBySimilarityWidget, CrossCorrelogramsWidget, - DriftMapWidget, + DriftRasterMapWidget, ISIDistributionWidget, MotionWidget, MotionInfoWidget, @@ -120,7 +119,7 @@ plot_confusion_matrix = ConfusionMatrixWidget plot_comparison_collision_by_similarity = ComparisonCollisionBySimilarityWidget plot_crosscorrelograms = CrossCorrelogramsWidget -plot_drift_map = DriftMapWidget +plot_drift_raster_map = DriftRasterMapWidget plot_isi_distribution = ISIDistributionWidget plot_motion = MotionWidget plot_motion_info = MotionInfoWidget From 30b60e7eab49bfa47696593e8f7f3506113cda53 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 12:05:47 +0200 Subject: [PATCH 26/44] Add explanation on what drift rastermap is --- src/spikeinterface/widgets/motion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index ee1599822f..66ef2a3f01 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -97,6 +97,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): class DriftRasterMapWidget(BaseWidget): """ Plot the drift raster map from peaks or a SortingAnalyzer. + The drift raster map is a scatter plot of the estimated peak depth vs time and it is + useful to visualize the drift over the course of the recording. Parameters ---------- From 31064ec453f65cac23baa2379991b0996492618b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 12:06:39 +0200 Subject: [PATCH 27/44] Add explanation on 'y' direction --- src/spikeinterface/widgets/motion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 66ef2a3f01..31edbf2f4d 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -112,7 +112,7 @@ class DriftRasterMapWidget(BaseWidget): The sorting analyzer object. To use this function, the `SortingAnalyzer` must have the "spike_locations" extension computed. direction : "x" or "y", default: "y" - The direction to display. + The direction to display. "y" is the depth direction. segment_index : int, default: None The segment index to display. recording : RecordingExtractor | None, default: None From cc550b9622bee8bf11a11b585ee9ff02cb829423 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 12:07:35 +0200 Subject: [PATCH 28/44] Fix segment index error --- src/spikeinterface/widgets/motion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 31edbf2f4d..31a938829d 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -189,7 +189,7 @@ def __init__( if segment_index is None: assert ( len(np.unique(peaks["segment_index"])) == 1 - ), "segment_index must be specified if there is only one segment in the peaks array" + ), "segment_index must be specified if there are multiple segments" segment_index = 0 else: peak_mask = peaks["segment_index"] == segment_index From 80ba2e512f568a2b96ea3e38095bc19f9a987480 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 12:12:03 +0200 Subject: [PATCH 29/44] Review suggestions and test with scatter_decimate --- src/spikeinterface/widgets/motion.py | 16 +++++++--------- src/spikeinterface/widgets/tests/test_widgets.py | 2 +- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 31a938829d..895a8733c7 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -232,21 +232,19 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) - assert backend_kwargs["axes"] is None, "axes argument is not allowed in MotionWidget" + assert backend_kwargs["axes"] is None, "axes argument is not allowed in DriftRasterMapWidget. Use ax instead." self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - fig = self.figure if dp.times is None: - x = dp.peaks["sample_index"] / dp.sampling_frequency + peak_times = dp.peaks["sample_index"] / dp.sampling_frequency else: - x = dp.times[dp.peaks["sample_index"]] + peak_times = dp.times[dp.peaks["sample_index"]] - y = dp.peak_locations[dp.direction] + peak_locs = dp.peak_locations[dp.direction] if dp.scatter_decimate is not None: - x = x[:: dp.scatter_decimate] - y = y[:: dp.scatter_decimate] - y2 = y2[:: dp.scatter_decimate] + peak_times = peak_times[:: dp.scatter_decimate] + peak_locs = peak_locs[:: dp.scatter_decimate] if dp.color_amplitude: amps = dp.peak_amplitudes @@ -271,7 +269,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): else: color_kwargs = dict(color=dp.color, c=None, alpha=dp.alpha) - self.ax.scatter(x, y, s=1, **color_kwargs) + self.ax.scatter(peak_times, peak_locs, s=1, **color_kwargs) if dp.depth_lim is not None: self.ax.set_ylim(*dp.depth_lim) self.ax.set_title("Peak depth") diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 0eef8539cc..7887ecda66 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -613,7 +613,7 @@ def test_drift_raster_map(self): color_amplitude=False, ) # with analyzer - sw.plot_drift_raster_map(sorting_analyzer=analyzer, color_amplitude=True) + sw.plot_drift_raster_map(sorting_analyzer=analyzer, color_amplitude=True, scatter_decimate=2) def test_plot_motion_info(self): motion_info = self.motion_info From 3e9f342e6a8d7695186c2aef4e12cde30d984cea Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 12:22:09 +0200 Subject: [PATCH 30/44] Mark failing sorter test on Windows*Python3.12 as xfail --- src/spikeinterface/sorters/tests/test_runsorter.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/spikeinterface/sorters/tests/test_runsorter.py b/src/spikeinterface/sorters/tests/test_runsorter.py index 470bdc3602..6bd73c5691 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter.py +++ b/src/spikeinterface/sorters/tests/test_runsorter.py @@ -1,7 +1,9 @@ import os +import platform import pytest from pathlib import Path import shutil +from packaging.version import parse from spikeinterface import generate_ground_truth_recording from spikeinterface.sorters import run_sorter @@ -19,6 +21,10 @@ def generate_recording(): return _generate_recording() +@pytest.mark.xfail( + platform.system() == "Windows" and parse(platform.python_version()) > parse("3.12"), + reason="3rd parth threadpoolctl issue: OSError('GetModuleFileNameEx failed')", +) def test_run_sorter_local(generate_recording, create_cache_folder): recording = generate_recording cache_folder = create_cache_folder From d1d65f6ca6338ac2dd8d6f9c99ee657f0db76d21 Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 11:58:23 +0100 Subject: [PATCH 31/44] estimate_sparsity arg ordering --- src/spikeinterface/core/sortinganalyzer.py | 2 +- src/spikeinterface/core/sparsity.py | 6 +++--- src/spikeinterface/core/tests/test_sparsity.py | 4 ++-- .../postprocessing/tests/common_extension_tests.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 53e060262b..62b7f9e7c0 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -127,7 +127,7 @@ def create_sorting_analyzer( recording.channel_ids, sparsity.channel_ids ), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond" elif sparse: - sparsity = estimate_sparsity(recording, sorting, **sparsity_kwargs) + sparsity = estimate_sparsity(sorting, recording, **sparsity_kwargs) else: sparsity = None diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index cefd7bd950..1cd7822f99 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -539,8 +539,8 @@ def compute_sparsity( def estimate_sparsity( - recording: BaseRecording, sorting: BaseSorting, + recording: BaseRecording, num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, @@ -563,10 +563,10 @@ def estimate_sparsity( Parameters ---------- - recording: BaseRecording - The recording sorting: BaseSorting The sorting + recording: BaseRecording + The recording num_spikes_for_sparsity: int, default: 100 How many spikes per units to compute the sparsity ms_before: float, default: 1.0 diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 98d033d8ea..a192d90502 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -166,8 +166,8 @@ def test_estimate_sparsity(): # small radius should give a very sparse = one channel per unit sparsity = estimate_sparsity( - recording, sorting, + recording, num_spikes_for_sparsity=50, ms_before=1.0, ms_after=2.0, @@ -182,8 +182,8 @@ def test_estimate_sparsity(): # best_channel : the mask should exactly 3 channels per units sparsity = estimate_sparsity( - recording, sorting, + recording, num_spikes_for_sparsity=50, ms_before=1.0, ms_after=2.0, diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index bf462a9466..8c46fa5e24 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -79,7 +79,7 @@ class AnalyzerExtensionCommonTestSuite: def setUpClass(cls): cls.recording, cls.sorting = get_dataset() # sparsity is computed once for all cases to save processing time and force a small radius - cls.sparsity = estimate_sparsity(cls.recording, cls.sorting, method="radius", radius_um=20) + cls.sparsity = estimate_sparsity(cls.sorting, cls.recording, method="radius", radius_um=20) @property def extension_name(self): From 02ae32a857c9ce59a54deffcc1465a3d975342aa Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 14:22:18 +0200 Subject: [PATCH 32/44] Update src/spikeinterface/widgets/motion.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/widgets/motion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 895a8733c7..5f0e02fdab 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -280,7 +280,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): class MotionInfoWidget(BaseWidget): """ Plot motion information from the motion_info dictionary returned by the `correct_motion()` funciton. - This widget plots:: + This widget plots: * the motion iself * the drift raster map (peak depth vs time) before correction * the drift raster map (peak depth vs time) after correction From c111cfcacb1c80b4166320c3f3753a2a7d629f69 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 14:22:28 +0200 Subject: [PATCH 33/44] Update src/spikeinterface/widgets/tests/test_widgets.py Co-authored-by: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> --- src/spikeinterface/widgets/tests/test_widgets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 7887ecda66..012b1ac07c 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -601,7 +601,7 @@ def test_drift_raster_map(self): possible_backends = list(sw.MotionWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - # with recoridng + # with recording sw.plot_drift_raster_map( peaks=peaks, peak_locations=peak_locations, recording=recording, color_amplitude=True ) From 12d823bb0dc7e1536486508c473f0ce5562e395a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 14:37:10 +0200 Subject: [PATCH 34/44] Better docs for plot mode (line, map, auto) --- src/spikeinterface/widgets/motion.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 895a8733c7..766938299a 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -19,7 +19,10 @@ class MotionWidget(BaseWidget): segment_index : int | None, default: None If Motion is multi segment, the must be not None. mode : "auto" | "line" | "map", default: "line" - How to plot map or lines. "auto" makes it automatic if the number of motion depths is too high. + How to plot the motion. + "line" plots estimated motion at different depths as lines. + "map" plots estimated motion at different depths as a heatmap. + "auto" makes it automatic depending on the number of motion depths. """ def __init__( From a3deed8211f9b20e3acbe41f9b7297e285ba68ed Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 14:39:07 +0200 Subject: [PATCH 35/44] Remove duplicated line --- src/spikeinterface/widgets/motion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index bf9010c144..0b79350a62 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -187,7 +187,6 @@ def __init__( peak_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() else: peak_amplitudes = None - times = recording.get_times(segment_index=segment_index) if recording is not None else None if segment_index is None: assert ( From 2cc719986e5d6fceb9ea828206d7cf1d9a3fef9a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 27 Jun 2024 08:11:55 -0600 Subject: [PATCH 36/44] @alejo91 suggestion --- src/spikeinterface/core/core_tools.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index d5480d6f00..066ab58d8c 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -184,10 +184,10 @@ def is_dict_extractor(d: dict) -> bool: return is_extractor -recording_dict_element = namedtuple(typename="recording_dict_element", field_names=["value", "name", "access_path"]) +extractor_dict_element = namedtuple(typename="extractor_dict_element", field_names=["value", "name", "access_path"]) -def extractor_dict_iterator(extractor_dict: dict) -> Generator[recording_dict_element]: +def extractor_dict_iterator(extractor_dict: dict) -> Generator[extractor_dict_element]: """ Iterator for recursive traversal of a dictionary. This function explores the dictionary recursively and yields the path to each value along with the value itself. @@ -204,7 +204,7 @@ def extractor_dict_iterator(extractor_dict: dict) -> Generator[recording_dict_el Yields ------ - recording_dict_element + extractor_dict_element Named tuple containing the value, the name, and the access_path to the value in the dictionary. """ @@ -219,7 +219,7 @@ def _extractor_dict_iterator(dict_list_or_value, access_path=(), name=""): v, access_path + (i,), name=name ) # Propagate name of list to children else: - yield recording_dict_element( + yield extractor_dict_element( value=dict_list_or_value, name=name, access_path=access_path, @@ -320,7 +320,7 @@ def recursive_path_modifier(d, func, target="path", copy=True) -> dict: raise ValueError(f"{k} key for path must be str or list[str]") -# This is the current definition that an element in a recording_dict is a path +# This is the current definition that an element in a extractor_dict is a path # This is shared across a couple of definition so it is here for DNRY element_is_path = lambda element: "path" in element.name and isinstance(element.value, (str, Path)) From 61060781eef87597461241aec077aac27baff69b Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:15:14 +0100 Subject: [PATCH 37/44] SpikeRetriever arg switch --- src/spikeinterface/core/node_pipeline.py | 16 +-- .../core/tests/test_node_pipeline.py | 4 +- .../tests/test_train_manual_curation.py | 120 ++++++++++++++++++ .../postprocessing/amplitude_scalings.py | 2 +- .../postprocessing/spike_amplitudes.py | 2 +- .../postprocessing/spike_locations.py | 2 +- 6 files changed, 133 insertions(+), 13 deletions(-) create mode 100644 src/spikeinterface/curation/tests/test_train_manual_curation.py diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 1c0107d235..0722ede23f 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -152,29 +152,29 @@ class SpikeRetriever(PeakSource): * compute_spike_amplitudes() * compute_principal_components() + sorting : BaseSorting + The sorting object. recording : BaseRecording The recording object. - sorting: BaseSorting - The sorting object. - channel_from_template: bool, default: True + channel_from_template : bool, default: True If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided. If False, the max channel is computed for each spike given a radius around the template max channel. - extremum_channel_inds: dict of int | None, default: None + extremum_channel_inds : dict of int | None, default: None The extremum channel index dict given from template. - radius_um: float, default: 50 + radius_um : float, default: 50 The radius to find the real max channel. Used only when channel_from_template=False - peak_sign: "neg" | "pos", default: "neg" + peak_sign : "neg" | "pos", default: "neg" Peak sign to find the max channel. Used only when channel_from_template=False - include_spikes_in_margin: bool, default False + include_spikes_in_margin : bool, default False If not None then spikes in margin are added and an extra filed in dtype is added """ def __init__( self, - recording, sorting, + recording, channel_from_template=True, extremum_channel_inds=None, radius_um=50, diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 03acc9fed1..8d788acbad 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -87,12 +87,12 @@ def test_run_node_pipeline(cache_folder_creation): peak_retriever = PeakRetriever(recording, peaks) # channel index is from template spike_retriever_T = SpikeRetriever( - recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds + sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds ) # channel index is per spike spike_retriever_S = SpikeRetriever( - recording, sorting, + recording, channel_from_template=False, extremum_channel_inds=extremum_channel_inds, radius_um=50, diff --git a/src/spikeinterface/curation/tests/test_train_manual_curation.py b/src/spikeinterface/curation/tests/test_train_manual_curation.py new file mode 100644 index 0000000000..f0f9ff4d75 --- /dev/null +++ b/src/spikeinterface/curation/tests/test_train_manual_curation.py @@ -0,0 +1,120 @@ +import pytest +import pandas as pd +import os +import shutil + +from spikeinterface.curation.train_manual_curation import CurationModelTrainer, Objective, train_model + +# Sample data for testing +data = { + 'num_spikes': [1, 2, 3, 4, 5, 6], + 'firing_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'presence_ratio': [0.9, 0.8, 0.7, 0.6, 0.5, 0.4], + 'isi_violations_ratio': [0.01, 0.02, 0.03, 0.04, 0.05, 0.06], + 'amplitude_cutoff': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'amplitude_median': [0.2, 0.3, 0.4, 0.5, 0.6, 0.7], + 'amplitude_cv_median': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'amplitude_cv_range': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'sync_spike_2': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'sync_spike_4': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'sync_spike_8': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'firing_range': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'drift_ptp': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'drift_std': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'drift_mad': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'isolation_distance': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'l_ratio': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'd_prime': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'silhouette': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'nn_hit_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'nn_miss_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'peak_to_valley': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'peak_trough_ratio': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'half_width': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'repolarization_slope': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'recovery_slope': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'num_positive_peaks': [1, 2, 3, 4, 5, 6], + 'num_negative_peaks': [1, 2, 3, 4, 5, 6], + 'velocity_above': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'velocity_below': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'exp_decay': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'spread': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'is_noise': [0, 1, 0, 1, 0, 1], + 'is_sua': [1, 0, 1, 0, 1, 0], + 'majority_vote': ['good', 'bad', 'good', 'bad', 'good', 'bad'] +} + +df = pd.DataFrame(data) + +# Test initialization +def test_initialization(): + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + assert trainer.output_folder == '/tmp' + assert trainer.curator_column == 'num_spikes' + assert trainer.imputation_strategies is not None + assert trainer.scaling_techniques is not None + +# Test load_data_file +def test_load_data_file(): + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + df.to_csv('/tmp/test.csv', index=False) + trainer.load_data_file('/tmp/test.csv') + assert trainer.testing_metrics is not None + assert 0 in trainer.testing_metrics + +# Test process_test_data_for_classification +def test_process_test_data_for_classification(): + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + trainer.testing_metrics = {0: df} + trainer.process_test_data_for_classification() + assert trainer.noise_test is not None + assert trainer.sua_mua_test is not None + +# Test apply_scaling_imputation +def test_apply_scaling_imputation(): + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + X_train = df.drop(columns=['is_noise', 'is_sua', 'majority_vote']) + X_val = df.drop(columns=['is_noise', 'is_sua', 'majority_vote']) + y_train = df['is_noise'] + y_val = df['is_noise'] + result = trainer.apply_scaling_imputation('median', trainer.scaling_techniques[0][1], X_train, X_val, y_train, y_val) + assert result is not None + +# Test get_classifier_search_space +def test_get_classifier_search_space(): + from sklearn.linear_model import LogisticRegression + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + model, param_space = trainer.get_classifier_search_space(LogisticRegression) + assert model is not None + assert param_space is not None + +# Test Objective Enum +def test_objective_enum(): + assert Objective.Noise == Objective(1) + assert Objective.SUA == Objective(2) + assert str(Objective.Noise) == "Objective.Noise" + assert str(Objective.SUA) == "Objective.SUA" + +# Test train_model function +def test_train_model(monkeypatch): + output_folder = '/tmp/output' + os.makedirs(output_folder, exist_ok=True) + df.to_csv('/tmp/metrics.csv', index=False) + + def mock_load_and_preprocess_full(self, path): + self.testing_metrics = {0: df} + self.process_test_data_for_classification() + + monkeypatch.setattr(CurationModelTrainer, 'load_and_preprocess_full', mock_load_and_preprocess_full) + + trainer = train_model('/tmp/metrics.csv', output_folder, 'is_noise') + assert trainer is not None + assert trainer.testing_metrics is not None + assert 0 in trainer.testing_metrics + +# Clean up temporary files +@pytest.fixture(scope="module", autouse=True) +def cleanup(request): + def remove_tmp(): + shutil.rmtree('/tmp', ignore_errors=True) + request.addfinalizer(remove_tmp) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 2e544d086b..8ff9cc5666 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -170,8 +170,8 @@ def _get_pipeline_nodes(self): sparsity_mask = sparsity.mask spike_retriever_node = SpikeRetriever( - recording, sorting, + recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices, include_spikes_in_margin=True, diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index aebfd1fd78..72cbcb651f 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -95,7 +95,7 @@ def _get_pipeline_nodes(self): peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_analyzer, peak_sign=peak_sign) spike_retriever_node = SpikeRetriever( - recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channels_indices + sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices ) spike_amplitudes_node = SpikeAmplitudeNode( recording, diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 52a91342b6..23301292e5 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -103,8 +103,8 @@ def _get_pipeline_nodes(self): ) retriever = SpikeRetriever( - recording, sorting, + recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices, ) From 722c313382b6ac225a2c9119c676bc1bcab6e480 Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:17:43 +0100 Subject: [PATCH 38/44] has_exceeding_spikes arg switch --- src/spikeinterface/core/basesorting.py | 2 +- src/spikeinterface/core/frameslicesorting.py | 2 +- src/spikeinterface/core/waveform_tools.py | 2 +- src/spikeinterface/curation/remove_excess_spikes.py | 2 +- .../curation/tests/test_remove_excess_spikes.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index fd68df9dda..d9a567dedf 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -197,7 +197,7 @@ def register_recording(self, recording, check_spike_frames=True): self.get_num_segments() == recording.get_num_segments() ), "The recording has a different number of segments than the sorting!" if check_spike_frames: - if has_exceeding_spikes(recording, self): + if has_exceeding_spikes(self, recording): warnings.warn( "Some spikes exceed the recording's duration! " "Removing these excess spikes with `spikeinterface.curation.remove_excess_spikes()` " diff --git a/src/spikeinterface/core/frameslicesorting.py b/src/spikeinterface/core/frameslicesorting.py index ffd8af5fd8..f3ec449ab0 100644 --- a/src/spikeinterface/core/frameslicesorting.py +++ b/src/spikeinterface/core/frameslicesorting.py @@ -54,7 +54,7 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike assert ( start_frame <= parent_n_samples ), "`start_frame` should be smaller than the sortings' total number of samples." - if check_spike_frames and has_exceeding_spikes(parent_sorting._recording, parent_sorting): + if check_spike_frames and has_exceeding_spikes(parent_sorting, parent_sorting._recording): raise ValueError( "The sorting object has spikes whose times go beyond the recording duration." "This could indicate a bug in the sorter. " diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index befc49d034..4543074872 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -679,7 +679,7 @@ def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None return waveforms_by_units -def has_exceeding_spikes(recording, sorting) -> bool: +def has_exceeding_spikes(sorting, recording) -> bool: """ Check if the sorting objects has spikes exceeding the recording number of samples, for all segments diff --git a/src/spikeinterface/curation/remove_excess_spikes.py b/src/spikeinterface/curation/remove_excess_spikes.py index 0ae7a59fc6..d1d6b7f3cb 100644 --- a/src/spikeinterface/curation/remove_excess_spikes.py +++ b/src/spikeinterface/curation/remove_excess_spikes.py @@ -102,7 +102,7 @@ def remove_excess_spikes(sorting, recording): sorting_without_excess_spikes : Sorting The sorting without any excess spikes. """ - if has_exceeding_spikes(recording=recording, sorting=sorting): + if has_exceeding_spikes(sorting=sorting, recording=recording): return RemoveExcessSpikesSorting(sorting=sorting, recording=recording) else: return sorting diff --git a/src/spikeinterface/curation/tests/test_remove_excess_spikes.py b/src/spikeinterface/curation/tests/test_remove_excess_spikes.py index 69edbaba4c..141cc4c34e 100644 --- a/src/spikeinterface/curation/tests/test_remove_excess_spikes.py +++ b/src/spikeinterface/curation/tests/test_remove_excess_spikes.py @@ -39,10 +39,10 @@ def test_remove_excess_spikes(): labels.append(labels_segment) sorting = NumpySorting.from_times_labels(times, labels, sampling_frequency=sampling_frequency) - assert has_exceeding_spikes(recording, sorting) + assert has_exceeding_spikes(sorting, recording) sorting_corrected = remove_excess_spikes(sorting, recording) - assert not has_exceeding_spikes(recording, sorting_corrected) + assert not has_exceeding_spikes(sorting_corrected, recording) for u in sorting.unit_ids: for segment_index in range(sorting.get_num_segments()): From d0968c4c941e290488848d14c6881c7a2cdf9c8c Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:19:24 +0100 Subject: [PATCH 39/44] removed accidental commit --- .../tests/test_train_manual_curation.py | 120 ------------------ 1 file changed, 120 deletions(-) delete mode 100644 src/spikeinterface/curation/tests/test_train_manual_curation.py diff --git a/src/spikeinterface/curation/tests/test_train_manual_curation.py b/src/spikeinterface/curation/tests/test_train_manual_curation.py deleted file mode 100644 index f0f9ff4d75..0000000000 --- a/src/spikeinterface/curation/tests/test_train_manual_curation.py +++ /dev/null @@ -1,120 +0,0 @@ -import pytest -import pandas as pd -import os -import shutil - -from spikeinterface.curation.train_manual_curation import CurationModelTrainer, Objective, train_model - -# Sample data for testing -data = { - 'num_spikes': [1, 2, 3, 4, 5, 6], - 'firing_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'presence_ratio': [0.9, 0.8, 0.7, 0.6, 0.5, 0.4], - 'isi_violations_ratio': [0.01, 0.02, 0.03, 0.04, 0.05, 0.06], - 'amplitude_cutoff': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'amplitude_median': [0.2, 0.3, 0.4, 0.5, 0.6, 0.7], - 'amplitude_cv_median': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'amplitude_cv_range': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'sync_spike_2': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'sync_spike_4': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'sync_spike_8': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'firing_range': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'drift_ptp': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'drift_std': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'drift_mad': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'isolation_distance': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'l_ratio': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'd_prime': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'silhouette': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'nn_hit_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'nn_miss_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'peak_to_valley': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'peak_trough_ratio': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'half_width': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'repolarization_slope': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'recovery_slope': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'num_positive_peaks': [1, 2, 3, 4, 5, 6], - 'num_negative_peaks': [1, 2, 3, 4, 5, 6], - 'velocity_above': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'velocity_below': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'exp_decay': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'spread': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'is_noise': [0, 1, 0, 1, 0, 1], - 'is_sua': [1, 0, 1, 0, 1, 0], - 'majority_vote': ['good', 'bad', 'good', 'bad', 'good', 'bad'] -} - -df = pd.DataFrame(data) - -# Test initialization -def test_initialization(): - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - assert trainer.output_folder == '/tmp' - assert trainer.curator_column == 'num_spikes' - assert trainer.imputation_strategies is not None - assert trainer.scaling_techniques is not None - -# Test load_data_file -def test_load_data_file(): - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - df.to_csv('/tmp/test.csv', index=False) - trainer.load_data_file('/tmp/test.csv') - assert trainer.testing_metrics is not None - assert 0 in trainer.testing_metrics - -# Test process_test_data_for_classification -def test_process_test_data_for_classification(): - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - trainer.testing_metrics = {0: df} - trainer.process_test_data_for_classification() - assert trainer.noise_test is not None - assert trainer.sua_mua_test is not None - -# Test apply_scaling_imputation -def test_apply_scaling_imputation(): - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - X_train = df.drop(columns=['is_noise', 'is_sua', 'majority_vote']) - X_val = df.drop(columns=['is_noise', 'is_sua', 'majority_vote']) - y_train = df['is_noise'] - y_val = df['is_noise'] - result = trainer.apply_scaling_imputation('median', trainer.scaling_techniques[0][1], X_train, X_val, y_train, y_val) - assert result is not None - -# Test get_classifier_search_space -def test_get_classifier_search_space(): - from sklearn.linear_model import LogisticRegression - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - model, param_space = trainer.get_classifier_search_space(LogisticRegression) - assert model is not None - assert param_space is not None - -# Test Objective Enum -def test_objective_enum(): - assert Objective.Noise == Objective(1) - assert Objective.SUA == Objective(2) - assert str(Objective.Noise) == "Objective.Noise" - assert str(Objective.SUA) == "Objective.SUA" - -# Test train_model function -def test_train_model(monkeypatch): - output_folder = '/tmp/output' - os.makedirs(output_folder, exist_ok=True) - df.to_csv('/tmp/metrics.csv', index=False) - - def mock_load_and_preprocess_full(self, path): - self.testing_metrics = {0: df} - self.process_test_data_for_classification() - - monkeypatch.setattr(CurationModelTrainer, 'load_and_preprocess_full', mock_load_and_preprocess_full) - - trainer = train_model('/tmp/metrics.csv', output_folder, 'is_noise') - assert trainer is not None - assert trainer.testing_metrics is not None - assert 0 in trainer.testing_metrics - -# Clean up temporary files -@pytest.fixture(scope="module", autouse=True) -def cleanup(request): - def remove_tmp(): - shutil.rmtree('/tmp', ignore_errors=True) - request.addfinalizer(remove_tmp) From f687c2c2fe9b70a970cfd39d6dd7b134c15e065f Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:20:32 +0100 Subject: [PATCH 40/44] docs --- src/spikeinterface/core/waveform_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 4543074872..98380e955f 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -685,10 +685,10 @@ def has_exceeding_spikes(sorting, recording) -> bool: Parameters ---------- - recording : BaseRecording - The recording object sorting : BaseSorting The sorting object + recording : BaseRecording + The recording object Returns ------- From b8c8fa83ba8695545b420d135c92f5167d7d2de1 Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:54:59 +0100 Subject: [PATCH 41/44] Missed one --- .../postprocessing/tests/common_extension_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index bb2f5aaafd..52dbaf23d4 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -73,7 +73,7 @@ class instance is used for each. In this case, we have to set self.__class__.recording, self.__class__.sorting = get_dataset() self.__class__.sparsity = estimate_sparsity( - self.__class__.recording, self.__class__.sorting, method="radius", radius_um=20 + self.__class__.sorting, self.__class__.recording, method="radius", radius_um=20 ) self.__class__.cache_folder = create_cache_folder From 3eee955a8da3989dda6cbd84b25c0eabc2222527 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 27 Jun 2024 09:01:15 -0600 Subject: [PATCH 42/44] make test skipif --- .../core/tests/test_core_tools.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index 724517577c..7153991543 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -31,25 +31,25 @@ def test_add_suffix(): assert str(file_path_with_suffix) == expected_path +@pytest.mark.skipif(platform.system() == "Windows", reason="Runs on posix only") def test_path_utils_functions(): - if platform.system() != "Windows": - # posix path - d = { - "kwargs": { - "path": "/yep/sub/path1", - "recording": { - "module": "mock_module", - "class": "mock_class", - "version": "1.2", - "annotations": {}, - "kwargs": {"path": "/yep/sub/path2"}, - }, - } + # posix path + d = { + "kwargs": { + "path": "/yep/sub/path1", + "recording": { + "module": "mock_module", + "class": "mock_class", + "version": "1.2", + "annotations": {}, + "kwargs": {"path": "/yep/sub/path2"}, + }, } + } - d2 = recursive_path_modifier(d, lambda p: p.replace("/yep", "/yop")) - assert d2["kwargs"]["path"].startswith("/yop") - assert d2["kwargs"]["recording"]["kwargs"]["path"].startswith("/yop") + d2 = recursive_path_modifier(d, lambda p: p.replace("/yep", "/yop")) + assert d2["kwargs"]["path"].startswith("/yop") + assert d2["kwargs"]["recording"]["kwargs"]["path"].startswith("/yop") @pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") From c24c9669dcd8e53246c376c6d33eebbf39cbab83 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Thu, 27 Jun 2024 18:32:22 +0100 Subject: [PATCH 43/44] Add *sg_execution_times.rst to gitignore. (#3097) --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index d981c8de4e..6c9fa6869f 100644 --- a/.gitignore +++ b/.gitignore @@ -180,6 +180,7 @@ examples/tutorials/*.svg doc/_build/* doc/tutorials/* doc/sources/* +*sg_execution_times.rst examples/getting_started/tmp_* examples/getting_started/phy From d5ec1806bf41c27317f60e7c96cf71972400774b Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 28 Jun 2024 16:58:30 -0400 Subject: [PATCH 44/44] get rid of waveform term --- src/spikeinterface/widgets/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index b94167d2b7..9566989d31 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -139,7 +139,7 @@ def check_extensions(sorting_analyzer, extensions): if not sorting_analyzer.has_extension(extension): raise_error = True error_msg += ( - f"The {extension} waveform extension is required for this widget. " + f"The {extension} sorting analyzer extension is required for this widget. " f"Run the `sorting_analyzer.compute('{extension}', ...)` to compute it.\n" ) if raise_error: