From 0851ab0f5417847661fdef6484823a7550e898cf Mon Sep 17 00:00:00 2001 From: Manish Mohapatra Date: Sun, 28 Jan 2024 20:36:56 -0500 Subject: [PATCH 001/103] Implementing read_spikeglx_event() --- .../extractors/neoextractors/spikeglx.py | 44 ++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index 6a6901b62e..b1b9a1a700 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -12,7 +12,7 @@ from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts -from .neobaseextractor import NeoBaseRecordingExtractor +from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseEventExtractor class SpikeGLXRecordingExtractor(NeoBaseRecordingExtractor): @@ -100,3 +100,45 @@ def map_to_neo_kwargs(cls, folder_path, load_sync_channel=False): read_spikeglx = define_function_from_class(source_class=SpikeGLXRecordingExtractor, name="read_spikeglx") + +class SpikeGLXEventExtractor(NeoBaseEventExtractor): + """ + Class for reading events saved on the event channel by SpikeGLX software. + + Parameters + ---------- + folder_path: str + + """ + + mode = "folder" + NeoRawIOClass = "SpikeGLXRawIO" + name = "spikeglx" + + def __init__(self, folder_path, block_index=None): + neo_kwargs = self.map_to_neo_kwargs(folder_path) + NeoBaseEventExtractor.__init__(self, block_index=block_index, **neo_kwargs) + + @classmethod + def map_to_neo_kwargs(cls, folder_path): + neo_kwargs = {"dirname": str(folder_path)} + return neo_kwargs + +def read_spikeglx_event(folder_path, block_index=None): + """ + Read SpikeGLX events + + Parameters + ---------- + folder_path: str or Path + Path to openephys folder + block_index: int, default: None + If there are several blocks (experiments), specify the block index you want to load. + + Returns + ------- + event: SpikeGLXEventExtractor + """ + + event = SpikeGLXEventExtractor(folder_path, block_index=block_index) + return event \ No newline at end of file From 1ca24c83d19214367109a6c30c1ec773530e9fda Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Jan 2024 01:37:45 +0000 Subject: [PATCH 002/103] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/extractors/neoextractors/spikeglx.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index b1b9a1a700..96c8b98d57 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -101,6 +101,7 @@ def map_to_neo_kwargs(cls, folder_path, load_sync_channel=False): read_spikeglx = define_function_from_class(source_class=SpikeGLXRecordingExtractor, name="read_spikeglx") + class SpikeGLXEventExtractor(NeoBaseEventExtractor): """ Class for reading events saved on the event channel by SpikeGLX software. @@ -124,6 +125,7 @@ def map_to_neo_kwargs(cls, folder_path): neo_kwargs = {"dirname": str(folder_path)} return neo_kwargs + def read_spikeglx_event(folder_path, block_index=None): """ Read SpikeGLX events @@ -141,4 +143,4 @@ def read_spikeglx_event(folder_path, block_index=None): """ event = SpikeGLXEventExtractor(folder_path, block_index=block_index) - return event \ No newline at end of file + return event From ee25d9604d5b5c77f6e547774d1bffe2f3d7b2ff Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 17 Apr 2024 15:24:40 +0200 Subject: [PATCH 003/103] Add sinaps research platform recording --- .../extractors/extractorlist.py | 2 + .../extractors/sinapsrecordingextractor.py | 97 +++++++++++++++++++ 2 files changed, 99 insertions(+) create mode 100644 src/spikeinterface/extractors/sinapsrecordingextractor.py diff --git a/src/spikeinterface/extractors/extractorlist.py b/src/spikeinterface/extractors/extractorlist.py index 228f7085bd..4957202c56 100644 --- a/src/spikeinterface/extractors/extractorlist.py +++ b/src/spikeinterface/extractors/extractorlist.py @@ -45,6 +45,7 @@ from .herdingspikesextractors import HerdingspikesSortingExtractor, read_herdingspikes from .mdaextractors import MdaRecordingExtractor, MdaSortingExtractor, read_mda_recording, read_mda_sorting from .phykilosortextractors import PhySortingExtractor, KiloSortSortingExtractor, read_phy, read_kilosort +from .sinapsrecordingextractor import SinapsResearchPlatformRecordingExtractor, read_sinaps_research_platform # sorting in relation with simulator from .shybridextractors import ( @@ -77,6 +78,7 @@ CompressedBinaryIblExtractor, IblRecordingExtractor, MCSH5RecordingExtractor, + SinapsResearchPlatformRecordingExtractor, ] recording_extractor_full_list += neo_recording_extractors_list diff --git a/src/spikeinterface/extractors/sinapsrecordingextractor.py b/src/spikeinterface/extractors/sinapsrecordingextractor.py new file mode 100644 index 0000000000..fd42401b65 --- /dev/null +++ b/src/spikeinterface/extractors/sinapsrecordingextractor.py @@ -0,0 +1,97 @@ +from pathlib import Path +import numpy as np + +from ..core import BinaryRecordingExtractor, ChannelSliceRecording +from ..core.core_tools import define_function_from_class + + +class SinapsResearchPlatformRecordingExtractor(ChannelSliceRecording): + extractor_name = "SinapsResearchPlatform" + mode = "file" + name = "sinaps_research_platform" + + def __init__(self, file_path, stream_name="filt"): + from ..preprocessing import UnsignedToSignedRecording + + file_path = Path(file_path) + meta_file = file_path.parent / f"metadata_{file_path.stem}.txt" + meta = parse_sinaps_meta(meta_file) + + num_aux_channels = meta["nbHWAux"] + meta["numberUserAUX"] + num_total_channels = 2 * meta["nbElectrodes"] + num_aux_channels + num_electrodes = meta["nbElectrodes"] + sampling_frequency = meta["samplingFreq"] + + channel_locations = meta["electrodePhysicalPosition"] + num_shanks = meta["nbShanks"] + num_electrodes_per_shank = meta["nbElectrodesShank"] + num_bits = int(np.log2(meta["nbADCLevels"])) + + channel_groups = [] + for i in range(num_shanks): + channel_groups.extend([i] * num_electrodes_per_shank) + + gain_ephys = meta["voltageConverter"] + gain_aux = meta["voltageAUXConverter"] + + recording = BinaryRecordingExtractor( + file_path, sampling_frequency, dtype="uint16", num_channels=num_total_channels + ) + recording = UnsignedToSignedRecording(recording, bit_depth=num_bits) + + if stream_name == "raw": + channel_slice = recording.channel_ids[:num_electrodes] + renamed_channels = np.arange(num_electrodes) + locations = channel_locations + groups = channel_groups + gain = gain_ephys + elif stream_name == "filt": + channel_slice = recording.channel_ids[num_electrodes : 2 * num_electrodes] + renamed_channels = np.arange(num_electrodes) + locations = channel_locations + groups = channel_groups + gain = gain_ephys + elif stream_name == "aux": + channel_slice = recording.channel_ids[2 * num_electrodes :] + hw_chans = meta["hwAUXChannelName"][1:-1].split(",") + user_chans = meta["userAuxName"][1:-1].split(",") + renamed_channels = hw_chans + user_chans + locations = None + groups = None + gain = gain_aux + else: + raise ValueError("stream_name must be 'raw', 'filt', or 'aux'") + + ChannelSliceRecording.__init__(self, recording, channel_ids=channel_slice, renamed_channel_ids=renamed_channels) + if locations is not None: + self.set_channel_locations(locations) + if groups is not None: + self.set_channel_groups(groups) + self.set_channel_gains(gain) + + +read_sinaps_research_platform = define_function_from_class( + source_class=SinapsResearchPlatformRecordingExtractor, name="read_sinaps_research_platform" +) + + +def parse_sinaps_meta(meta_file): + meta_dict = {} + with open(meta_file) as f: + lines = f.readlines() + for l in lines: + if "**" in l or "=" not in l: + continue + else: + key, val = l.split("=") + val = val.replace("\n", "") + try: + val = int(val) + except: + pass + try: + val = eval(val) + except: + pass + meta_dict[key] = val + return meta_dict From ddbce702355c5a4a66969a64709ddf5fd7cc7b73 Mon Sep 17 00:00:00 2001 From: Nina Kudryashova Date: Mon, 22 Apr 2024 15:09:29 +0100 Subject: [PATCH 004/103] Add an H5 extractor for sinaps research platform --- .../extractors/extractorlist.py | 1 + .../extractors/sinapsrecordingh5extractor.py | 112 ++++++++++++++++++ 2 files changed, 113 insertions(+) create mode 100644 src/spikeinterface/extractors/sinapsrecordingh5extractor.py diff --git a/src/spikeinterface/extractors/extractorlist.py b/src/spikeinterface/extractors/extractorlist.py index 4957202c56..b226a2d838 100644 --- a/src/spikeinterface/extractors/extractorlist.py +++ b/src/spikeinterface/extractors/extractorlist.py @@ -46,6 +46,7 @@ from .mdaextractors import MdaRecordingExtractor, MdaSortingExtractor, read_mda_recording, read_mda_sorting from .phykilosortextractors import PhySortingExtractor, KiloSortSortingExtractor, read_phy, read_kilosort from .sinapsrecordingextractor import SinapsResearchPlatformRecordingExtractor, read_sinaps_research_platform +from .sinapsrecordingh5extractor import SinapsResearchPlatformH5RecordingExtractor, read_sinaps_research_platform_h5 # sorting in relation with simulator from .shybridextractors import ( diff --git a/src/spikeinterface/extractors/sinapsrecordingh5extractor.py b/src/spikeinterface/extractors/sinapsrecordingh5extractor.py new file mode 100644 index 0000000000..e1dbedebbe --- /dev/null +++ b/src/spikeinterface/extractors/sinapsrecordingh5extractor.py @@ -0,0 +1,112 @@ +from pathlib import Path +import numpy as np + +from ..core.core_tools import define_function_from_class +from ..core import BaseRecording, BaseRecordingSegment + +try: + import h5py + + HAVE_MCSH5 = True +except ImportError: + HAVE_MCSH5 = False + +class SinapsResearchPlatformH5RecordingExtractor(BaseRecording): + extractor_name = "SinapsResearchPlatformH5" + mode = "file" + name = "sinaps_research_platform_h5" + + def __init__(self, file_path): + + assert self.installed, self.installation_mesg + self._file_path = file_path + + mcs_info = openSiNAPSFile(self._file_path) + self._rf = mcs_info["filehandle"] + + BaseRecording.__init__( + self, + sampling_frequency=mcs_info["sampling_frequency"], + channel_ids=mcs_info["channel_ids"], + dtype=mcs_info["dtype"], + ) + + self.extra_requirements.append("h5py") + + recording_segment = SiNAPSRecordingSegment( + self._rf, mcs_info["num_frames"], sampling_frequency=mcs_info["sampling_frequency"] + ) + self.add_recording_segment(recording_segment) + + # set gain + self.set_channel_gains(mcs_info["gain"]) + self.set_channel_offsets(mcs_info["offset"]) + + # set other properties + + self._kwargs = {"file_path": str(Path(file_path).absolute())} + + def __del__(self): + self._rf.close() + +class SiNAPSRecordingSegment(BaseRecordingSegment): + def __init__(self, rf, num_frames, sampling_frequency): + BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency) + self._rf = rf + self._num_samples = int(num_frames) + self._stream = self._rf.require_group('RealTimeProcessedData') + + def get_num_samples(self): + return self._num_samples + + def get_traces(self, start_frame=None, end_frame=None, channel_indices=None): + if isinstance(channel_indices, slice): + traces = self._stream.get('FilteredData')[channel_indices, start_frame:end_frame].T + else: + # channel_indices is np.ndarray + if np.array(channel_indices).size > 1 and np.any(np.diff(channel_indices) < 0): + # get around h5py constraint that it does not allow datasets + # to be indexed out of order + sorted_channel_indices = np.sort(channel_indices) + resorted_indices = np.array([list(sorted_channel_indices).index(ch) for ch in channel_indices]) + recordings = self._stream.get('FilteredData')[sorted_channel_indices, start_frame:end_frame].T + traces = recordings[:, resorted_indices] + else: + traces = self._stream.get('FilteredData')[channel_indices, start_frame:end_frame].T + return traces + + +read_sinaps_research_platform_h5 = define_function_from_class( + source_class=SinapsResearchPlatformH5RecordingExtractor, name="read_sinaps_research_platform_h5" +) + +def openSiNAPSFile(filename): + """Open an SiNAPS hdf5 file, read and return the recording info.""" + rf = h5py.File(filename, "r") + + stream = rf.require_group('RealTimeProcessedData') + data = stream.get("FilteredData") + dtype = data.dtype + + parameters = rf.require_group('Parameters') + gain = parameters.get('VoltageConverter')[0] + offset = -2047 # the input data is in ADC levels, represented with 12 bits (values from 0 to 4095). + # To convert the data to uV, you need to first subtract the OFFSET=2047 (half of the represented range) + # and multiply by the VoltageConverter + + nRecCh, nFrames = data.shape + + samplingRate = parameters.get('SamplingFrequency')[0] + + mcs_info = { + "filehandle": rf, + "num_frames": nFrames, + "sampling_frequency": samplingRate, + "num_channels": nRecCh, + "channel_ids": np.arange(nRecCh), + "gain": gain, + "offset": offset, + "dtype": dtype, + } + + return mcs_info From 6d0cd8599ad8251528f6acdd5daaea09c2152421 Mon Sep 17 00:00:00 2001 From: Nina Kudryashova Date: Mon, 22 Apr 2024 16:50:30 +0100 Subject: [PATCH 005/103] Fix OFFSET, variable naming and importing h5py --- .../extractors/sinapsrecordingh5extractor.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/extractors/sinapsrecordingh5extractor.py b/src/spikeinterface/extractors/sinapsrecordingh5extractor.py index e1dbedebbe..2923011901 100644 --- a/src/spikeinterface/extractors/sinapsrecordingh5extractor.py +++ b/src/spikeinterface/extractors/sinapsrecordingh5extractor.py @@ -4,13 +4,6 @@ from ..core.core_tools import define_function_from_class from ..core import BaseRecording, BaseRecordingSegment -try: - import h5py - - HAVE_MCSH5 = True -except ImportError: - HAVE_MCSH5 = False - class SinapsResearchPlatformH5RecordingExtractor(BaseRecording): extractor_name = "SinapsResearchPlatformH5" mode = "file" @@ -18,29 +11,35 @@ class SinapsResearchPlatformH5RecordingExtractor(BaseRecording): def __init__(self, file_path): + try: + import h5py + self.installed = True + except ImportError: + self.installed = False + assert self.installed, self.installation_mesg self._file_path = file_path - mcs_info = openSiNAPSFile(self._file_path) - self._rf = mcs_info["filehandle"] + sinaps_info = openSiNAPSFile(self._file_path) + self._rf = sinaps_info["filehandle"] BaseRecording.__init__( self, - sampling_frequency=mcs_info["sampling_frequency"], - channel_ids=mcs_info["channel_ids"], - dtype=mcs_info["dtype"], + sampling_frequency=sinaps_info["sampling_frequency"], + channel_ids=sinaps_info["channel_ids"], + dtype=sinaps_info["dtype"], ) self.extra_requirements.append("h5py") recording_segment = SiNAPSRecordingSegment( - self._rf, mcs_info["num_frames"], sampling_frequency=mcs_info["sampling_frequency"] + self._rf, sinaps_info["num_frames"], sampling_frequency=sinaps_info["sampling_frequency"] ) self.add_recording_segment(recording_segment) # set gain - self.set_channel_gains(mcs_info["gain"]) - self.set_channel_offsets(mcs_info["offset"]) + self.set_channel_gains(sinaps_info["gain"]) + self.set_channel_offsets(sinaps_info["offset"]) # set other properties @@ -82,6 +81,9 @@ def get_traces(self, start_frame=None, end_frame=None, channel_indices=None): def openSiNAPSFile(filename): """Open an SiNAPS hdf5 file, read and return the recording info.""" + + import h5py + rf = h5py.File(filename, "r") stream = rf.require_group('RealTimeProcessedData') @@ -90,15 +92,13 @@ def openSiNAPSFile(filename): parameters = rf.require_group('Parameters') gain = parameters.get('VoltageConverter')[0] - offset = -2047 # the input data is in ADC levels, represented with 12 bits (values from 0 to 4095). - # To convert the data to uV, you need to first subtract the OFFSET=2047 (half of the represented range) - # and multiply by the VoltageConverter + offset = -2048 * gain nRecCh, nFrames = data.shape samplingRate = parameters.get('SamplingFrequency')[0] - mcs_info = { + sinaps_info = { "filehandle": rf, "num_frames": nFrames, "sampling_frequency": samplingRate, @@ -109,4 +109,4 @@ def openSiNAPSFile(filename): "dtype": dtype, } - return mcs_info + return sinaps_info From c3dbd28d861674e34257ff9d95886a429a49b10c Mon Sep 17 00:00:00 2001 From: Nina Kudryashova Date: Mon, 22 Apr 2024 16:53:40 +0100 Subject: [PATCH 006/103] Add 0 offset to support rescaling --- src/spikeinterface/extractors/sinapsrecordingextractor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/extractors/sinapsrecordingextractor.py b/src/spikeinterface/extractors/sinapsrecordingextractor.py index fd42401b65..05411a8f06 100644 --- a/src/spikeinterface/extractors/sinapsrecordingextractor.py +++ b/src/spikeinterface/extractors/sinapsrecordingextractor.py @@ -68,6 +68,7 @@ def __init__(self, file_path, stream_name="filt"): if groups is not None: self.set_channel_groups(groups) self.set_channel_gains(gain) + self.set_channel_offsets(0) read_sinaps_research_platform = define_function_from_class( From cb5e2716cf1181f8dc24943f789a374272a70f88 Mon Sep 17 00:00:00 2001 From: Nina Kudryashova Date: Thu, 30 May 2024 14:10:28 +0100 Subject: [PATCH 007/103] Attach a SiNAPS probe to recording --- .../extractors/sinapsrecordingextractor.py | 42 ++++++++++++------- .../extractors/sinapsrecordingh5extractor.py | 14 +++++++ 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/extractors/sinapsrecordingextractor.py b/src/spikeinterface/extractors/sinapsrecordingextractor.py index 05411a8f06..be048d8276 100644 --- a/src/spikeinterface/extractors/sinapsrecordingextractor.py +++ b/src/spikeinterface/extractors/sinapsrecordingextractor.py @@ -1,10 +1,11 @@ from pathlib import Path import numpy as np +from probeinterface import get_probe + from ..core import BinaryRecordingExtractor, ChannelSliceRecording from ..core.core_tools import define_function_from_class - class SinapsResearchPlatformRecordingExtractor(ChannelSliceRecording): extractor_name = "SinapsResearchPlatform" mode = "file" @@ -22,14 +23,15 @@ def __init__(self, file_path, stream_name="filt"): num_electrodes = meta["nbElectrodes"] sampling_frequency = meta["samplingFreq"] - channel_locations = meta["electrodePhysicalPosition"] + probe_type = meta['probeType'] + # channel_locations = meta["electrodePhysicalPosition"] # will be depricated soon by Sam, switching to probeinterface num_shanks = meta["nbShanks"] num_electrodes_per_shank = meta["nbElectrodesShank"] num_bits = int(np.log2(meta["nbADCLevels"])) - channel_groups = [] - for i in range(num_shanks): - channel_groups.extend([i] * num_electrodes_per_shank) + # channel_groups = [] + # for i in range(num_shanks): + # channel_groups.extend([i] * num_electrodes_per_shank) gain_ephys = meta["voltageConverter"] gain_aux = meta["voltageAUXConverter"] @@ -42,34 +44,44 @@ def __init__(self, file_path, stream_name="filt"): if stream_name == "raw": channel_slice = recording.channel_ids[:num_electrodes] renamed_channels = np.arange(num_electrodes) - locations = channel_locations - groups = channel_groups + # locations = channel_locations + # groups = channel_groups gain = gain_ephys elif stream_name == "filt": channel_slice = recording.channel_ids[num_electrodes : 2 * num_electrodes] renamed_channels = np.arange(num_electrodes) - locations = channel_locations - groups = channel_groups + # locations = channel_locations + # groups = channel_groups gain = gain_ephys elif stream_name == "aux": channel_slice = recording.channel_ids[2 * num_electrodes :] hw_chans = meta["hwAUXChannelName"][1:-1].split(",") user_chans = meta["userAuxName"][1:-1].split(",") renamed_channels = hw_chans + user_chans - locations = None - groups = None + # locations = None + # groups = None gain = gain_aux else: raise ValueError("stream_name must be 'raw', 'filt', or 'aux'") ChannelSliceRecording.__init__(self, recording, channel_ids=channel_slice, renamed_channel_ids=renamed_channels) - if locations is not None: - self.set_channel_locations(locations) - if groups is not None: - self.set_channel_groups(groups) + # if locations is not None: + # self.set_channel_locations(locations) + # if groups is not None: + # self.set_channel_groups(groups) + self.set_channel_gains(gain) self.set_channel_offsets(0) + if probe_type == 'p1024s1NHP': + probe = get_probe(manufacturer='sinaps', + probe_name='SiNAPS-p1024s1NHP') + # now wire the probe + channel_indices = np.arange(1024) + probe.set_device_channel_indices(channel_indices) + self.set_probe(probe,in_place=True) + else: + raise ValueError(f"Unknown probe type: {probe_type}") read_sinaps_research_platform = define_function_from_class( source_class=SinapsResearchPlatformRecordingExtractor, name="read_sinaps_research_platform" diff --git a/src/spikeinterface/extractors/sinapsrecordingh5extractor.py b/src/spikeinterface/extractors/sinapsrecordingh5extractor.py index 2923011901..b20444c3c9 100644 --- a/src/spikeinterface/extractors/sinapsrecordingh5extractor.py +++ b/src/spikeinterface/extractors/sinapsrecordingh5extractor.py @@ -1,6 +1,8 @@ from pathlib import Path import numpy as np +from probeinterface import get_probe + from ..core.core_tools import define_function_from_class from ..core import BaseRecording, BaseRecordingSegment @@ -45,6 +47,15 @@ def __init__(self, file_path): self._kwargs = {"file_path": str(Path(file_path).absolute())} + # set probe + if sinaps_info['probe_type'] == 'p1024s1NHP': + probe = get_probe(manufacturer='sinaps', + probe_name='SiNAPS-p1024s1NHP') + probe.set_device_channel_indices(np.arange(1024)) + self.set_probe(probe, in_place=True) + else: + raise ValueError(f"Unknown probe type: {sinaps_info['probe_type']}") + def __del__(self): self._rf.close() @@ -98,6 +109,8 @@ def openSiNAPSFile(filename): samplingRate = parameters.get('SamplingFrequency')[0] + probe_type = str(rf.require_group('Advanced Recording Parameters').require_group('Probe').get('probeType').asstr()[...]) + sinaps_info = { "filehandle": rf, "num_frames": nFrames, @@ -107,6 +120,7 @@ def openSiNAPSFile(filename): "gain": gain, "offset": offset, "dtype": dtype, + "probe_type": probe_type, } return sinaps_info From 1f396f4f6258b0de9977fb09bc6162723ae5a9ee Mon Sep 17 00:00:00 2001 From: Nina Kudryashova Date: Thu, 30 May 2024 15:30:23 +0100 Subject: [PATCH 008/103] Fix unsigned to signed --- .../extractors/sinapsrecordingh5extractor.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/sinapsrecordingh5extractor.py b/src/spikeinterface/extractors/sinapsrecordingh5extractor.py index b20444c3c9..94c6e74223 100644 --- a/src/spikeinterface/extractors/sinapsrecordingh5extractor.py +++ b/src/spikeinterface/extractors/sinapsrecordingh5extractor.py @@ -5,8 +5,10 @@ from ..core.core_tools import define_function_from_class from ..core import BaseRecording, BaseRecordingSegment +from ..preprocessing import UnsignedToSignedRecording -class SinapsResearchPlatformH5RecordingExtractor(BaseRecording): + +class SinapsResearchPlatformH5RecordingExtractor_Unsigned(BaseRecording): extractor_name = "SinapsResearchPlatformH5" mode = "file" name = "sinaps_research_platform_h5" @@ -42,6 +44,7 @@ def __init__(self, file_path): # set gain self.set_channel_gains(sinaps_info["gain"]) self.set_channel_offsets(sinaps_info["offset"]) + self.num_bits = sinaps_info["num_bits"] # set other properties @@ -56,6 +59,7 @@ def __init__(self, file_path): else: raise ValueError(f"Unknown probe type: {sinaps_info['probe_type']}") + def __del__(self): self._rf.close() @@ -85,11 +89,21 @@ def get_traces(self, start_frame=None, end_frame=None, channel_indices=None): traces = self._stream.get('FilteredData')[channel_indices, start_frame:end_frame].T return traces +class SinapsResearchPlatformH5RecordingExtractor(UnsignedToSignedRecording): + extractor_name = "SinapsResearchPlatformH5" + mode = "file" + name = "sinaps_research_platform_h5" + + def __init__(self, file_path): + recording = SinapsResearchPlatformH5RecordingExtractor_Unsigned(file_path) + UnsignedToSignedRecording.__init__(self, recording, bit_depth=recording.num_bits) + read_sinaps_research_platform_h5 = define_function_from_class( source_class=SinapsResearchPlatformH5RecordingExtractor, name="read_sinaps_research_platform_h5" ) + def openSiNAPSFile(filename): """Open an SiNAPS hdf5 file, read and return the recording info.""" @@ -103,13 +117,14 @@ def openSiNAPSFile(filename): parameters = rf.require_group('Parameters') gain = parameters.get('VoltageConverter')[0] - offset = -2048 * gain + offset = 0 nRecCh, nFrames = data.shape samplingRate = parameters.get('SamplingFrequency')[0] probe_type = str(rf.require_group('Advanced Recording Parameters').require_group('Probe').get('probeType').asstr()[...]) + num_bits = int(np.log2(rf.require_group('Advanced Recording Parameters').require_group('DAQ').get('nbADCLevels')[0])) sinaps_info = { "filehandle": rf, @@ -121,6 +136,7 @@ def openSiNAPSFile(filename): "offset": offset, "dtype": dtype, "probe_type": probe_type, + "num_bits": num_bits, } return sinaps_info From 5b82b3fec16ce3f0b8111caa6bef4455d589e85a Mon Sep 17 00:00:00 2001 From: Nina Kudryashova Date: Thu, 30 May 2024 15:46:18 +0100 Subject: [PATCH 009/103] Fix AUX channels which should not be attached to a probe --- .../extractors/sinapsrecordingextractor.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/extractors/sinapsrecordingextractor.py b/src/spikeinterface/extractors/sinapsrecordingextractor.py index be048d8276..048c5d8e8a 100644 --- a/src/spikeinterface/extractors/sinapsrecordingextractor.py +++ b/src/spikeinterface/extractors/sinapsrecordingextractor.py @@ -73,15 +73,16 @@ def __init__(self, file_path, stream_name="filt"): self.set_channel_gains(gain) self.set_channel_offsets(0) - if probe_type == 'p1024s1NHP': - probe = get_probe(manufacturer='sinaps', - probe_name='SiNAPS-p1024s1NHP') - # now wire the probe - channel_indices = np.arange(1024) - probe.set_device_channel_indices(channel_indices) - self.set_probe(probe,in_place=True) - else: - raise ValueError(f"Unknown probe type: {probe_type}") + if (stream_name == 'filt') | (stream_name == 'raw'): + if (probe_type == 'p1024s1NHP'): + probe = get_probe(manufacturer='sinaps', + probe_name='SiNAPS-p1024s1NHP') + # now wire the probe + channel_indices = np.arange(1024) + probe.set_device_channel_indices(channel_indices) + self.set_probe(probe,in_place=True) + else: + raise ValueError(f"Unknown probe type: {probe_type}") read_sinaps_research_platform = define_function_from_class( source_class=SinapsResearchPlatformRecordingExtractor, name="read_sinaps_research_platform" From 2526a1b8980edd6c1aed46f8b1795cc63bdaadbd Mon Sep 17 00:00:00 2001 From: Nina Kudryashova Date: Fri, 31 May 2024 11:15:00 +0100 Subject: [PATCH 010/103] Fix _kwargs in extractors --- .../extractors/sinapsrecordingextractor.py | 3 ++- .../extractors/sinapsrecordingh5extractor.py | 11 +++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/extractors/sinapsrecordingextractor.py b/src/spikeinterface/extractors/sinapsrecordingextractor.py index 048c5d8e8a..c54ed8ddcd 100644 --- a/src/spikeinterface/extractors/sinapsrecordingextractor.py +++ b/src/spikeinterface/extractors/sinapsrecordingextractor.py @@ -83,12 +83,13 @@ def __init__(self, file_path, stream_name="filt"): self.set_probe(probe,in_place=True) else: raise ValueError(f"Unknown probe type: {probe_type}") + + self._kwargs = {"file_path": str(file_path.absolute())} read_sinaps_research_platform = define_function_from_class( source_class=SinapsResearchPlatformRecordingExtractor, name="read_sinaps_research_platform" ) - def parse_sinaps_meta(meta_file): meta_dict = {} with open(meta_file) as f: diff --git a/src/spikeinterface/extractors/sinapsrecordingh5extractor.py b/src/spikeinterface/extractors/sinapsrecordingh5extractor.py index 94c6e74223..96def456dd 100644 --- a/src/spikeinterface/extractors/sinapsrecordingh5extractor.py +++ b/src/spikeinterface/extractors/sinapsrecordingh5extractor.py @@ -46,10 +46,6 @@ def __init__(self, file_path): self.set_channel_offsets(sinaps_info["offset"]) self.num_bits = sinaps_info["num_bits"] - # set other properties - - self._kwargs = {"file_path": str(Path(file_path).absolute())} - # set probe if sinaps_info['probe_type'] == 'p1024s1NHP': probe = get_probe(manufacturer='sinaps', @@ -58,6 +54,11 @@ def __init__(self, file_path): self.set_probe(probe, in_place=True) else: raise ValueError(f"Unknown probe type: {sinaps_info['probe_type']}") + + + # set other properties + + self._kwargs = {"file_path": str(Path(file_path).absolute())} def __del__(self): @@ -98,6 +99,8 @@ def __init__(self, file_path): recording = SinapsResearchPlatformH5RecordingExtractor_Unsigned(file_path) UnsignedToSignedRecording.__init__(self, recording, bit_depth=recording.num_bits) + self._kwargs = {"file_path": str(Path(file_path).absolute())} + read_sinaps_research_platform_h5 = define_function_from_class( source_class=SinapsResearchPlatformH5RecordingExtractor, name="read_sinaps_research_platform_h5" From 0fde0d2eabd992e42a2b4ca4f8455dca89fb3766 Mon Sep 17 00:00:00 2001 From: Nina Kudryashova Date: Mon, 3 Jun 2024 11:48:40 +0100 Subject: [PATCH 011/103] Run black locally --- .../extractors/sinapsrecordingextractor.py | 22 ++++++------ .../extractors/sinapsrecordingh5extractor.py | 36 ++++++++++--------- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/extractors/sinapsrecordingextractor.py b/src/spikeinterface/extractors/sinapsrecordingextractor.py index c54ed8ddcd..1f35407c33 100644 --- a/src/spikeinterface/extractors/sinapsrecordingextractor.py +++ b/src/spikeinterface/extractors/sinapsrecordingextractor.py @@ -6,6 +6,7 @@ from ..core import BinaryRecordingExtractor, ChannelSliceRecording from ..core.core_tools import define_function_from_class + class SinapsResearchPlatformRecordingExtractor(ChannelSliceRecording): extractor_name = "SinapsResearchPlatform" mode = "file" @@ -23,7 +24,7 @@ def __init__(self, file_path, stream_name="filt"): num_electrodes = meta["nbElectrodes"] sampling_frequency = meta["samplingFreq"] - probe_type = meta['probeType'] + probe_type = meta["probeType"] # channel_locations = meta["electrodePhysicalPosition"] # will be depricated soon by Sam, switching to probeinterface num_shanks = meta["nbShanks"] num_electrodes_per_shank = meta["nbElectrodesShank"] @@ -66,30 +67,31 @@ def __init__(self, file_path, stream_name="filt"): ChannelSliceRecording.__init__(self, recording, channel_ids=channel_slice, renamed_channel_ids=renamed_channels) # if locations is not None: - # self.set_channel_locations(locations) + # self.set_channel_locations(locations) # if groups is not None: - # self.set_channel_groups(groups) - + # self.set_channel_groups(groups) + self.set_channel_gains(gain) self.set_channel_offsets(0) - if (stream_name == 'filt') | (stream_name == 'raw'): - if (probe_type == 'p1024s1NHP'): - probe = get_probe(manufacturer='sinaps', - probe_name='SiNAPS-p1024s1NHP') + if (stream_name == "filt") | (stream_name == "raw"): + if probe_type == "p1024s1NHP": + probe = get_probe(manufacturer="sinaps", probe_name="SiNAPS-p1024s1NHP") # now wire the probe channel_indices = np.arange(1024) probe.set_device_channel_indices(channel_indices) - self.set_probe(probe,in_place=True) + self.set_probe(probe, in_place=True) else: raise ValueError(f"Unknown probe type: {probe_type}") - + self._kwargs = {"file_path": str(file_path.absolute())} + read_sinaps_research_platform = define_function_from_class( source_class=SinapsResearchPlatformRecordingExtractor, name="read_sinaps_research_platform" ) + def parse_sinaps_meta(meta_file): meta_dict = {} with open(meta_file) as f: diff --git a/src/spikeinterface/extractors/sinapsrecordingh5extractor.py b/src/spikeinterface/extractors/sinapsrecordingh5extractor.py index 96def456dd..dbfcb239fa 100644 --- a/src/spikeinterface/extractors/sinapsrecordingh5extractor.py +++ b/src/spikeinterface/extractors/sinapsrecordingh5extractor.py @@ -17,6 +17,7 @@ def __init__(self, file_path): try: import h5py + self.installed = True except ImportError: self.installed = False @@ -47,36 +48,34 @@ def __init__(self, file_path): self.num_bits = sinaps_info["num_bits"] # set probe - if sinaps_info['probe_type'] == 'p1024s1NHP': - probe = get_probe(manufacturer='sinaps', - probe_name='SiNAPS-p1024s1NHP') + if sinaps_info["probe_type"] == "p1024s1NHP": + probe = get_probe(manufacturer="sinaps", probe_name="SiNAPS-p1024s1NHP") probe.set_device_channel_indices(np.arange(1024)) self.set_probe(probe, in_place=True) else: raise ValueError(f"Unknown probe type: {sinaps_info['probe_type']}") - # set other properties self._kwargs = {"file_path": str(Path(file_path).absolute())} - def __del__(self): self._rf.close() + class SiNAPSRecordingSegment(BaseRecordingSegment): def __init__(self, rf, num_frames, sampling_frequency): BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency) self._rf = rf self._num_samples = int(num_frames) - self._stream = self._rf.require_group('RealTimeProcessedData') + self._stream = self._rf.require_group("RealTimeProcessedData") def get_num_samples(self): return self._num_samples def get_traces(self, start_frame=None, end_frame=None, channel_indices=None): if isinstance(channel_indices, slice): - traces = self._stream.get('FilteredData')[channel_indices, start_frame:end_frame].T + traces = self._stream.get("FilteredData")[channel_indices, start_frame:end_frame].T else: # channel_indices is np.ndarray if np.array(channel_indices).size > 1 and np.any(np.diff(channel_indices) < 0): @@ -84,12 +83,13 @@ def get_traces(self, start_frame=None, end_frame=None, channel_indices=None): # to be indexed out of order sorted_channel_indices = np.sort(channel_indices) resorted_indices = np.array([list(sorted_channel_indices).index(ch) for ch in channel_indices]) - recordings = self._stream.get('FilteredData')[sorted_channel_indices, start_frame:end_frame].T + recordings = self._stream.get("FilteredData")[sorted_channel_indices, start_frame:end_frame].T traces = recordings[:, resorted_indices] else: - traces = self._stream.get('FilteredData')[channel_indices, start_frame:end_frame].T + traces = self._stream.get("FilteredData")[channel_indices, start_frame:end_frame].T return traces + class SinapsResearchPlatformH5RecordingExtractor(UnsignedToSignedRecording): extractor_name = "SinapsResearchPlatformH5" mode = "file" @@ -109,25 +109,29 @@ def __init__(self, file_path): def openSiNAPSFile(filename): """Open an SiNAPS hdf5 file, read and return the recording info.""" - + import h5py rf = h5py.File(filename, "r") - stream = rf.require_group('RealTimeProcessedData') + stream = rf.require_group("RealTimeProcessedData") data = stream.get("FilteredData") dtype = data.dtype - parameters = rf.require_group('Parameters') - gain = parameters.get('VoltageConverter')[0] + parameters = rf.require_group("Parameters") + gain = parameters.get("VoltageConverter")[0] offset = 0 nRecCh, nFrames = data.shape - samplingRate = parameters.get('SamplingFrequency')[0] + samplingRate = parameters.get("SamplingFrequency")[0] - probe_type = str(rf.require_group('Advanced Recording Parameters').require_group('Probe').get('probeType').asstr()[...]) - num_bits = int(np.log2(rf.require_group('Advanced Recording Parameters').require_group('DAQ').get('nbADCLevels')[0])) + probe_type = str( + rf.require_group("Advanced Recording Parameters").require_group("Probe").get("probeType").asstr()[...] + ) + num_bits = int( + np.log2(rf.require_group("Advanced Recording Parameters").require_group("DAQ").get("nbADCLevels")[0]) + ) sinaps_info = { "filehandle": rf, From ad8a062c6769b060fa46d68e892bc728b8290189 Mon Sep 17 00:00:00 2001 From: Julien Verplanken Date: Fri, 7 Jun 2024 10:41:13 +0200 Subject: [PATCH 012/103] add whiteningRange as kilosort2_5 parameter --- src/spikeinterface/sorters/external/kilosort2_5.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort2_5.py b/src/spikeinterface/sorters/external/kilosort2_5.py index abde2ab324..beccba3481 100644 --- a/src/spikeinterface/sorters/external/kilosort2_5.py +++ b/src/spikeinterface/sorters/external/kilosort2_5.py @@ -53,6 +53,7 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "nPCs": 3, "ntbuff": 64, "nfilt_factor": 4, + "whiteningRange": 32.0, "NT": None, "AUCsplit": 0.9, "do_correction": True, @@ -82,6 +83,7 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "ntbuff": "Samples of symmetrical buffer for whitening and spike detection", "nfilt_factor": "Max number of clusters per good channel (even temporary ones) 4", "do_correction": "If True drift registration is applied", + "whiteningRange": "Number of channels to use for whitening each channel", "NT": "Batch size (if None it is automatically computed)", "AUCsplit": "Threshold on the area under the curve (AUC) criterion for performing a split in the final step", "keep_good_only": "If True only 'good' units are returned", @@ -220,7 +222,7 @@ def _get_specific_options(cls, ops, params): ops["NT"] = params[ "NT" ] # must be multiple of 32 + ntbuff. This is the batch size (try decreasing if out of memory). - ops["whiteningRange"] = 32.0 # number of channels to use for whitening each channel + ops["whiteningRange"] = params["whiteningRange"] # number of channels to use for whitening each channel ops["nSkipCov"] = 25.0 # compute whitening matrix from every N-th batch ops["nPCs"] = params["nPCs"] # how many PCs to project the spikes into ops["useRAM"] = 0.0 # not yet available From a54612bfc0d4f8e704800854a3a119a05129be1c Mon Sep 17 00:00:00 2001 From: Julien Verplanken Date: Fri, 7 Jun 2024 10:54:09 +0200 Subject: [PATCH 013/103] reorder to match positions in param dictionaries --- src/spikeinterface/sorters/external/kilosort2_5.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort2_5.py b/src/spikeinterface/sorters/external/kilosort2_5.py index beccba3481..b3d1718d59 100644 --- a/src/spikeinterface/sorters/external/kilosort2_5.py +++ b/src/spikeinterface/sorters/external/kilosort2_5.py @@ -41,6 +41,7 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "detect_threshold": 6, "projection_threshold": [10, 4], "preclust_threshold": 8, + "whiteningRange": 32.0, "momentum": [20.0, 400.0], "car": True, "minFR": 0.1, @@ -53,7 +54,6 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "nPCs": 3, "ntbuff": 64, "nfilt_factor": 4, - "whiteningRange": 32.0, "NT": None, "AUCsplit": 0.9, "do_correction": True, @@ -70,6 +70,7 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "detect_threshold": "Threshold for spike detection", "projection_threshold": "Threshold on projections", "preclust_threshold": "Threshold crossings for pre-clustering (in PCA projection space)", + "whiteningRange": "Number of channels to use for whitening each channel", "momentum": "Number of samples to average over (annealed from first to second value)", "car": "Enable or disable common reference", "minFR": "Minimum spike rate (Hz), if a cluster falls below this for too long it gets removed", @@ -83,7 +84,6 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "ntbuff": "Samples of symmetrical buffer for whitening and spike detection", "nfilt_factor": "Max number of clusters per good channel (even temporary ones) 4", "do_correction": "If True drift registration is applied", - "whiteningRange": "Number of channels to use for whitening each channel", "NT": "Batch size (if None it is automatically computed)", "AUCsplit": "Threshold on the area under the curve (AUC) criterion for performing a split in the final step", "keep_good_only": "If True only 'good' units are returned", From b2b9001b343a285c07640693ea41fc6facdebbfd Mon Sep 17 00:00:00 2001 From: Julien Verplanken Date: Fri, 7 Jun 2024 15:35:37 +0200 Subject: [PATCH 014/103] added whiteningRange parameter to KS2 and KS3 --- src/spikeinterface/sorters/external/kilosort2.py | 4 +++- src/spikeinterface/sorters/external/kilosort3.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort2.py b/src/spikeinterface/sorters/external/kilosort2.py index bdc0372789..0425ad5e53 100644 --- a/src/spikeinterface/sorters/external/kilosort2.py +++ b/src/spikeinterface/sorters/external/kilosort2.py @@ -37,6 +37,7 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): "detect_threshold": 6, "projection_threshold": [10, 4], "preclust_threshold": 8, + "whiteningRange": 32, # samples of the template to use for whitening "spatial" dimension "momentum": [20.0, 400.0], "car": True, "minFR": 0.1, @@ -62,6 +63,7 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): "detect_threshold": "Threshold for spike detection", "projection_threshold": "Threshold on projections", "preclust_threshold": "Threshold crossings for pre-clustering (in PCA projection space)", + "whiteningRange": "Number of channels to use for whitening each channel", "momentum": "Number of samples to average over (annealed from first to second value)", "car": "Enable or disable common reference", "minFR": "Minimum spike rate (Hz), if a cluster falls below this for too long it gets removed", @@ -199,7 +201,7 @@ def _get_specific_options(cls, ops, params): ops["NT"] = params[ "NT" ] # must be multiple of 32 + ntbuff. This is the batch size (try decreasing if out of memory). - ops["whiteningRange"] = 32.0 # number of channels to use for whitening each channel + ops["whiteningRange"] = params["whiteningRange"] # number of channels to use for whitening each channel ops["nSkipCov"] = 25.0 # compute whitening matrix from every N-th batch ops["nPCs"] = params["nPCs"] # how many PCs to project the spikes into ops["useRAM"] = 0.0 # not yet available diff --git a/src/spikeinterface/sorters/external/kilosort3.py b/src/spikeinterface/sorters/external/kilosort3.py index 3d2103ea66..f560fd7e1e 100644 --- a/src/spikeinterface/sorters/external/kilosort3.py +++ b/src/spikeinterface/sorters/external/kilosort3.py @@ -38,6 +38,7 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): "detect_threshold": 6, "projection_threshold": [9, 9], "preclust_threshold": 8, + "whiteningRange": 32, "car": True, "minFR": 0.2, "minfr_goodchannels": 0.2, @@ -65,6 +66,7 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): "detect_threshold": "Threshold for spike detection", "projection_threshold": "Threshold on projections", "preclust_threshold": "Threshold crossings for pre-clustering (in PCA projection space)", + "whiteningRange": "number of channels to use for whitening each channel", "car": "Enable or disable common reference", "minFR": "Minimum spike rate (Hz), if a cluster falls below this for too long it gets removed", "minfr_goodchannels": "Minimum firing rate on a 'good' channel", @@ -212,7 +214,7 @@ def _get_specific_options(cls, ops, params): ops["NT"] = params[ "NT" ] # must be multiple of 32 + ntbuff. This is the batch size (try decreasing if out of memory). - ops["whiteningRange"] = 32.0 # number of channels to use for whitening each channel + ops["whiteningRange"] = params["whiteningRange"] # number of channels to use for whitening each channel ops["nSkipCov"] = 25.0 # compute whitening matrix from every N-th batch ops["scaleproc"] = 200.0 # int16 scaling of whitened data ops["nPCs"] = params["nPCs"] # how many PCs to project the spikes into From c1c0cb6b3023f9e274ffb9012966ce7db841d649 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 11 Jun 2024 15:56:01 +0100 Subject: [PATCH 015/103] remove extremum from spikelocations init --- src/spikeinterface/postprocessing/spike_locations.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index d468bd90ab..96e01a68c4 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -59,9 +59,6 @@ class ComputeSpikeLocations(AnalyzerExtension): def __init__(self, sorting_analyzer): AnalyzerExtension.__init__(self, sorting_analyzer) - extremum_channel_inds = get_template_extremum_channel(self.sorting_analyzer, outputs="index") - self.spikes = self.sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) - def _set_params( self, ms_before=0.5, @@ -89,8 +86,9 @@ def _set_params( def _select_extension_data(self, unit_ids): old_unit_ids = self.sorting_analyzer.unit_ids unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) + spikes = self.sorting_analyzer.sorting.to_spike_vector() - spike_mask = np.isin(self.spikes["unit_index"], unit_inds) + spike_mask = np.isin(spikes["unit_index"], unit_inds) new_spike_locations = self.data["spike_locations"][spike_mask] return dict(spike_locations=new_spike_locations) From cf5041062a73b1c61d8f15200ade73ea1f1d8bae Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 19:14:51 +0100 Subject: [PATCH 016/103] Run checks for singularity, docker and related python module installations. --- src/spikeinterface/sorters/runsorter.py | 18 ++++++++++++++ src/spikeinterface/sorters/utils/misc.py | 31 ++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index baec6aaac3..44a08a34a7 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -169,6 +169,15 @@ def run_sorter( container_image = None else: container_image = docker_image + + if not has_docker(): + raise RuntimeError("Docker is not installed. Install docker " + "on this machine to run sorting with docker.") + + if not has_docker_python(): + raise RuntimeError("The python `docker` package must be installed." + "Install with `pip install docker`") + else: mode = "singularity" assert not docker_image @@ -176,6 +185,15 @@ def run_sorter( container_image = None else: container_image = singularity_image + + if not has_singularity(): + raise RuntimeError("Singularity is not installed. Install singularity " + "on this machine to run sorting with singularity.") + + if not has_spython(): + raise RuntimeError("The python singularity package must be installed." + "Install with `pip install spython`") + return run_sorter_container( container_image=container_image, mode=mode, diff --git a/src/spikeinterface/sorters/utils/misc.py b/src/spikeinterface/sorters/utils/misc.py index 0a6b4a986c..a1cf34f059 100644 --- a/src/spikeinterface/sorters/utils/misc.py +++ b/src/spikeinterface/sorters/utils/misc.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +import subprocess # TODO: decide best format for this from subprocess import check_output, CalledProcessError from typing import List, Union @@ -80,3 +81,33 @@ def has_nvidia(): return device_count > 0 except RuntimeError: # Failed to dlopen libcuda.so return False + +def _run_subprocess_silently(command): + output = subprocess.run( + command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + ) + return output + + +def has_docker(): + return self._run_subprocess_silently("docker --version").returncode == 0 + + +def has_singularity(): + return self._run_subprocess_silently("singularity --version").returncode == 0 + + +def has_docker_python(): + try: + import docker + return True + except ImportError: + return False + + +def has_spython(): + try: + import spython + return True + except ImportError: + return False From e49521939f2023c50943afad21a663c3d7822011 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 20:09:03 +0100 Subject: [PATCH 017/103] Add nvidia dependency checks, tidy up. --- src/spikeinterface/sorters/runsorter.py | 17 +++++++++++---- src/spikeinterface/sorters/utils/__init__.py | 2 +- src/spikeinterface/sorters/utils/misc.py | 22 +++++++++++++++++--- 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 44a08a34a7..884cba590f 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -19,7 +19,7 @@ from ..core import BaseRecording, NumpySorting, load_extractor from ..core.core_tools import check_json, is_editable_mode from .sorterlist import sorter_dict -from .utils import SpikeSortingError, has_nvidia +from .utils import SpikeSortingError, has_nvidia, has_docker, has_docker_python, has_singularity, has_spython, has_docker_nvidia_installed, get_nvidia_docker_dependecies from .container_tools import ( find_recording_folders, path_to_unix, @@ -175,7 +175,7 @@ def run_sorter( "on this machine to run sorting with docker.") if not has_docker_python(): - raise RuntimeError("The python `docker` package must be installed." + raise RuntimeError("The python `docker` package must be installed. " "Install with `pip install docker`") else: @@ -191,8 +191,8 @@ def run_sorter( "on this machine to run sorting with singularity.") if not has_spython(): - raise RuntimeError("The python singularity package must be installed." - "Install with `pip install spython`") + raise RuntimeError("The python `spython` package must be installed to " + "run singularity. Install with `pip install spython`") return run_sorter_container( container_image=container_image, @@ -480,6 +480,15 @@ def run_sorter_container( if gpu_capability == "nvidia-required": assert has_nvidia(), "The container requires a NVIDIA GPU capability, but it is not available" extra_kwargs["container_requires_gpu"] = True + + if platform.system() == "Linux" and has_docker_nvidia_installed(): + warn( + f"nvidia-required but none of \n{get_nvidia_docker_dependecies()}\n were found. " + f"This may result in an error being raised during sorting. Try " + "installing `nvidia-container-toolkit`, including setting the " + "configuration steps, if running into errors." + ) + elif gpu_capability == "nvidia-optional": if has_nvidia(): extra_kwargs["container_requires_gpu"] = True diff --git a/src/spikeinterface/sorters/utils/__init__.py b/src/spikeinterface/sorters/utils/__init__.py index 6cad10b211..7f6f3089d4 100644 --- a/src/spikeinterface/sorters/utils/__init__.py +++ b/src/spikeinterface/sorters/utils/__init__.py @@ -1,2 +1,2 @@ from .shellscript import ShellScript -from .misc import SpikeSortingError, get_git_commit, has_nvidia, get_matlab_shell_name, get_bash_path +from .misc import SpikeSortingError, get_git_commit, has_nvidia, get_matlab_shell_name, get_bash_path, has_docker, has_docker_python, has_singularity, has_spython, has_docker_nvidia_installed, get_nvidia_docker_dependecies diff --git a/src/spikeinterface/sorters/utils/misc.py b/src/spikeinterface/sorters/utils/misc.py index a1cf34f059..4a900f4485 100644 --- a/src/spikeinterface/sorters/utils/misc.py +++ b/src/spikeinterface/sorters/utils/misc.py @@ -82,6 +82,7 @@ def has_nvidia(): except RuntimeError: # Failed to dlopen libcuda.so return False + def _run_subprocess_silently(command): output = subprocess.run( command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL @@ -90,12 +91,27 @@ def _run_subprocess_silently(command): def has_docker(): - return self._run_subprocess_silently("docker --version").returncode == 0 + return _run_subprocess_silently("docker --version").returncode == 0 def has_singularity(): - return self._run_subprocess_silently("singularity --version").returncode == 0 - + return _run_subprocess_silently("singularity --version").returncode == 0 + +def get_nvidia_docker_dependecies(): + return [ + "nvidia-docker", + "nvidia-docker2", + "nvidia-container-toolkit", + ] + +def has_docker_nvidia_installed(): + all_dependencies = get_nvidia_docker_dependecies() + has_dep = [] + for dep in all_dependencies: + has_dep.append( + _run_subprocess_silently(f"{dep} --version").returncode == 0 + ) + return not any(has_dep) def has_docker_python(): try: From e0656bb86901127c8b1c0f708e4970584e79a40d Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 20:15:48 +0100 Subject: [PATCH 018/103] Add docstrings. --- src/spikeinterface/sorters/utils/misc.py | 44 ++++++++++++++++++------ 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/sorters/utils/misc.py b/src/spikeinterface/sorters/utils/misc.py index 4a900f4485..66744fbab1 100644 --- a/src/spikeinterface/sorters/utils/misc.py +++ b/src/spikeinterface/sorters/utils/misc.py @@ -84,9 +84,10 @@ def has_nvidia(): def _run_subprocess_silently(command): - output = subprocess.run( - command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL - ) + """ + Run a subprocess command without outputting to stderr or stdout. + """ + output = subprocess.run(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) return output @@ -97,25 +98,45 @@ def has_docker(): def has_singularity(): return _run_subprocess_silently("singularity --version").returncode == 0 + +def has_docker_nvidia_installed(): + """ + On Linux, nvidia has a set of container dependencies + that are required for running GPU in docker. This is a little + complex and is described in more detail in the links below. + To summarise breifly, at least one of the `get_nvidia_docker_dependecies()` + is almost certainly required to run docker with GPU. + + https://github.com/NVIDIA/nvidia-docker/issues/1268 + https://www.howtogeek.com/devops/how-to-use-an-nvidia-gpu-with-docker-containers/ + + Returns + ------- + Whether at least one of the dependencies listed in + `get_nvidia_docker_dependecies()` is installed. + """ + all_dependencies = get_nvidia_docker_dependecies() + has_dep = [] + for dep in all_dependencies: + has_dep.append(_run_subprocess_silently(f"{dep} --version").returncode == 0) + return not any(has_dep) + + def get_nvidia_docker_dependecies(): + """ + See `has_docker_nvidia_installed()` + """ return [ "nvidia-docker", "nvidia-docker2", "nvidia-container-toolkit", ] -def has_docker_nvidia_installed(): - all_dependencies = get_nvidia_docker_dependecies() - has_dep = [] - for dep in all_dependencies: - has_dep.append( - _run_subprocess_silently(f"{dep} --version").returncode == 0 - ) - return not any(has_dep) def has_docker_python(): try: import docker + return True except ImportError: return False @@ -124,6 +145,7 @@ def has_docker_python(): def has_spython(): try: import spython + return True except ImportError: return False From b145b04ac31a8de3d9c9fbfc56b4a9974ce0eb3a Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 21:21:51 +0100 Subject: [PATCH 019/103] Add tests for runsorter dependencies. --- src/spikeinterface/sorters/runsorter.py | 35 +++-- .../tests/test_runsorter_dependency_checks.py | 144 ++++++++++++++++++ 2 files changed, 170 insertions(+), 9 deletions(-) create mode 100644 src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 884cba590f..5b2e80b83d 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -19,7 +19,18 @@ from ..core import BaseRecording, NumpySorting, load_extractor from ..core.core_tools import check_json, is_editable_mode from .sorterlist import sorter_dict -from .utils import SpikeSortingError, has_nvidia, has_docker, has_docker_python, has_singularity, has_spython, has_docker_nvidia_installed, get_nvidia_docker_dependecies + +# full import required for monkeypatch testing. +from spikeinterface.sorters.utils import ( + SpikeSortingError, + has_nvidia, + has_docker, + has_docker_python, + has_singularity, + has_spython, + has_docker_nvidia_installed, + get_nvidia_docker_dependecies, +) from .container_tools import ( find_recording_folders, path_to_unix, @@ -171,12 +182,14 @@ def run_sorter( container_image = docker_image if not has_docker(): - raise RuntimeError("Docker is not installed. Install docker " - "on this machine to run sorting with docker.") + raise RuntimeError( + "Docker is not installed. Install docker " "on this machine to run sorting with docker." + ) if not has_docker_python(): - raise RuntimeError("The python `docker` package must be installed. " - "Install with `pip install docker`") + raise RuntimeError( + "The python `docker` package must be installed. " "Install with `pip install docker`" + ) else: mode = "singularity" @@ -187,12 +200,16 @@ def run_sorter( container_image = singularity_image if not has_singularity(): - raise RuntimeError("Singularity is not installed. Install singularity " - "on this machine to run sorting with singularity.") + raise RuntimeError( + "Singularity is not installed. Install singularity " + "on this machine to run sorting with singularity." + ) if not has_spython(): - raise RuntimeError("The python `spython` package must be installed to " - "run singularity. Install with `pip install spython`") + raise RuntimeError( + "The python `spython` package must be installed to " + "run singularity. Install with `pip install spython`" + ) return run_sorter_container( container_image=container_image, diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py new file mode 100644 index 0000000000..8dbb1b20f6 --- /dev/null +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -0,0 +1,144 @@ +import os +import pytest +from pathlib import Path +import shutil +import platform +from spikeinterface import generate_ground_truth_recording +from spikeinterface.sorters.utils import has_spython, has_docker_python +from spikeinterface.sorters import run_sorter +import subprocess +import sys +import copy + + +def _monkeypatch_return_false(): + return False + + +class TestRunersorterDependencyChecks: + """ + This class performs tests to check whether expected + dependency checks prior to sorting are run. The + run_sorter function should raise an error if: + - singularity is not installed + - spython is not installed (python package) + - docker is not installed + - docker is not installed (python package) + when running singularity / docker respectively. + + Two separate checks should be run. First, that the + relevant `has_` function (indicating if the dependency + is installed) is working. Unfortunately it is not possible to + easily test this core singularity and docker installs, so this is not done. + `uninstall_python_dependency()` allows a test to check if the + `has_spython()` and `has_docker_dependency()` return `False` as expected + when these python modules are not installed. + + Second, the `run_sorters()` function should return the appropriate error + when these functions return that the dependency is not available. This is + easier to test as these `has_` reporting functions can be + monkeypatched to return False at runtime. This is done for these 4 + dependency checks, and tests check the expected error is raised. + + Notes + ---- + `has_nvidia()` and `has_docker_nvidia_installed()` are not tested + as these are complex GPU-related dependencies which are difficult to mock. + """ + + @pytest.fixture(scope="function") + def uninstall_python_dependency(self, request): + """ + This python fixture mocks python modules not been importable + by setting the relevant `sys.modules` dict entry to `None`. + It uses `yeild` so that the function can tear-down the test + (even if it failed) and replace the patched `sys.module` entry. + + This function uses an `indirect` parameterisation, meaning the + `request.param` is passed to the fixture at the start of the + test function. This is used to reuse code for nearly identical + `spython` and `docker` python dependency tests. + """ + dep_name = request.param + assert dep_name in ["spython", "docker"] + + try: + if dep_name == "spython": + import spython + else: + import docker + dependency_installed = True + except: + dependency_installed = False + + if dependency_installed: + copy_import = sys.modules[dep_name] + sys.modules[dep_name] = None + yield + if dependency_installed: + sys.modules[dep_name] = copy_import + + @pytest.fixture(scope="session") + def recording(self): + """ + Make a small recording to have something to pass to the sorter. + """ + recording, _ = generate_ground_truth_recording(durations=[10]) + return recording + + @pytest.mark.skipif(platform.system() != "Linux", reason="spython install only for Linux.") + @pytest.mark.parametrize("uninstall_python_dependency", ["spython"], indirect=True) + def test_has_spython(self, recording, uninstall_python_dependency): + """ + Test the `has_spython()` function, see class docstring and + `uninstall_python_dependency()` for details. + """ + assert has_spython() is False + + @pytest.mark.parametrize("uninstall_python_dependency", ["docker"], indirect=True) + def test_has_docker_python(self, recording, uninstall_python_dependency): + """ + Test the `has_docker_python()` function, see class docstring and + `uninstall_python_dependency()` for details. + """ + assert has_docker_python() is False + + @pytest.mark.parametrize("dependency", ["singularity", "spython"]) + def test_has_singularity_and_spython(self, recording, monkeypatch, dependency): + """ + When running a sorting, if singularity dependencies (singularity + itself or the `spython` package`) are not installed, an error is raised. + Beacause it is hard to actually uninstall these dependencies, the + `has_` functions that let `run_sorter` know if the dependency + are installed are monkeypatched. This is done so at runtime these always + return False. Then, test the expected error is raised when the dependency + is not found. + """ + test_func = f"has_{dependency}" + + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.{test_func}", _monkeypatch_return_false) + with pytest.raises(RuntimeError) as e: + run_sorter("kilosort2_5", recording, singularity_image=True) + + if dependency == "spython": + assert "The python `spython` package must be installed" in str(e) + else: + assert "Singularity is not installed." in str(e) + + @pytest.mark.parametrize("dependency", ["docker", "docker_python"]) + def test_has_docker_and_docker_python(self, recording, monkeypatch, dependency): + """ + See `test_has_singularity_and_spython()` for details. This test + is almost identical, but with some key changes for Docker. + """ + test_func = f"has_{dependency}" + + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.{test_func}", _monkeypatch_return_false) + + with pytest.raises(RuntimeError) as e: + run_sorter("kilosort2_5", recording, docker_image=True) + + if dependency == "docker_python": + assert "The python `docker` package must be installed" in str(e) + else: + assert "Docker is not installed." in str(e) From 78ccc2719676b238dbd92d2ad5384786ca0724e0 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 21:24:29 +0100 Subject: [PATCH 020/103] Remove unnecessary non-relative import. --- src/spikeinterface/sorters/runsorter.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 5b2e80b83d..c16435cdb5 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -19,9 +19,7 @@ from ..core import BaseRecording, NumpySorting, load_extractor from ..core.core_tools import check_json, is_editable_mode from .sorterlist import sorter_dict - -# full import required for monkeypatch testing. -from spikeinterface.sorters.utils import ( +from .utils import ( SpikeSortingError, has_nvidia, has_docker, From f1438c4ce20bbd7ae3c910b793f92ebb4d723253 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 21:27:03 +0100 Subject: [PATCH 021/103] Fix some string formatting, add docstring to monkeypatch function. --- src/spikeinterface/sorters/runsorter.py | 6 ++---- .../sorters/tests/test_runsorter_dependency_checks.py | 4 ++++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index c16435cdb5..f9994dd38d 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -181,13 +181,11 @@ def run_sorter( if not has_docker(): raise RuntimeError( - "Docker is not installed. Install docker " "on this machine to run sorting with docker." + "Docker is not installed. Install docker on this machine to run sorting with docker." ) if not has_docker_python(): - raise RuntimeError( - "The python `docker` package must be installed. " "Install with `pip install docker`" - ) + raise RuntimeError("The python `docker` package must be installed. Install with `pip install docker`") else: mode = "singularity" diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py index 8dbb1b20f6..c81593b7db 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -12,6 +12,10 @@ def _monkeypatch_return_false(): + """ + A function to monkeypatch the `has_` functions, + ensuring the always return `False` at runtime. + """ return False From fd4406e0826f80329614e3b59388e9640c00fe3e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Jun 2024 20:27:36 +0000 Subject: [PATCH 022/103] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/utils/__init__.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/utils/__init__.py b/src/spikeinterface/sorters/utils/__init__.py index 7f6f3089d4..62317be6f2 100644 --- a/src/spikeinterface/sorters/utils/__init__.py +++ b/src/spikeinterface/sorters/utils/__init__.py @@ -1,2 +1,14 @@ from .shellscript import ShellScript -from .misc import SpikeSortingError, get_git_commit, has_nvidia, get_matlab_shell_name, get_bash_path, has_docker, has_docker_python, has_singularity, has_spython, has_docker_nvidia_installed, get_nvidia_docker_dependecies +from .misc import ( + SpikeSortingError, + get_git_commit, + has_nvidia, + get_matlab_shell_name, + get_bash_path, + has_docker, + has_docker_python, + has_singularity, + has_spython, + has_docker_nvidia_installed, + get_nvidia_docker_dependecies, +) From 7af611ba289e220c4bf36f4b62ae26efe94f93b1 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 21:42:26 +0100 Subject: [PATCH 023/103] Mock all has functions to ensure tests do not depend on actual dependencies. --- .../tests/test_runsorter_dependency_checks.py | 58 +++++++++++++------ 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py index c81593b7db..a248033089 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -4,7 +4,7 @@ import shutil import platform from spikeinterface import generate_ground_truth_recording -from spikeinterface.sorters.utils import has_spython, has_docker_python +from spikeinterface.sorters.utils import has_spython, has_docker_python, has_docker, has_singularity from spikeinterface.sorters import run_sorter import subprocess import sys @@ -19,6 +19,10 @@ def _monkeypatch_return_false(): return False +def _monkeypatch_return_true(): + return True + + class TestRunersorterDependencyChecks: """ This class performs tests to check whether expected @@ -91,6 +95,7 @@ def recording(self): return recording @pytest.mark.skipif(platform.system() != "Linux", reason="spython install only for Linux.") + @pytest.mark.skipif(not has_singularity(), reason="singularity required for this test.") @pytest.mark.parametrize("uninstall_python_dependency", ["spython"], indirect=True) def test_has_spython(self, recording, uninstall_python_dependency): """ @@ -100,6 +105,7 @@ def test_has_spython(self, recording, uninstall_python_dependency): assert has_spython() is False @pytest.mark.parametrize("uninstall_python_dependency", ["docker"], indirect=True) + @pytest.mark.skipif(not has_docker(), reason="docker required for this test.") def test_has_docker_python(self, recording, uninstall_python_dependency): """ Test the `has_docker_python()` function, see class docstring and @@ -107,8 +113,7 @@ def test_has_docker_python(self, recording, uninstall_python_dependency): """ assert has_docker_python() is False - @pytest.mark.parametrize("dependency", ["singularity", "spython"]) - def test_has_singularity_and_spython(self, recording, monkeypatch, dependency): + def test_no_singularity_error_raised(self, recording, monkeypatch): """ When running a sorting, if singularity dependencies (singularity itself or the `spython` package`) are not installed, an error is raised. @@ -118,31 +123,46 @@ def test_has_singularity_and_spython(self, recording, monkeypatch, dependency): return False. Then, test the expected error is raised when the dependency is not found. """ - test_func = f"has_{dependency}" + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.has_singularity", _monkeypatch_return_false) - monkeypatch.setattr(f"spikeinterface.sorters.runsorter.{test_func}", _monkeypatch_return_false) with pytest.raises(RuntimeError) as e: run_sorter("kilosort2_5", recording, singularity_image=True) - if dependency == "spython": - assert "The python `spython` package must be installed" in str(e) - else: - assert "Singularity is not installed." in str(e) + assert "Singularity is not installed." in str(e) - @pytest.mark.parametrize("dependency", ["docker", "docker_python"]) - def test_has_docker_and_docker_python(self, recording, monkeypatch, dependency): + def test_no_spython_error_raised(self, recording, monkeypatch): """ - See `test_has_singularity_and_spython()` for details. This test - is almost identical, but with some key changes for Docker. + See `test_no_singularity_error_raised()`. """ - test_func = f"has_{dependency}" + # make sure singularity test returns true as that comes first + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.has_singularity", _monkeypatch_return_true) + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.has_spython", _monkeypatch_return_false) + + with pytest.raises(RuntimeError) as e: + run_sorter("kilosort2_5", recording, singularity_image=True) + + assert "The python `spython` package must be installed" in str(e) - monkeypatch.setattr(f"spikeinterface.sorters.runsorter.{test_func}", _monkeypatch_return_false) + def test_no_docker_error_raised(self, recording, monkeypatch): + """ + See `test_no_singularity_error_raised()`. + """ + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.has_docker", _monkeypatch_return_false) + + with pytest.raises(RuntimeError) as e: + run_sorter("kilosort2_5", recording, docker_image=True) + + assert "Docker is not installed." in str(e) + + def test_as_no_docker_python_error_raised(self, recording, monkeypatch): + """ + See `test_no_singularity_error_raised()`. + """ + # make sure docker test returns true as that comes first + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.has_docker", _monkeypatch_return_true) + monkeypatch.setattr(f"spikeinterface.sorters.runsorter.has_docker_python", _monkeypatch_return_false) with pytest.raises(RuntimeError) as e: run_sorter("kilosort2_5", recording, docker_image=True) - if dependency == "docker_python": - assert "The python `docker` package must be installed" in str(e) - else: - assert "Docker is not installed." in str(e) + assert "The python `docker` package must be installed" in str(e) From 0c0b1f908d8e356b9a58cacd4524ace871ff93b3 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 21:43:10 +0100 Subject: [PATCH 024/103] Remove unecessary skips. --- .../sorters/tests/test_runsorter_dependency_checks.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py index a248033089..741fe4ae0e 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -95,7 +95,6 @@ def recording(self): return recording @pytest.mark.skipif(platform.system() != "Linux", reason="spython install only for Linux.") - @pytest.mark.skipif(not has_singularity(), reason="singularity required for this test.") @pytest.mark.parametrize("uninstall_python_dependency", ["spython"], indirect=True) def test_has_spython(self, recording, uninstall_python_dependency): """ @@ -105,7 +104,6 @@ def test_has_spython(self, recording, uninstall_python_dependency): assert has_spython() is False @pytest.mark.parametrize("uninstall_python_dependency", ["docker"], indirect=True) - @pytest.mark.skipif(not has_docker(), reason="docker required for this test.") def test_has_docker_python(self, recording, uninstall_python_dependency): """ Test the `has_docker_python()` function, see class docstring and From 1be1dbd39a339ff56c0803ff7a59e5650d95b781 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 13 Jun 2024 09:04:04 +0100 Subject: [PATCH 025/103] Update docstrings. --- .../sorters/tests/test_runsorter_dependency_checks.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py index 741fe4ae0e..c4beaba072 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -20,14 +20,18 @@ def _monkeypatch_return_false(): def _monkeypatch_return_true(): + """ + Monkeypatch for some `has_` functions to + return `True` so functions that are later in the + `runsorter` code can be checked. + """ return True class TestRunersorterDependencyChecks: """ - This class performs tests to check whether expected - dependency checks prior to sorting are run. The - run_sorter function should raise an error if: + This class tests whether expected dependency checks prior to sorting are run. + The run_sorter function should raise an error if: - singularity is not installed - spython is not installed (python package) - docker is not installed From 00663080b03f7933d37ba4ff2ee32e3402aa200e Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 13 Jun 2024 09:10:30 +0100 Subject: [PATCH 026/103] Swap return bool for to match function name. --- src/spikeinterface/sorters/runsorter.py | 2 +- src/spikeinterface/sorters/utils/misc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index f9994dd38d..80608f8973 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -494,7 +494,7 @@ def run_sorter_container( assert has_nvidia(), "The container requires a NVIDIA GPU capability, but it is not available" extra_kwargs["container_requires_gpu"] = True - if platform.system() == "Linux" and has_docker_nvidia_installed(): + if platform.system() == "Linux" and not has_docker_nvidia_installed(): warn( f"nvidia-required but none of \n{get_nvidia_docker_dependecies()}\n were found. " f"This may result in an error being raised during sorting. Try " diff --git a/src/spikeinterface/sorters/utils/misc.py b/src/spikeinterface/sorters/utils/misc.py index 66744fbab1..1e01b9c052 100644 --- a/src/spikeinterface/sorters/utils/misc.py +++ b/src/spikeinterface/sorters/utils/misc.py @@ -119,7 +119,7 @@ def has_docker_nvidia_installed(): has_dep = [] for dep in all_dependencies: has_dep.append(_run_subprocess_silently(f"{dep} --version").returncode == 0) - return not any(has_dep) + return any(has_dep) def get_nvidia_docker_dependecies(): From 9664f69c4bcdd24e20584f601bcbd6a9ae79e174 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 19 Jun 2024 11:49:32 +0200 Subject: [PATCH 027/103] Apply suggestions from code review Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- .../sorters/tests/test_runsorter_dependency_checks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py index c4beaba072..83d6ec3161 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py +++ b/src/spikeinterface/sorters/tests/test_runsorter_dependency_checks.py @@ -13,7 +13,7 @@ def _monkeypatch_return_false(): """ - A function to monkeypatch the `has_` functions, + A function to monkeypatch the `has_` functions, ensuring the always return `False` at runtime. """ return False @@ -61,12 +61,12 @@ class TestRunersorterDependencyChecks: @pytest.fixture(scope="function") def uninstall_python_dependency(self, request): """ - This python fixture mocks python modules not been importable + This python fixture mocks python modules not being importable by setting the relevant `sys.modules` dict entry to `None`. - It uses `yeild` so that the function can tear-down the test + It uses `yield` so that the function can tear-down the test (even if it failed) and replace the patched `sys.module` entry. - This function uses an `indirect` parameterisation, meaning the + This function uses an `indirect` parameterization, meaning the `request.param` is passed to the fixture at the start of the test function. This is used to reuse code for nearly identical `spython` and `docker` python dependency tests. From c504fc63f94bf0d31b5aea7329cc067f2d81dee4 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 19 Jun 2024 12:40:57 -0400 Subject: [PATCH 028/103] fix for load_json --- src/spikeinterface/core/sortinganalyzer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 46d02099d5..0094012013 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1608,7 +1608,9 @@ def load_data(self): if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() for ext_data_file in extension_folder.iterdir(): - if ext_data_file.name == "params.json": + # patch for https://github.com/SpikeInterface/spikeinterface/issues/3041 + # maybe add a check for version number from the info.json during loading only + if ext_data_file.name == "params.json" or ext_data_file.name == "info.json": continue ext_data_name = ext_data_file.stem if ext_data_file.suffix == ".json": From 375620fa1589b8fdb46e7e9289909992ac5b0398 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jun 2024 19:01:08 +0200 Subject: [PATCH 029/103] fix spike_vector_to_indices --- src/spikeinterface/core/sorting_tools.py | 15 ++++++++++++++- .../postprocessing/spike_amplitudes.py | 2 +- .../postprocessing/spike_locations.py | 2 +- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 2313e7d253..5e3af58198 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -47,7 +47,7 @@ def spike_vector_to_spike_trains(spike_vector: list[np.array], unit_ids: np.arra return spike_trains -def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array): +def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array, absolut_index=False): """ Similar to spike_vector_to_spike_trains but instead having the spike_trains (aka spike times) return spike indices by segment and units. @@ -61,6 +61,12 @@ def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array): List of spike vectors optained with sorting.to_spike_vector(concatenated=False) unit_ids: np.array Unit ids + absolut_index: bool, default False + Give spike indices absolut usefull when having a unique spike vector + or relative to segment usefull with a list of spike vectors + When a unique spike vectors (or amplitudes) is used then absolut_index should be True. + When a list of spikes (or amplitudes) is used then absolut_index should be False. + Returns ------- spike_indices: dict[dict]: @@ -82,12 +88,19 @@ def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array): num_units = unit_ids.size spike_indices = {} + + total_spikes = 0 for segment_index, spikes in enumerate(spike_vector): indices = np.arange(spikes.size, dtype=np.int64) + if absolut_index: + indices += total_spikes + total_spikes += spikes.size unit_indices = np.array(spikes["unit_index"]).astype(np.int64, copy=False) list_of_spike_indices = vector_to_list_of_spiketrain(indices, unit_indices, num_units) + spike_indices[segment_index] = dict(zip(unit_ids, list_of_spike_indices)) + return spike_indices diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 09b46362e5..2a9edf7e73 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -127,7 +127,7 @@ def _get_data(self, outputs="numpy"): elif outputs == "by_unit": unit_ids = self.sorting_analyzer.unit_ids spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) - spike_indices = spike_vector_to_indices(spike_vector, unit_ids) + spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolut_index=True) amplitudes_by_units = {} for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): amplitudes_by_units[segment_index] = {} diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index d468bd90ab..e7a9d7a992 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -140,7 +140,7 @@ def _get_data(self, outputs="numpy"): elif outputs == "by_unit": unit_ids = self.sorting_analyzer.unit_ids spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) - spike_indices = spike_vector_to_indices(spike_vector, unit_ids) + spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolut_index=True) spike_locations_by_units = {} for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): spike_locations_by_units[segment_index] = {} From 543cc8f2a67719e4ae8b5b64a198a6c7256406e4 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Wed, 19 Jun 2024 18:12:31 +0100 Subject: [PATCH 030/103] Add apptainer case to 'has_singularity()' Co-authored-by: Alessio Buccino --- src/spikeinterface/sorters/utils/misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/utils/misc.py b/src/spikeinterface/sorters/utils/misc.py index 1e01b9c052..82480ffe0a 100644 --- a/src/spikeinterface/sorters/utils/misc.py +++ b/src/spikeinterface/sorters/utils/misc.py @@ -96,7 +96,7 @@ def has_docker(): def has_singularity(): - return _run_subprocess_silently("singularity --version").returncode == 0 + return _run_subprocess_silently("singularity --version").returncode == 0 or _run_subprocess_silently("apptainer --version").returncode == 0 def has_docker_nvidia_installed(): From dceb08070af9954b25c99c82ed2df314ef924aa7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Jun 2024 17:12:51 +0000 Subject: [PATCH 031/103] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/utils/misc.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/utils/misc.py b/src/spikeinterface/sorters/utils/misc.py index 82480ffe0a..9c8c3bba89 100644 --- a/src/spikeinterface/sorters/utils/misc.py +++ b/src/spikeinterface/sorters/utils/misc.py @@ -96,7 +96,10 @@ def has_docker(): def has_singularity(): - return _run_subprocess_silently("singularity --version").returncode == 0 or _run_subprocess_silently("apptainer --version").returncode == 0 + return ( + _run_subprocess_silently("singularity --version").returncode == 0 + or _run_subprocess_silently("apptainer --version").returncode == 0 + ) def has_docker_nvidia_installed(): From 8a7c145a8a1abd4c8d63c55eabb32910205053ab Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 13:30:06 +0100 Subject: [PATCH 032/103] Add peaks_on_probe widget and tests. --- src/spikeinterface/widgets/peaks_on_probe.py | 218 +++++++++++++ .../widgets/tests/test_peaks_on_probe.py | 304 ++++++++++++++++++ src/spikeinterface/widgets/widget_list.py | 3 + 3 files changed, 525 insertions(+) create mode 100644 src/spikeinterface/widgets/peaks_on_probe.py create mode 100644 src/spikeinterface/widgets/tests/test_peaks_on_probe.py diff --git a/src/spikeinterface/widgets/peaks_on_probe.py b/src/spikeinterface/widgets/peaks_on_probe.py new file mode 100644 index 0000000000..0d23b6c67e --- /dev/null +++ b/src/spikeinterface/widgets/peaks_on_probe.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import numpy as np + + +from .base import BaseWidget, to_attr + + +class PeaksOnProbeWidget(BaseWidget): + """ + Generate a plot of spike peaks showing their location on a plot + of the probe. Color scaling represents spike amplitude. + + The generated plot overlays the estimated position of a spike peak + (as a single point for each peak) onto a plot of the probe. The + dimensions of the plot are x axis: probe width, y axis: probe depth. + + Plots of different sets of peaks can be created on subplots, by + passing a list of peaks and corresponding peak locations. + + Parameters + ---------- + recording : Recording + A SpikeInterface recording object. + peaks : np.array | list[np.ndarray] + SpikeInterface 'peaks' array created with `detect_peaks()`, + an array of length num_peaks with entries: + (sample_index, channel_index, amplitude, segment_index) + To plot different sets of peaks in subplots, pass a list of peaks, each + with a corresponding entry in a list passed to `peak_locations`. + peak_locations : np.array | list[np.ndarray] + A SpikeInterface 'peak_locations' array created with `localize_peaks()`. + an array of length num_peaks with entries: (x, y) + To plot multiple peaks in subplots, pass a list of `peak_locations` + here with each entry having a corresponding `peaks`. + segment_index : None | int, default: None + If set, only peaks from this recording segment will be used. + time_range : None | Tuple, default: None + The time period over which to include peaks. If `None`, peaks + across the entire recording will be shown. + ylim : None | Tuple, default: None + The y-axis limits (i.e. the probe depth). If `None`, the entire + probe will be displayed. + decimate : int, default: 5 + For performance reasons, every nth peak is shown on the plot, + where n is set by decimate. To plot all peaks, set `decimate=1`. + """ + + def __init__( + self, + recording, + peaks, + peak_locations, + segment_index=None, + time_range=None, + ylim=None, + decimate=5, + backend=None, + **backend_kwargs, + ): + data_plot = dict( + recording=recording, + peaks=peaks, + peak_locations=peak_locations, + segment_index=segment_index, + time_range=time_range, + ylim=ylim, + decimate=decimate, + ) + + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from spikeinterface.widgets import plot_probe_map + + dp = to_attr(data_plot) + + peaks, peak_locations = self._check_and_format_inputs( + dp.peaks, + dp.peak_locations, + ) + fs = dp.recording.get_sampling_frequency() + num_plots = len(peaks) + + # Set the maximum time to the end time of the longest segment + if dp.time_range is None: + + time_range = self._get_min_and_max_times_in_recording(dp.recording) + else: + time_range = dp.time_range + + ## Create the figure and axes + if backend_kwargs["figsize"] is None: + backend_kwargs.update(dict(figsize=(12, 8))) + + self.figure, self.axes, self.ax = make_mpl_figure(num_axes=num_plots, **backend_kwargs) + self.axes = self.axes[0] + + # Plot each passed peaks / peak_locations over the probe on a separate subplot + for ax_idx, (peaks_to_plot, peak_locs_to_plot) in enumerate(zip(peaks, peak_locations)): + + ax = self.axes[ax_idx] + plot_probe_map(dp.recording, ax=ax) + + time_mask = self._get_peaks_time_mask(dp.recording, time_range, peaks_to_plot) + + if dp.segment_index is not None: + segment_mask = peaks_to_plot["segment_index"] == dp.segment_index + mask = time_mask & segment_mask + else: + mask = time_mask + + if not any(mask): + raise ValueError( + "No peaks within the time and segment mask found. Change `time_range` or `segment_index`" + ) + + # only plot every nth peak + peak_slice = slice(None, None, dp.decimate) + + # Find the amplitudes for the colormap scaling + # (intensity represents amplitude) + amps = np.abs(peaks_to_plot["amplitude"][mask][peak_slice]) + amps /= np.quantile(amps, 0.95) + cmap = plt.get_cmap("inferno")(amps) + color_kwargs = dict(alpha=0.2, s=2, c=cmap) + + # Plot the peaks over the plot, and set the y-axis limits. + ax.scatter( + peak_locs_to_plot["x"][mask][peak_slice], peak_locs_to_plot["y"][mask][peak_slice], **color_kwargs + ) + + if dp.ylim is None: + padding = 25 # arbitary padding just to give some space around highests and lowest peaks on the plot + ylim = (np.min(peak_locs_to_plot["y"]) - padding, np.max(peak_locs_to_plot["y"]) + padding) + else: + ylim = dp.ylim + + ax.set_ylim(ylim[0], ylim[1]) + + self.figure.suptitle(f"Peaks on Probe Plot") + + def _get_peaks_time_mask(self, recording, time_range, peaks_to_plot): + """ + Return a mask of `True` where the peak is within the given time range + and `False` otherwise. + + This is a little complex, as each segment can have different start / + end times. For each segment, find the time bounds relative to that + segment time and fill the `time_mask` one segment at a time. + """ + time_mask = np.zeros(peaks_to_plot.size, dtype=bool) + + for seg_idx in range(recording.get_num_segments()): + + segment = recording.select_segments(seg_idx) + + t_start_sample = segment.time_to_sample_index(time_range[0]) + t_stop_sample = segment.time_to_sample_index(time_range[1]) + + seg_mask = peaks_to_plot["segment_index"] == seg_idx + + time_mask[seg_mask] = (t_start_sample < peaks_to_plot[seg_mask]["sample_index"]) & ( + peaks_to_plot[seg_mask]["sample_index"] < t_stop_sample + ) + + return time_mask + + def _get_min_and_max_times_in_recording(self, recording): + """ + Find the maximum and minimum time across all segments in the recording. + For example if the segment times are (10-100 s, 0 - 50s) the + min and max times are (0, 100) + """ + t_starts = [] + t_stops = [] + for seg_idx in range(recording.get_num_segments()): + + segment = recording.select_segments(seg_idx) + + t_starts.append(segment.sample_index_to_time(0)) + + t_stops.append(segment.sample_index_to_time(segment.get_num_samples() - 1)) + + time_range = (np.min(t_starts), np.max(t_stops)) + + return time_range + + def _check_and_format_inputs(self, peaks, peak_locations): + """ + Check that the inpust are in expected form. Corresponding peaks + and peak_locations of same size and format must be provided. + """ + types_are_list = [isinstance(peaks, list), isinstance(peak_locations, list)] + + if not all(types_are_list): + if any(types_are_list): + raise ValueError("`peaks` and `peak_locations` must either be both lists or both not lists.") + peaks = [peaks] + peak_locations = [peak_locations] + + if len(peaks) != len(peak_locations): + raise ValueError( + "If `peaks` and `peak_locations` are lists, they must contain " + "the same number of (corresponding) peaks and peak locations." + ) + + for idx, (peak, peak_loc) in enumerate(zip(peaks, peak_locations)): + if peak.size != peak_loc.size: + raise ValueError( + f"The number of peaks and peak_locations do not " + f"match for the {idx} input. For each spike peak, there " + f"must be a corresponding peak location" + ) + + return peaks, peak_locations diff --git a/src/spikeinterface/widgets/tests/test_peaks_on_probe.py b/src/spikeinterface/widgets/tests/test_peaks_on_probe.py new file mode 100644 index 0000000000..9820ee5e72 --- /dev/null +++ b/src/spikeinterface/widgets/tests/test_peaks_on_probe.py @@ -0,0 +1,304 @@ +import pytest +from spikeinterface.sortingcomponents.peak_localization import localize_peaks +from spikeinterface.sortingcomponents.peak_detection import detect_peaks +from spikeinterface.widgets import plot_peaks_on_probe +from spikeinterface import generate_ground_truth_recording # TODO: think about imports +import numpy as np + + +class TestPeaksOnProbe: + + @pytest.fixture(scope="session") + def peak_info(self): + """ + Fixture (created only once per test run) of a small + ground truth recording with peaks and peak locations calculated. + """ + recording, _ = generate_ground_truth_recording(num_units=5, num_channels=16, durations=[20, 9], seed=0) + peaks = detect_peaks(recording) + + peak_locations = localize_peaks( + recording, + peaks, + ms_before=0.3, + ms_after=0.6, + method="center_of_mass", + ) + + return (recording, peaks, peak_locations) + + def data_from_widget(self, widget, axes_idx): + """ + Convenience function to get the data of the peaks + that are on the plot (not sure why they are in the + second 'collections'). + """ + return widget.axes[axes_idx].collections[2].get_offsets().data + + def test_peaks_on_probe_main(self, peak_info): + """ + Plot all peaks, and check every peak is plot. + Check the labels are corect. + """ + recording, peaks, peak_locations = peak_info + + widget = plot_peaks_on_probe(recording, peaks, peak_locations, decimate=1) + + ax_y_data = self.data_from_widget(widget, 0)[:, 1] + ax_y_pos = peak_locations["y"] + + assert np.array_equal(np.sort(ax_y_data), np.sort(ax_y_pos)) + assert widget.axes[0].get_ylabel() == "y ($\\mu m$)" + assert widget.axes[0].get_xlabel() == "x ($\\mu m$)" + + @pytest.mark.parametrize("segment_index", [0, 1]) + def test_segment_selection(self, peak_info, segment_index): + """ + Check that that when specifying only to plot peaks + from a sepecific segment, that only peaks + from that segment are plot. + """ + recording, peaks, peak_locations = peak_info + + widget = plot_peaks_on_probe( + recording, + peaks, + peak_locations, + decimate=1, + segment_index=segment_index, + ) + + ax_y_data = self.data_from_widget(widget, 0)[:, 1] + ax_y_pos = peak_locations["y"][peaks["segment_index"] == segment_index] + + assert np.array_equal(np.sort(ax_y_data), np.sort(ax_y_pos)) + + def test_multiple_inputs(self, peak_info): + """ + Check that multiple inputs are correctly plot + on separate axes. Do this my creating a copy + of the peaks / peak locations with less peaks + and different locations, for good measure. + Check that these separate peaks / peak locations + are plot on different axes. + """ + recording, peaks, peak_locations = peak_info + + half_num_peaks = int(peaks.shape[0] / 2) + + peaks_change = peaks.copy()[:half_num_peaks] + locs_change = peak_locations.copy()[:half_num_peaks] + locs_change["y"] += 1 + + widget = plot_peaks_on_probe( + recording, + [peaks, peaks_change], + [peak_locations, locs_change], + decimate=1, + ) + + # Test the first entry, axis 0 + ax_0_y_data = self.data_from_widget(widget, 0)[:, 1] + + assert np.array_equal(np.sort(peak_locations["y"]), np.sort(ax_0_y_data)) + + # Test the second entry, axis 1. + ax_1_y_data = self.data_from_widget(widget, 1)[:, 1] + + assert np.array_equal(np.sort(locs_change["y"]), np.sort(ax_1_y_data)) + + def test_times_all(self, peak_info): + """ + Check that when the times of peaks to plot is restricted, + only peaks within the given time range are plot. Set the + limits just before and after the second peak, and check only + that peak is plot. + """ + recording, peaks, peak_locations = peak_info + + peak_idx = 1 + peak_cutoff_low = peaks["sample_index"][peak_idx] - 1 + peak_cutoff_high = peaks["sample_index"][peak_idx] + 1 + + widget = plot_peaks_on_probe( + recording, + peaks, + peak_locations, + decimate=1, + time_range=( + peak_cutoff_low / recording.get_sampling_frequency(), + peak_cutoff_high / recording.get_sampling_frequency(), + ), + ) + + ax_y_data = self.data_from_widget(widget, 0)[:, 1] + + assert np.array_equal([peak_locations[peak_idx]["y"]], ax_y_data) + + def test_times_per_segment(self, peak_info): + """ + Test that the time bounds for multi-segment recordings + with different times are handled properly. The time bounds + given must respect the times for each segment. Here, we build + two segments with times 0-100s and 100-200s. We set the + time limits for peaks to plot as 50-150 i.e. all peaks + from the second half of the first segment, and the first half + of the second segment, should be plotted. + + Recompute peaks here for completeness even though this does + duplicate the fixture. + """ + recording, _, _ = peak_info + + first_seg_times = np.linspace(0, 100, recording.get_num_samples(0)) + second_seg_times = np.linspace(100, 200, recording.get_num_samples(1)) + + recording.set_times(first_seg_times, segment_index=0) + recording.set_times(second_seg_times, segment_index=1) + + # After setting the peak times above, re-detect peaks and plot + # with a time range 50-150 s + peaks = detect_peaks(recording) + + peak_locations = localize_peaks( + recording, + peaks, + ms_before=0.3, + ms_after=0.6, + method="center_of_mass", + ) + + widget = plot_peaks_on_probe( + recording, + peaks, + peak_locations, + decimate=1, + time_range=( + 50, + 150, + ), + ) + + # Find the peaks that are expected to be plot given the time + # restriction (second half of first segment, first half of + # second segment) and check that indeed the expected locations + # are displayed. + seg_one_num_samples = recording.get_num_samples(0) + seg_two_num_samples = recording.get_num_samples(1) + + okay_peaks_one = np.logical_and( + peaks["segment_index"] == 0, peaks["sample_index"] > int(seg_one_num_samples / 2) + ) + okay_peaks_two = np.logical_and( + peaks["segment_index"] == 1, peaks["sample_index"] < int(seg_two_num_samples / 2) + ) + okay_peaks = np.logical_or(okay_peaks_one, okay_peaks_two) + + ax_y_data = self.data_from_widget(widget, 0)[:, 1] + + assert any(okay_peaks), "someting went wrong in test generation, no peaks within the set time bounds detected" + + assert np.array_equal(np.sort(ax_y_data), np.sort(peak_locations[okay_peaks]["y"])) + + def test_get_min_and_max_times_in_recording(self, peak_info): + """ + Check that the function which finds the minimum and maximum times + across all segments in the recording returns correctly. First + set times of the segments such that the earliest time is 50s and + latest 200s. Check the function returns (50, 200). + """ + recording, peaks, peak_locations = peak_info + + first_seg_times = np.linspace(50, 100, recording.get_num_samples(0)) + second_seg_times = np.linspace(100, 200, recording.get_num_samples(1)) + + recording.set_times(first_seg_times, segment_index=0) + recording.set_times(second_seg_times, segment_index=1) + + widget = plot_peaks_on_probe( + recording, + peaks, + peak_locations, + decimate=1, + ) + + min_max_times = widget._get_min_and_max_times_in_recording(recording) + + assert min_max_times == (50, 200) + + def test_ylim(self, peak_info): + """ + Specify some y-axis limits (which is the probe height + to show) and check that the plot is restricted to + these limits. + """ + recording, peaks, peak_locations = peak_info + + widget = plot_peaks_on_probe( + recording, + peaks, + peak_locations, + decimate=1, + ylim=(300, 600), + ) + + assert widget.axes[0].get_ylim() == (300, 600) + + def test_decimate(self, peak_info): + """ + By default, only a subset of peaks are shown for + performance reasons. In tests, decimate is set to 1 + to ensure all peaks are plot. This tests now + checks the decimate argument, to ensure peaks that are + plot are correctly decimated. + """ + recording, peaks, peak_locations = peak_info + + decimate = 5 + + widget = plot_peaks_on_probe( + recording, + peaks, + peak_locations, + decimate=decimate, + ) + + ax_y_data = self.data_from_widget(widget, 0)[:, 1] + ax_y_pos = peak_locations["y"][::decimate] + + assert np.array_equal(np.sort(ax_y_data), np.sort(ax_y_pos)) + + def test_errors(self, peak_info): + """ + Test all validation errors are raised when data in + incorrect form is passed to the plotting function. + """ + recording, peaks, peak_locations = peak_info + + # All lists must be same length + with pytest.raises(ValueError) as e: + plot_peaks_on_probe( + recording, + [peaks, peaks], + [peak_locations], + ) + + # peaks and corresponding peak locations must be same size + with pytest.raises(ValueError) as e: + plot_peaks_on_probe( + recording, + [peaks[:-1]], + [peak_locations], + ) + + # if one is list, both must be lists + with pytest.raises(ValueError) as e: + plot_peaks_on_probe( + recording, + peaks, + [peak_locations], + ) + + # must have some peaks within the given time / segment + with pytest.raises(ValueError) as e: + plot_peaks_on_probe(recording, [peaks[:-1]], [peak_locations], time_range=(0, 0.001)) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index d6df59b0f3..6367e098ea 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -13,6 +13,7 @@ from .motion import MotionWidget, MotionInfoWidget from .multicomparison import MultiCompGraphWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget from .peak_activity import PeakActivityMapWidget +from .peaks_on_probe import PeaksOnProbeWidget from .potential_merges import PotentialMergesWidget from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget @@ -50,6 +51,7 @@ MultiCompAgreementBySorterWidget, MultiCompGraphWidget, PeakActivityMapWidget, + PeaksOnProbeWidget, PotentialMergesWidget, ProbeMapWidget, QualityMetricsWidget, @@ -123,6 +125,7 @@ plot_multicomparison_agreement_by_sorter = MultiCompAgreementBySorterWidget plot_multicomparison_graph = MultiCompGraphWidget plot_peak_activity = PeakActivityMapWidget +plot_peaks_on_probe = PeaksOnProbeWidget plot_potential_merges = PotentialMergesWidget plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget From 7ab068b36bcc55d1efd6966051f54b152c4321e2 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 19 Jun 2024 20:00:17 +0100 Subject: [PATCH 033/103] Remove tests. --- .../widgets/tests/test_peaks_on_probe.py | 304 ------------------ 1 file changed, 304 deletions(-) delete mode 100644 src/spikeinterface/widgets/tests/test_peaks_on_probe.py diff --git a/src/spikeinterface/widgets/tests/test_peaks_on_probe.py b/src/spikeinterface/widgets/tests/test_peaks_on_probe.py deleted file mode 100644 index 9820ee5e72..0000000000 --- a/src/spikeinterface/widgets/tests/test_peaks_on_probe.py +++ /dev/null @@ -1,304 +0,0 @@ -import pytest -from spikeinterface.sortingcomponents.peak_localization import localize_peaks -from spikeinterface.sortingcomponents.peak_detection import detect_peaks -from spikeinterface.widgets import plot_peaks_on_probe -from spikeinterface import generate_ground_truth_recording # TODO: think about imports -import numpy as np - - -class TestPeaksOnProbe: - - @pytest.fixture(scope="session") - def peak_info(self): - """ - Fixture (created only once per test run) of a small - ground truth recording with peaks and peak locations calculated. - """ - recording, _ = generate_ground_truth_recording(num_units=5, num_channels=16, durations=[20, 9], seed=0) - peaks = detect_peaks(recording) - - peak_locations = localize_peaks( - recording, - peaks, - ms_before=0.3, - ms_after=0.6, - method="center_of_mass", - ) - - return (recording, peaks, peak_locations) - - def data_from_widget(self, widget, axes_idx): - """ - Convenience function to get the data of the peaks - that are on the plot (not sure why they are in the - second 'collections'). - """ - return widget.axes[axes_idx].collections[2].get_offsets().data - - def test_peaks_on_probe_main(self, peak_info): - """ - Plot all peaks, and check every peak is plot. - Check the labels are corect. - """ - recording, peaks, peak_locations = peak_info - - widget = plot_peaks_on_probe(recording, peaks, peak_locations, decimate=1) - - ax_y_data = self.data_from_widget(widget, 0)[:, 1] - ax_y_pos = peak_locations["y"] - - assert np.array_equal(np.sort(ax_y_data), np.sort(ax_y_pos)) - assert widget.axes[0].get_ylabel() == "y ($\\mu m$)" - assert widget.axes[0].get_xlabel() == "x ($\\mu m$)" - - @pytest.mark.parametrize("segment_index", [0, 1]) - def test_segment_selection(self, peak_info, segment_index): - """ - Check that that when specifying only to plot peaks - from a sepecific segment, that only peaks - from that segment are plot. - """ - recording, peaks, peak_locations = peak_info - - widget = plot_peaks_on_probe( - recording, - peaks, - peak_locations, - decimate=1, - segment_index=segment_index, - ) - - ax_y_data = self.data_from_widget(widget, 0)[:, 1] - ax_y_pos = peak_locations["y"][peaks["segment_index"] == segment_index] - - assert np.array_equal(np.sort(ax_y_data), np.sort(ax_y_pos)) - - def test_multiple_inputs(self, peak_info): - """ - Check that multiple inputs are correctly plot - on separate axes. Do this my creating a copy - of the peaks / peak locations with less peaks - and different locations, for good measure. - Check that these separate peaks / peak locations - are plot on different axes. - """ - recording, peaks, peak_locations = peak_info - - half_num_peaks = int(peaks.shape[0] / 2) - - peaks_change = peaks.copy()[:half_num_peaks] - locs_change = peak_locations.copy()[:half_num_peaks] - locs_change["y"] += 1 - - widget = plot_peaks_on_probe( - recording, - [peaks, peaks_change], - [peak_locations, locs_change], - decimate=1, - ) - - # Test the first entry, axis 0 - ax_0_y_data = self.data_from_widget(widget, 0)[:, 1] - - assert np.array_equal(np.sort(peak_locations["y"]), np.sort(ax_0_y_data)) - - # Test the second entry, axis 1. - ax_1_y_data = self.data_from_widget(widget, 1)[:, 1] - - assert np.array_equal(np.sort(locs_change["y"]), np.sort(ax_1_y_data)) - - def test_times_all(self, peak_info): - """ - Check that when the times of peaks to plot is restricted, - only peaks within the given time range are plot. Set the - limits just before and after the second peak, and check only - that peak is plot. - """ - recording, peaks, peak_locations = peak_info - - peak_idx = 1 - peak_cutoff_low = peaks["sample_index"][peak_idx] - 1 - peak_cutoff_high = peaks["sample_index"][peak_idx] + 1 - - widget = plot_peaks_on_probe( - recording, - peaks, - peak_locations, - decimate=1, - time_range=( - peak_cutoff_low / recording.get_sampling_frequency(), - peak_cutoff_high / recording.get_sampling_frequency(), - ), - ) - - ax_y_data = self.data_from_widget(widget, 0)[:, 1] - - assert np.array_equal([peak_locations[peak_idx]["y"]], ax_y_data) - - def test_times_per_segment(self, peak_info): - """ - Test that the time bounds for multi-segment recordings - with different times are handled properly. The time bounds - given must respect the times for each segment. Here, we build - two segments with times 0-100s and 100-200s. We set the - time limits for peaks to plot as 50-150 i.e. all peaks - from the second half of the first segment, and the first half - of the second segment, should be plotted. - - Recompute peaks here for completeness even though this does - duplicate the fixture. - """ - recording, _, _ = peak_info - - first_seg_times = np.linspace(0, 100, recording.get_num_samples(0)) - second_seg_times = np.linspace(100, 200, recording.get_num_samples(1)) - - recording.set_times(first_seg_times, segment_index=0) - recording.set_times(second_seg_times, segment_index=1) - - # After setting the peak times above, re-detect peaks and plot - # with a time range 50-150 s - peaks = detect_peaks(recording) - - peak_locations = localize_peaks( - recording, - peaks, - ms_before=0.3, - ms_after=0.6, - method="center_of_mass", - ) - - widget = plot_peaks_on_probe( - recording, - peaks, - peak_locations, - decimate=1, - time_range=( - 50, - 150, - ), - ) - - # Find the peaks that are expected to be plot given the time - # restriction (second half of first segment, first half of - # second segment) and check that indeed the expected locations - # are displayed. - seg_one_num_samples = recording.get_num_samples(0) - seg_two_num_samples = recording.get_num_samples(1) - - okay_peaks_one = np.logical_and( - peaks["segment_index"] == 0, peaks["sample_index"] > int(seg_one_num_samples / 2) - ) - okay_peaks_two = np.logical_and( - peaks["segment_index"] == 1, peaks["sample_index"] < int(seg_two_num_samples / 2) - ) - okay_peaks = np.logical_or(okay_peaks_one, okay_peaks_two) - - ax_y_data = self.data_from_widget(widget, 0)[:, 1] - - assert any(okay_peaks), "someting went wrong in test generation, no peaks within the set time bounds detected" - - assert np.array_equal(np.sort(ax_y_data), np.sort(peak_locations[okay_peaks]["y"])) - - def test_get_min_and_max_times_in_recording(self, peak_info): - """ - Check that the function which finds the minimum and maximum times - across all segments in the recording returns correctly. First - set times of the segments such that the earliest time is 50s and - latest 200s. Check the function returns (50, 200). - """ - recording, peaks, peak_locations = peak_info - - first_seg_times = np.linspace(50, 100, recording.get_num_samples(0)) - second_seg_times = np.linspace(100, 200, recording.get_num_samples(1)) - - recording.set_times(first_seg_times, segment_index=0) - recording.set_times(second_seg_times, segment_index=1) - - widget = plot_peaks_on_probe( - recording, - peaks, - peak_locations, - decimate=1, - ) - - min_max_times = widget._get_min_and_max_times_in_recording(recording) - - assert min_max_times == (50, 200) - - def test_ylim(self, peak_info): - """ - Specify some y-axis limits (which is the probe height - to show) and check that the plot is restricted to - these limits. - """ - recording, peaks, peak_locations = peak_info - - widget = plot_peaks_on_probe( - recording, - peaks, - peak_locations, - decimate=1, - ylim=(300, 600), - ) - - assert widget.axes[0].get_ylim() == (300, 600) - - def test_decimate(self, peak_info): - """ - By default, only a subset of peaks are shown for - performance reasons. In tests, decimate is set to 1 - to ensure all peaks are plot. This tests now - checks the decimate argument, to ensure peaks that are - plot are correctly decimated. - """ - recording, peaks, peak_locations = peak_info - - decimate = 5 - - widget = plot_peaks_on_probe( - recording, - peaks, - peak_locations, - decimate=decimate, - ) - - ax_y_data = self.data_from_widget(widget, 0)[:, 1] - ax_y_pos = peak_locations["y"][::decimate] - - assert np.array_equal(np.sort(ax_y_data), np.sort(ax_y_pos)) - - def test_errors(self, peak_info): - """ - Test all validation errors are raised when data in - incorrect form is passed to the plotting function. - """ - recording, peaks, peak_locations = peak_info - - # All lists must be same length - with pytest.raises(ValueError) as e: - plot_peaks_on_probe( - recording, - [peaks, peaks], - [peak_locations], - ) - - # peaks and corresponding peak locations must be same size - with pytest.raises(ValueError) as e: - plot_peaks_on_probe( - recording, - [peaks[:-1]], - [peak_locations], - ) - - # if one is list, both must be lists - with pytest.raises(ValueError) as e: - plot_peaks_on_probe( - recording, - peaks, - [peak_locations], - ) - - # must have some peaks within the given time / segment - with pytest.raises(ValueError) as e: - plot_peaks_on_probe(recording, [peaks[:-1]], [peak_locations], time_range=(0, 0.001)) From bd626b0b4fbf4bfa5cfa68e857fb8ea997784c56 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 19 Jun 2024 13:22:10 -0600 Subject: [PATCH 034/103] propagate FrameSlice behavior to frame_slice and time_slice --- src/spikeinterface/core/baserecording.py | 31 +++++++++---------- .../core/frameslicerecording.py | 2 +- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 184959512b..40d014cdb3 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -45,7 +45,6 @@ def __init__(self, sampling_frequency: float, channel_ids: list, dtype): self.annotate(is_filtered=False) def __repr__(self): - extractor_name = self.__class__.__name__ num_segments = self.get_num_segments() @@ -182,7 +181,7 @@ def add_recording_segment(self, recording_segment): self._recording_segments.append(recording_segment) recording_segment.set_parent_extractor(self) - def get_num_samples(self, segment_index=None) -> int: + def get_num_samples(self, segment_index: int | None = None) -> int: """ Returns the number of samples for a segment. @@ -657,21 +656,21 @@ def _remove_channels(self, remove_channel_ids): sub_recording = ChannelSliceRecording(self, new_channel_ids) return sub_recording - def frame_slice(self, start_frame: int, end_frame: int) -> BaseRecording: + def frame_slice(self, start_frame: int | None, end_frame: int | None) -> BaseRecording: """ Returns a new recording with sliced frames. Note that this operation is not in place. Parameters ---------- - start_frame : int - The start frame - end_frame : int - The end frame + start_frame : int, optional + The start frame, if not provided it is set to 0 + end_frame : int, optional + The end frame, it not provided it is set to the total number of samples Returns ------- BaseRecording - The object with sliced frames + A new recording object with only samples between start_frame and end_frame """ from .frameslicerecording import FrameSliceRecording @@ -679,27 +678,27 @@ def frame_slice(self, start_frame: int, end_frame: int) -> BaseRecording: sub_recording = FrameSliceRecording(self, start_frame=start_frame, end_frame=end_frame) return sub_recording - def time_slice(self, start_time: float, end_time: float) -> BaseRecording: + def time_slice(self, start_time: float | None, end_time: float) -> BaseRecording: """ Returns a new recording with sliced time. Note that this operation is not in place. Parameters ---------- - start_time : float - The start time in seconds. - end_time : float - The end time in seconds. + start_time : float, optional + The start time in seconds. If not provided it is set to 0. + end_time : float, optional + The end time in seconds. If not provided it is set to the total duration. Returns ------- BaseRecording - The object with sliced time. + A new recording object with only samples between start_time and end_time """ assert self.get_num_segments() == 1, "Time slicing is only supported for single segment recordings." - start_frame = self.time_to_sample_index(start_time) - end_frame = self.time_to_sample_index(end_time) + start_frame = self.time_to_sample_index(start_time) if start_time else None + end_frame = self.time_to_sample_index(end_time) if end_time else None return self.frame_slice(start_frame=start_frame, end_frame=end_frame) diff --git a/src/spikeinterface/core/frameslicerecording.py b/src/spikeinterface/core/frameslicerecording.py index 533328ad42..133cbf886c 100644 --- a/src/spikeinterface/core/frameslicerecording.py +++ b/src/spikeinterface/core/frameslicerecording.py @@ -30,7 +30,7 @@ def __init__(self, parent_recording, start_frame=None, end_frame=None): assert parent_recording.get_num_segments() == 1, "FrameSliceRecording only works with one segment" - parent_size = parent_recording.get_num_samples(0) + parent_size = parent_recording.get_num_samples(segment_index=0) if start_frame is None: start_frame = 0 else: From 3a6545700b6b55d254c79bac909e83539f471062 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jun 2024 12:06:03 +0200 Subject: [PATCH 035/103] Support kilosort>=4.0.12 --- src/spikeinterface/core/generate.py | 6 +++++ .../sorters/external/kilosort4.py | 26 ++++++++++++++++--- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 05c1ebc7ed..251678e675 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1145,6 +1145,8 @@ def get_traces( ) -> np.ndarray: start_frame = 0 if start_frame is None else max(start_frame, 0) end_frame = self.num_samples if end_frame is None else min(end_frame, self.num_samples) + start_frame = int(start_frame) + end_frame = int(end_frame) start_frame_within_block = start_frame % self.noise_block_size end_frame_within_block = end_frame % self.noise_block_size @@ -1812,6 +1814,8 @@ def get_traces( ) -> np.ndarray: start_frame = 0 if start_frame is None else start_frame end_frame = self.num_samples if end_frame is None else end_frame + start_frame = int(start_frame) + end_frame = int(end_frame) if channel_indices is None: n_channels = self.templates.shape[2] @@ -1848,6 +1852,8 @@ def get_traces( end_traces = start_traces + template.shape[0] if start_traces >= end_frame - start_frame or end_traces <= 0: continue + start_traces = int(start_traces) + end_traces = int(end_traces) start_template = 0 end_template = template.shape[0] diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 47846f10ce..a7f40a9558 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Union +from packaging import version from ..basesorter import BaseSorter from .kilosortbase import KilosortBase @@ -24,11 +25,14 @@ class Kilosort4Sorter(BaseSorter): "do_CAR": True, "invert_sign": False, "nt": 61, + "shift": None, + "scale": None, "artifact_threshold": None, "nskip": 25, "whitening_range": 32, "binning_depth": 5, "sig_interp": 20, + "drift_smoothing": [0.5, 0.5, 0.5], "nt0min": None, "dmin": None, "dminx": 32, @@ -63,11 +67,14 @@ class Kilosort4Sorter(BaseSorter): "do_CAR": "Whether to perform common average reference. Default value: True.", "invert_sign": "Invert the sign of the data. Default value: False.", "nt": "Number of samples per waveform. Also size of symmetric padding for filtering. Default value: 61.", + "shift": "Scalar shift to apply to data before all other operations. Default None.", + "scale": "Scaling factor to apply to data before all other operations. Default None.", "artifact_threshold": "If a batch contains absolute values above this number, it will be zeroed out under the assumption that a recording artifact is present. By default, the threshold is infinite (so that no zeroing occurs). Default value: None.", "nskip": "Batch stride for computing whitening matrix. Default value: 25.", "whitening_range": "Number of nearby channels used to estimate the whitening matrix. Default value: 32.", "binning_depth": "For drift correction, vertical bin size in microns used for 2D histogram. Default value: 5.", "sig_interp": "For drift correction, sigma for interpolation (spatial standard deviation). Approximate smoothness scale in units of microns. Default value: 20.", + "drift_smoothing": "Amount of gaussian smoothing to apply to the spatiotemporal drift estimation, for x,y,time axes in units of registration blocks (for x,y axes) and batch size (for time axis). The x,y smoothing has no effect for `nblocks = 1`.", "nt0min": "Sample index for aligning waveforms, so that their minimum or maximum value happens here. Default of 20. Default value: None.", "dmin": "Vertical spacing of template centers used for spike detection, in microns. Determined automatically by default. Default value: None.", "dminx": "Horizontal spacing of template centers used for spike detection, in microns. Default value: 32.", @@ -153,6 +160,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): import torch import numpy as np + if verbose: + import logging + + logging.basicConfig(level=logging.INFO) + sorter_output_folder = sorter_output_folder.absolute() probe_filename = sorter_output_folder / "probe.prb" @@ -194,11 +206,17 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): data_dir = "" results_dir = sorter_output_folder filename, data_dir, results_dir, probe = set_files(settings, filename, probe, probe_name, data_dir, results_dir) - ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device) + if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): + ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device, False) + n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( + get_run_parameters(ops) + ) + else: + ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device) + n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = ( + get_run_parameters(ops) + ) - n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = ( - get_run_parameters(ops) - ) # Set preprocessing and drift correction parameters if not params["skip_kilosort_preprocessing"]: ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object) From 0507d59b811682f6e2fb1132099dd6575b43362b Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Thu, 20 Jun 2024 12:59:08 +0200 Subject: [PATCH 036/103] Update src/spikeinterface/core/sorting_tools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/core/sorting_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 5e3af58198..13f6b28f3c 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -47,7 +47,7 @@ def spike_vector_to_spike_trains(spike_vector: list[np.array], unit_ids: np.arra return spike_trains -def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array, absolut_index=False): +def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array, absolute_index : bool = False): """ Similar to spike_vector_to_spike_trains but instead having the spike_trains (aka spike times) return spike indices by segment and units. From 861857f6215b89ecb7220c1f34c8b504943b3b60 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Thu, 20 Jun 2024 08:16:55 -0400 Subject: [PATCH 037/103] add tests for select units for zarr --- .../core/tests/test_sortinganalyzer.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index d780932146..7456680b2a 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -67,9 +67,16 @@ def test_SortingAnalyzer_binary_folder(tmp_path, dataset): sorting_analyzer = create_sorting_analyzer( sorting, recording, format="binary_folder", folder=folder, sparse=False, sparsity=None ) + + sorting_analyzer.compute(["random_spikes", "templates"]) sorting_analyzer = load_sorting_analyzer(folder, format="auto") _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) + # test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041 + # this bug requires that we have an info.json file so we calculate templates above + select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=[1]) + assert len(select_units_sorting_analyer.unit_ids) == 1 + folder = tmp_path / "test_SortingAnalyzer_binary_folder" if folder.exists(): shutil.rmtree(folder) @@ -97,9 +104,15 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): sorting_analyzer = create_sorting_analyzer( sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None ) + sorting_analyzer.compute(["random_spikes", "templates"]) sorting_analyzer = load_sorting_analyzer(folder, format="auto") _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) + # test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041 + # this bug requires that we have an info.json file so we calculate templates above + select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=[1]) + assert len(select_units_sorting_analyer.unit_ids) == 1 + folder = tmp_path / "test_SortingAnalyzer_zarr.zarr" if folder.exists(): shutil.rmtree(folder) @@ -312,7 +325,7 @@ def test_extensions_sorting(): if __name__ == "__main__": tmp_path = Path("test_SortingAnalyzer") - dataset = _get_dataset() + dataset = get_dataset() test_SortingAnalyzer_memory(tmp_path, dataset) test_SortingAnalyzer_binary_folder(tmp_path, dataset) test_SortingAnalyzer_zarr(tmp_path, dataset) From 8a2c56fa6ca401fca56aa8f41d432c06eb0c62b2 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Thu, 20 Jun 2024 15:12:14 +0200 Subject: [PATCH 038/103] Merci Zach Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/sorting_tools.py | 8 ++++---- src/spikeinterface/postprocessing/spike_amplitudes.py | 2 +- src/spikeinterface/postprocessing/spike_locations.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 13f6b28f3c..5ac3fcc822 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -61,11 +61,11 @@ def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array, ab List of spike vectors optained with sorting.to_spike_vector(concatenated=False) unit_ids: np.array Unit ids - absolut_index: bool, default False + absolute_index: bool, default False Give spike indices absolut usefull when having a unique spike vector or relative to segment usefull with a list of spike vectors - When a unique spike vectors (or amplitudes) is used then absolut_index should be True. - When a list of spikes (or amplitudes) is used then absolut_index should be False. + When a unique spike vectors (or amplitudes) is used then absolute_index should be True. + When a list of spikes (or amplitudes) is used then absolute_index should be False. Returns ------- @@ -92,7 +92,7 @@ def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array, ab total_spikes = 0 for segment_index, spikes in enumerate(spike_vector): indices = np.arange(spikes.size, dtype=np.int64) - if absolut_index: + if absolute_index: indices += total_spikes total_spikes += spikes.size unit_indices = np.array(spikes["unit_index"]).astype(np.int64, copy=False) diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 2a9edf7e73..aebfd1fd78 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -127,7 +127,7 @@ def _get_data(self, outputs="numpy"): elif outputs == "by_unit": unit_ids = self.sorting_analyzer.unit_ids spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) - spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolut_index=True) + spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) amplitudes_by_units = {} for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): amplitudes_by_units[segment_index] = {} diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index e7a9d7a992..a2dcd4a68a 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -140,7 +140,7 @@ def _get_data(self, outputs="numpy"): elif outputs == "by_unit": unit_ids = self.sorting_analyzer.unit_ids spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) - spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolut_index=True) + spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) spike_locations_by_units = {} for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): spike_locations_by_units[segment_index] = {} From 16dd4c77b419b37f41ffc0795225afff753345ef Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 20 Jun 2024 08:27:53 -0600 Subject: [PATCH 039/103] pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a2e1b3d3a5..58c0f66e44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ dependencies = [ - "numpy>=1.26, <2.0", # 1.20 np.ptp, 1.26 for avoiding pickling errors when numpy >2.0 + "numpy>=1.20, <2.0", # 1.20 np.ptp, 1.26 might be necessary for avoiding pickling errors when numpy >2.0 "threadpoolctl>=3.0.0", "tqdm", "zarr>=2.16,<2.18", From 617649569e147f8a530d6cfd0c0637857481e367 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 20 Jun 2024 13:14:37 -0600 Subject: [PATCH 040/103] improve error log to json in run_sorter --- src/spikeinterface/sorters/basesorter.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 8c52626703..799444ddbd 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -262,7 +262,12 @@ def run_from_folder(cls, output_folder, raise_error, verbose): has_error = True run_time = None log["error"] = True - log["error_trace"] = traceback.format_exc() + error_log_to_display = traceback.format_exc() + trace_lines = error_log_to_display.strip().split("\n") + error_to_json = ["Traceback (most recent call last):"] + [ + f" {line}" if not line.startswith(" ") else line for line in trace_lines[1:] + ] + log["error_trace"] = error_to_json log["error"] = has_error log["run_time"] = run_time @@ -290,7 +295,7 @@ def run_from_folder(cls, output_folder, raise_error, verbose): if has_error and raise_error: raise SpikeSortingError( - f"Spike sorting error trace:\n{log['error_trace']}\n" + f"Spike sorting error trace:\n{error_log_to_display}\n" f"Spike sorting failed. You can inspect the runtime trace in {output_folder}/spikeinterface_log.json." ) From 2016de841178030f49a20df2749556479a4aa4ac Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 20 Jun 2024 21:14:52 +0100 Subject: [PATCH 041/103] Remove duplicate function from common test suite. --- .../postprocessing/tests/common_extension_tests.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index c99b2d4f3b..bb2f5aaafd 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -77,19 +77,6 @@ class instance is used for each. In this case, we have to set ) self.__class__.cache_folder = create_cache_folder - def _prepare_sorting_analyzer(self, format, sparse, extension_class): - """ - Prepare a SortingAnalyzer object with dependencies already computed - according to format (e.g. "memory", "binary_folder", "zarr") - and sparsity (e.g. True, False). - """ - sparsity_ = self.sparsity if sparse else None - - sorting_analyzer = self.get_sorting_analyzer( - self.recording, self.sorting, format=format, sparsity=sparsity_, name=extension_class.extension_name - ) - return sorting_analyzer - def get_sorting_analyzer(self, recording, sorting, format="memory", sparsity=None, name=""): sparse = sparsity is not None From 864d1d3237c206226b850409d4ee0bd12d65a32b Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 20 Jun 2024 21:17:35 +0100 Subject: [PATCH 042/103] try and fix access issue by blinding rerunning tests. From 7bdefe5c993678da6f0d618d2958c5c0e2c5e6d6 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Fri, 21 Jun 2024 09:41:29 +0100 Subject: [PATCH 043/103] Force tests again in the vain hope that doing nothing overnight has fixed the issue. From 6abb74b84bc766c801ac366a8678f6cdafb2a06c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Jun 2024 09:37:35 +0000 Subject: [PATCH 044/103] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sorting_tools.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 5ac3fcc822..65a65875e1 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -47,7 +47,7 @@ def spike_vector_to_spike_trains(spike_vector: list[np.array], unit_ids: np.arra return spike_trains -def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array, absolute_index : bool = False): +def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array, absolute_index: bool = False): """ Similar to spike_vector_to_spike_trains but instead having the spike_trains (aka spike times) return spike indices by segment and units. @@ -66,7 +66,7 @@ def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array, ab or relative to segment usefull with a list of spike vectors When a unique spike vectors (or amplitudes) is used then absolute_index should be True. When a list of spikes (or amplitudes) is used then absolute_index should be False. - + Returns ------- spike_indices: dict[dict]: @@ -88,7 +88,7 @@ def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array, ab num_units = unit_ids.size spike_indices = {} - + total_spikes = 0 for segment_index, spikes in enumerate(spike_vector): indices = np.arange(spikes.size, dtype=np.int64) @@ -100,7 +100,6 @@ def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array, ab spike_indices[segment_index] = dict(zip(unit_ids, list_of_spike_indices)) - return spike_indices From 34e0e32dd9210245199b2326126f1ae626b41ee2 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jun 2024 11:48:50 +0200 Subject: [PATCH 045/103] Check start_frame/end_frame in BaseRecording.get_traces() rather than individual segment.get_traces() --- src/spikeinterface/core/baserecording.py | 2 ++ src/spikeinterface/core/frameslicerecording.py | 4 ---- src/spikeinterface/core/segmentutils.py | 5 ----- src/spikeinterface/extractors/cbin_ibl.py | 4 ---- src/spikeinterface/extractors/iblextractors.py | 4 ---- src/spikeinterface/extractors/nwbextractors.py | 5 ----- .../preprocessing/average_across_direction.py | 5 ----- src/spikeinterface/preprocessing/decimate.py | 7 ------- .../preprocessing/deepinterpolation/deepinterpolation.py | 8 -------- .../preprocessing/directional_derivative.py | 5 ----- src/spikeinterface/preprocessing/phase_shift.py | 4 ---- src/spikeinterface/preprocessing/remove_artifacts.py | 5 ----- src/spikeinterface/preprocessing/resample.py | 5 ----- src/spikeinterface/preprocessing/silence_periods.py | 6 ------ src/spikeinterface/preprocessing/zero_channel_pad.py | 9 --------- 15 files changed, 2 insertions(+), 76 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 184959512b..f16707f31c 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -331,6 +331,8 @@ def get_traces( segment_index = self._check_segment_index(segment_index) channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True) rs = self._recording_segments[segment_index] + start_frame = int(start_frame) if start_frame is not None else 0 + end_frame = int(min(end_frame, rs.get_num_samples())) if end_frame is not None else rs.get_num_samples() traces = rs.get_traces(start_frame=start_frame, end_frame=end_frame, channel_indices=channel_indices) if order is not None: assert order in ["C", "F"] diff --git a/src/spikeinterface/core/frameslicerecording.py b/src/spikeinterface/core/frameslicerecording.py index 533328ad42..7831dd61a1 100644 --- a/src/spikeinterface/core/frameslicerecording.py +++ b/src/spikeinterface/core/frameslicerecording.py @@ -86,10 +86,6 @@ def get_num_samples(self): return self.end_frame - self.start_frame def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() parent_start = self.start_frame + start_frame parent_end = self.start_frame + end_frame traces = self._parent_recording_segment.get_traces( diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index 959b7f8c43..b23b7202c6 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -163,11 +163,6 @@ def get_num_samples(self): return self.total_length def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - # # Ensures that we won't request invalid segment indices if (start_frame >= self.get_num_samples()) or (end_frame <= start_frame): # Return (0 * num_channels) array of correct dtype diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index e5ff8ed371..a6da19408f 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -134,10 +134,6 @@ def get_num_samples(self): return self._cbuffer.shape[0] def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() if channel_indices is None: channel_indices = slice(None) diff --git a/src/spikeinterface/extractors/iblextractors.py b/src/spikeinterface/extractors/iblextractors.py index 6e3ee59cad..2444314aec 100644 --- a/src/spikeinterface/extractors/iblextractors.py +++ b/src/spikeinterface/extractors/iblextractors.py @@ -269,10 +269,6 @@ def get_num_samples(self): return self._file_streamer.ns def get_traces(self, start_frame: int, end_frame: int, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() if channel_indices is None: channel_indices = slice(None) traces = self._file_streamer.read(nsel=slice(start_frame, end_frame), volts=False) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 1f413ae2b0..ccb2ff4370 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -932,11 +932,6 @@ def get_num_samples(self): return self._num_samples def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - electrical_series_data = self.electrical_series_data if electrical_series_data.ndim == 1: traces = electrical_series_data[start_frame:end_frame][:, np.newaxis] diff --git a/src/spikeinterface/preprocessing/average_across_direction.py b/src/spikeinterface/preprocessing/average_across_direction.py index 71051f07ab..53f0d54147 100644 --- a/src/spikeinterface/preprocessing/average_across_direction.py +++ b/src/spikeinterface/preprocessing/average_across_direction.py @@ -116,11 +116,6 @@ def get_num_samples(self): return self.parent_recording_segment.get_num_samples() def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - parent_traces = self.parent_recording_segment.get_traces( start_frame=start_frame, end_frame=end_frame, diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index 8c4970c4e4..aa5c600182 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -123,13 +123,6 @@ def get_num_samples(self): return int(np.ceil((parent_n_samp - self._decimation_offset) / self._decimation_factor)) def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - end_frame = min(end_frame, self.get_num_samples()) - start_frame = min(start_frame, self.get_num_samples()) - # Account for offset and end when querying parent traces parent_start_frame = self._decimation_offset + start_frame * self._decimation_factor parent_end_frame = parent_start_frame + (end_frame - start_frame) * self._decimation_factor diff --git a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py index 80b212deda..90dbdba6da 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py @@ -148,14 +148,6 @@ def __init__( def get_traces(self, start_frame, end_frame, channel_indices): from .generators import SpikeInterfaceRecordingSegmentGenerator - n_frames = self.parent_recording_segment.get_num_samples() - - if start_frame == None: - start_frame = 0 - - if end_frame == None: - end_frame = n_frames - # for frames that lack full training data (i.e. pre and post frames including omissinos), # just return uninterpolated if start_frame < self.pre_frame + self.pre_post_omission: diff --git a/src/spikeinterface/preprocessing/directional_derivative.py b/src/spikeinterface/preprocessing/directional_derivative.py index 5e77cc8ae6..f8aeac05fc 100644 --- a/src/spikeinterface/preprocessing/directional_derivative.py +++ b/src/spikeinterface/preprocessing/directional_derivative.py @@ -103,11 +103,6 @@ def __init__( self.unique_pos_other_dims, self.column_inds = np.unique(geom_other_dims, axis=0, return_inverse=True) def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - parent_traces = self.parent_recording_segment.get_traces( start_frame=start_frame, end_frame=end_frame, diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index ca93d58364..5d483b3ce2 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -84,10 +84,6 @@ def __init__(self, parent_recording_segment, sample_shifts, margin, dtype, tmp_d self.tmp_dtype = tmp_dtype def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() if channel_indices is None: channel_indices = slice(None) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 3c0f766737..d2aef6ba3a 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -263,11 +263,6 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) traces = traces.copy() - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - mask = (self.triggers >= start_frame) & (self.triggers < end_frame) triggers = self.triggers[mask] - start_frame labels = self.labels[mask] diff --git a/src/spikeinterface/preprocessing/resample.py b/src/spikeinterface/preprocessing/resample.py index 4843df5444..f8324817d4 100644 --- a/src/spikeinterface/preprocessing/resample.py +++ b/src/spikeinterface/preprocessing/resample.py @@ -115,11 +115,6 @@ def get_num_samples(self): return int(self._parent_segment.get_num_samples() / self._parent_rate * self.sampling_frequency) def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - # get parent traces with margin parent_start_frame, parent_end_frame = [ int((frame / self.sampling_frequency) * self._parent_rate) for frame in [start_frame, end_frame] diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 74d370b3a9..3758d29554 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -111,12 +111,6 @@ def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) traces = traces.copy() - num_channels = traces.shape[1] - - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() if len(self.periods) > 0: new_interval = np.array([start_frame, end_frame]) diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index cf4ba6a4a2..0b2ff9449f 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -75,11 +75,6 @@ def __init__( super().__init__(parent_recording_segment=recording_segment) def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - # This contains the padded elements by default and we add the original traces if necessary trace_size = end_frame - start_frame if isinstance(channel_indices, (np.ndarray, list)): @@ -200,10 +195,6 @@ def __init__(self, recording_segment: BaseRecordingSegment, num_channels: int, c self.channel_mapping = channel_mapping def get_traces(self, start_frame, end_frame, channel_indices): - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() traces = np.zeros((end_frame - start_frame, self.num_channels)) traces[:, self.channel_mapping] = self.parent_recording_segment.get_traces( start_frame=start_frame, end_frame=end_frame, channel_indices=self.channel_mapping From 750f4cb124f53a2dc7d8cf73bf82160d86a93595 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jun 2024 11:54:56 +0200 Subject: [PATCH 046/103] A few more --- src/spikeinterface/core/baserecording.py | 2 +- src/spikeinterface/core/generate.py | 10 ---------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index f16707f31c..16f246f280 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -331,7 +331,7 @@ def get_traces( segment_index = self._check_segment_index(segment_index) channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True) rs = self._recording_segments[segment_index] - start_frame = int(start_frame) if start_frame is not None else 0 + start_frame = int(max(0, start_frame)) if start_frame is not None else 0 end_frame = int(min(end_frame, rs.get_num_samples())) if end_frame is not None else rs.get_num_samples() traces = rs.get_traces(start_frame=start_frame, end_frame=end_frame, channel_indices=channel_indices) if order is not None: diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 251678e675..70f3f120c8 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1143,11 +1143,6 @@ def get_traces( end_frame: Union[int, None] = None, channel_indices: Union[List, None] = None, ) -> np.ndarray: - start_frame = 0 if start_frame is None else max(start_frame, 0) - end_frame = self.num_samples if end_frame is None else min(end_frame, self.num_samples) - start_frame = int(start_frame) - end_frame = int(end_frame) - start_frame_within_block = start_frame % self.noise_block_size end_frame_within_block = end_frame % self.noise_block_size num_samples = end_frame - start_frame @@ -1812,11 +1807,6 @@ def get_traces( end_frame: Union[int, None] = None, channel_indices: Union[List, None] = None, ) -> np.ndarray: - start_frame = 0 if start_frame is None else start_frame - end_frame = self.num_samples if end_frame is None else end_frame - start_frame = int(start_frame) - end_frame = int(end_frame) - if channel_indices is None: n_channels = self.templates.shape[2] elif isinstance(channel_indices, slice): From 04fef83c22ccdf8ddbd01dec648c93d9e72a02e5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jun 2024 11:59:13 +0200 Subject: [PATCH 047/103] Fix deepinterpolation --- .../preprocessing/deepinterpolation/deepinterpolation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py index 90dbdba6da..31ebb90831 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py @@ -148,6 +148,8 @@ def __init__( def get_traces(self, start_frame, end_frame, channel_indices): from .generators import SpikeInterfaceRecordingSegmentGenerator + n_frames = self.parent_recording_segment.get_num_samples() + # for frames that lack full training data (i.e. pre and post frames including omissinos), # just return uninterpolated if start_frame < self.pre_frame + self.pre_post_omission: From c1c9f1f1d0d9967e3c08cf3e92aa520d9e9d289b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jun 2024 12:03:42 +0200 Subject: [PATCH 048/103] Update src/spikeinterface/core/sorting_tools.py --- src/spikeinterface/core/sorting_tools.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 65a65875e1..6045442466 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -62,10 +62,9 @@ def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array, ab unit_ids: np.array Unit ids absolute_index: bool, default False - Give spike indices absolut usefull when having a unique spike vector - or relative to segment usefull with a list of spike vectors - When a unique spike vectors (or amplitudes) is used then absolute_index should be True. - When a list of spikes (or amplitudes) is used then absolute_index should be False. + It True, return absolute spike indices, else spike indices are relative to the segment. + When a unique spike vector is used, then absolute_index should be True. + When a list of spikes per segment is used, then absolute_index should be False. Returns ------- From 391db33a9aad3c49718a415da6208c6d92add7d4 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 21 Jun 2024 12:04:13 +0200 Subject: [PATCH 049/103] Update src/spikeinterface/core/sorting_tools.py --- src/spikeinterface/core/sorting_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 6045442466..02f4529a98 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -62,7 +62,7 @@ def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array, ab unit_ids: np.array Unit ids absolute_index: bool, default False - It True, return absolute spike indices, else spike indices are relative to the segment. + It True, return absolute spike indices. If False, spike indices are relative to the segment. When a unique spike vector is used, then absolute_index should be True. When a list of spikes per segment is used, then absolute_index should be False. From a5926907a783410eb483ec68fef72b2a81bd06f1 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 21 Jun 2024 17:01:54 -0400 Subject: [PATCH 050/103] fix the probe handling tutorial --- .../core/plot_3_handle_probe_info.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/examples/tutorials/core/plot_3_handle_probe_info.py b/examples/tutorials/core/plot_3_handle_probe_info.py index 157efb683f..50905871bc 100644 --- a/examples/tutorials/core/plot_3_handle_probe_info.py +++ b/examples/tutorials/core/plot_3_handle_probe_info.py @@ -47,16 +47,24 @@ plot_probe(recording_2_shanks.get_probe()) ############################################################################### -# Now let's check what we have loaded. The `group_mode='by_shank'` automatically +# Now let's check what we have loaded. The :code:`group_mode='by_shank'` automatically # sets the 'group' property depending on the shank id. -# We can use this information to split the recording into two sub-recordings: +# We can use this information to split the recording into two sub-recordings. +# We can acccess this information either as a dict with :code:`outputs='dict'` (default) +# or as a list of recordings with :code:`outputs='list'`. print(recording_2_shanks) -print(recording_2_shanks.get_property("group")) +print(f'/nGroup Property: {recording_2_shanks.get_property("group")}/n') -rec0, rec1 = recording_2_shanks.split_by(property="group") -print(rec0) -print(rec1) +# Here we split as a dict +sub_recording_dict = recording_2_shanks.split_by(property="group", outputs='dict') +print(sub_recording_dict, '/n') + +# Then we can pull out the individual sub-recordings +sub_rec0 = sub_recording_dict[0] +sub_rec1 = sub_recording_dict[1] +print(sub_rec0, '/n') +print(sub_rec1) ############################################################################### # Note that some formats (MEArec, SpikeGLX) automatically handle the probe From 8f543534c529c90241bef81238cde9c4e2badd12 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 21 Jun 2024 17:39:08 -0400 Subject: [PATCH 051/103] fix slashes --- examples/tutorials/core/plot_3_handle_probe_info.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/tutorials/core/plot_3_handle_probe_info.py b/examples/tutorials/core/plot_3_handle_probe_info.py index 50905871bc..99a92b8a0a 100644 --- a/examples/tutorials/core/plot_3_handle_probe_info.py +++ b/examples/tutorials/core/plot_3_handle_probe_info.py @@ -54,16 +54,15 @@ # or as a list of recordings with :code:`outputs='list'`. print(recording_2_shanks) -print(f'/nGroup Property: {recording_2_shanks.get_property("group")}/n') +print(f'\nGroup Property: {recording_2_shanks.get_property("group")}\n') # Here we split as a dict sub_recording_dict = recording_2_shanks.split_by(property="group", outputs='dict') -print(sub_recording_dict, '/n') # Then we can pull out the individual sub-recordings sub_rec0 = sub_recording_dict[0] sub_rec1 = sub_recording_dict[1] -print(sub_rec0, '/n') +print(sub_rec0, '\n') print(sub_rec1) ############################################################################### From a71eff4109c945bcfb1f384c39f6c4127629fa2f Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 21 Jun 2024 18:25:22 -0400 Subject: [PATCH 052/103] add typing --- src/spikeinterface/core/base.py | 6 +++--- src/spikeinterface/core/baserecordingsnippets.py | 8 ++++---- src/spikeinterface/core/basesorting.py | 10 ++++------ src/spikeinterface/core/binaryfolder.py | 2 +- src/spikeinterface/core/binaryrecordingextractor.py | 2 +- src/spikeinterface/core/core_tools.py | 13 ++++++++++--- src/spikeinterface/core/frameslicerecording.py | 2 +- src/spikeinterface/core/generate.py | 2 +- src/spikeinterface/core/numpyextractors.py | 2 +- src/spikeinterface/core/recording_tools.py | 12 +++++++++++- src/spikeinterface/core/sortinganalyzer.py | 2 +- src/spikeinterface/core/waveform_tools.py | 2 +- src/spikeinterface/sorters/basesorter.py | 6 +++--- 13 files changed, 42 insertions(+), 27 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 6fbc5ac289..4922707b35 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -550,7 +550,7 @@ def check_serializability(self, type): return False return self._serializability[type] - def check_if_memory_serializable(self): + def check_if_memory_serializable(self) -> bool: """ Check if the object is serializable to memory with pickle, including nested objects. @@ -561,7 +561,7 @@ def check_if_memory_serializable(self): """ return self.check_serializability("memory") - def check_if_json_serializable(self): + def check_if_json_serializable(self) -> bool: """ Check if the object is json serializable, including nested objects. @@ -574,7 +574,7 @@ def check_if_json_serializable(self): # is this needed ??? I think no. return self.check_serializability("json") - def check_if_pickle_serializable(self): + def check_if_pickle_serializable(self) -> bool: # is this needed ??? I think no. return self.check_serializability("pickle") diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 2a9f075954..428472bf93 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -48,7 +48,7 @@ def get_num_channels(self): def get_dtype(self): return self._dtype - def has_scaleable_traces(self): + def has_scaleable_traces(self) -> bool: if self.get_property("gain_to_uV") is None or self.get_property("offset_to_uV") is None: return False else: @@ -62,10 +62,10 @@ def has_scaled(self): ) return self.has_scaleable_traces() - def has_probe(self): + def has_probe(self) -> bool: return "contact_vector" in self.get_property_keys() - def has_channel_location(self): + def has_channel_location(self) -> bool: return self.has_probe() or "location" in self.get_property_keys() def is_filtered(self): @@ -366,7 +366,7 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy"): locations = np.asarray(locations)[channel_indices] return select_axes(locations, axes) - def has_3d_locations(self): + def has_3d_locations(self) -> bool: return self.get_property("location").shape[1] == 3 def clear_channel_locations(self, channel_ids=None): diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 7214d2780e..fd68df9dda 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import List, Optional, Union +from typing import Optional, Union import numpy as np @@ -73,7 +73,7 @@ def unit_ids(self): def sampling_frequency(self): return self._sampling_frequency - def get_unit_ids(self) -> List: + def get_unit_ids(self) -> list: return self._main_ids def get_num_units(self) -> int: @@ -121,7 +121,7 @@ def get_total_samples(self) -> int: s += self.get_num_samples(segment_index) return s - def get_total_duration(self): + def get_total_duration(self) -> float: """Returns the total duration in s of the associated recording. Returns @@ -219,7 +219,7 @@ def set_sorting_info(self, recording_dict, params_dict, log_dict): def has_recording(self): return self._recording is not None - def has_time_vector(self, segment_index=None): + def has_time_vector(self, segment_index=None) -> bool: """ Check if the segment of the registered recording has a time vector. """ @@ -515,8 +515,6 @@ def precompute_spike_trains(self, from_spike_vector=None): """ Pre-computes and caches all spike trains for this sorting - - Parameters ---------- from_spike_vector : None | bool, default: None diff --git a/src/spikeinterface/core/binaryfolder.py b/src/spikeinterface/core/binaryfolder.py index ec9bdfcc5e..546ac85f93 100644 --- a/src/spikeinterface/core/binaryfolder.py +++ b/src/spikeinterface/core/binaryfolder.py @@ -53,7 +53,7 @@ def __init__(self, folder_path): assert "num_chan" in self._bin_kwargs, "Cannot find num_channels or num_chan in binary.json" self._bin_kwargs["num_channels"] = self._bin_kwargs["num_chan"] - def is_binary_compatible(self): + def is_binary_compatible(self) -> bool: return True def get_binary_description(self): diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index 5d72532704..8fb9a78f2a 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -147,7 +147,7 @@ def write_recording(recording, file_paths, dtype=None, **job_kwargs): """ write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs) - def is_binary_compatible(self): + def is_binary_compatible(self) -> bool: return True def get_binary_description(self): diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 664eac169f..f3d8b3df7f 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -168,9 +168,14 @@ def make_shared_array(shape, dtype): return arr, shm -def is_dict_extractor(d): +def is_dict_extractor(d: dict) -> bool: """ - Check if a dict describe an extractor. + Check if a dict describes an extractor. + + Returns + ------- + is_extractor : bool + Whether the dict describes an extractor """ if not isinstance(d, dict): return False @@ -283,6 +288,7 @@ def check_paths_relative(input_dict, relative_folder) -> bool: Returns ------- relative_possible: bool + Whether the given input can be made relative to the relative_folder """ path_list = _get_paths_list(input_dict) relative_folder = Path(relative_folder).resolve().absolute() @@ -513,7 +519,8 @@ def normal_pdf(x, mu: float = 0.0, sigma: float = 1.0): def retrieve_importing_provenance(a_class): """ - Retrieve the import provenance of a class, including its import name (that consists of the class name and the module), the top-level module, and the module version. + Retrieve the import provenance of a class, including its import name (that consists of the class name and the module), + the top-level module, and the module version. Parameters ---------- diff --git a/src/spikeinterface/core/frameslicerecording.py b/src/spikeinterface/core/frameslicerecording.py index 133cbf886c..5c91d3cae1 100644 --- a/src/spikeinterface/core/frameslicerecording.py +++ b/src/spikeinterface/core/frameslicerecording.py @@ -82,7 +82,7 @@ def __init__(self, parent_recording_segment, start_frame, end_frame): self.start_frame = start_frame self.end_frame = end_frame - def get_num_samples(self): + def get_num_samples(self) -> int: return self.end_frame - self.start_frame def get_traces(self, start_frame, end_frame, channel_indices): diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 251678e675..370f5b42c6 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1134,7 +1134,7 @@ def __init__( elif self.strategy == "on_the_fly": pass - def get_num_samples(self): + def get_num_samples(self) -> int: return self.num_samples def get_traces( diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 62cd2fe2cf..0ba1c05417 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -110,7 +110,7 @@ def __init__(self, traces, sampling_frequency, t_start): self._traces = traces self.num_samples = traces.shape[0] - def get_num_samples(self): + def get_num_samples(self) -> int: return self.num_samples def get_traces(self, start_frame, end_frame, channel_indices): diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 8b1b293543..b4c07e77c9 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -862,7 +862,17 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y"), def check_probe_do_not_overlap(probes): """ When several probes this check that that they do not overlap in space - and so channel positions can be safly concatenated. + and so channel positions can be safely concatenated. + + Raises + ------ + Exception : + If probes are overlapping + + Returns + ------- + None : None + If the check is successful """ for i in range(len(probes)): probe_i = probes[i] diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 0094012013..e439ddf1ed 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1229,7 +1229,7 @@ def get_computable_extensions(self): """ return get_available_analyzer_extensions() - def get_default_extension_params(self, extension_name: str): + def get_default_extension_params(self, extension_name: str) -> dict: """ Get the default params for an extension. diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index acc368b2e5..befc49d034 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -679,7 +679,7 @@ def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None return waveforms_by_units -def has_exceeding_spikes(recording, sorting): +def has_exceeding_spikes(recording, sorting) -> bool: """ Check if the sorting objects has spikes exceeding the recording number of samples, for all segments diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 8c52626703..a9513f9f5a 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -343,7 +343,7 @@ def get_result_from_folder(cls, output_folder, register_recording=True, sorting_ return sorting @classmethod - def check_compiled(cls): + def check_compiled(cls) -> bool: """ Checks if the sorter is running inside an image with matlab-compiled version @@ -370,7 +370,7 @@ def check_compiled(cls): return True @classmethod - def use_gpu(cls, params): + def use_gpu(cls, params) -> bool: return cls.gpu_capability != "not-supported" ############################################# @@ -436,7 +436,7 @@ def get_job_kwargs(params, verbose): return job_kwargs -def is_log_ok(output_folder): +def is_log_ok(output_folder) -> bool: # log is OK when run_time is not None if (output_folder / "spikeinterface_log.json").is_file(): with open(output_folder / "spikeinterface_log.json", mode="r", encoding="utf8") as logfile: From 048fa788a58580673f030af756be274b859bc65e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 22 Jun 2024 15:54:52 +0200 Subject: [PATCH 053/103] Add checks for start/end_frames --- src/spikeinterface/core/baserecording.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 16f246f280..f69d6d25f8 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -331,8 +331,12 @@ def get_traces( segment_index = self._check_segment_index(segment_index) channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True) rs = self._recording_segments[segment_index] - start_frame = int(max(0, start_frame)) if start_frame is not None else 0 + start_frame = int(start_frame) if start_frame is not None else 0 end_frame = int(min(end_frame, rs.get_num_samples())) if end_frame is not None else rs.get_num_samples() + if start_frame < 0: + raise ValueError("start_frame cannot be negative") + if start_frame > end_frame: + raise ValueError("start_frame cannot be greater than end_frame") traces = rs.get_traces(start_frame=start_frame, end_frame=end_frame, channel_indices=channel_indices) if order is not None: assert order in ["C", "F"] From a3a27cf217ab767b1531dd7082ad3bfed190ba3d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 22 Jun 2024 16:00:34 +0200 Subject: [PATCH 054/103] Fix failing binaryrecordingextractor test --- src/spikeinterface/core/tests/test_binaryrecordingextractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py index 61af8f322d..8ea99e3d04 100644 --- a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py +++ b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py @@ -33,7 +33,7 @@ def test_BinaryRecordingExtractor(create_cache_folder): def test_round_trip(tmp_path): num_channels = 10 - num_samples = 50 + num_samples = 500 traces_list = [np.ones(shape=(num_samples, num_channels), dtype="int32")] sampling_frequency = 30_000.0 recording = NumpyRecording(traces_list=traces_list, sampling_frequency=sampling_frequency) From a18ea3b9d9dc5f5edc2b35c459daa7d573646141 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 22 Jun 2024 17:03:19 +0200 Subject: [PATCH 055/103] Remove check on start_frame > end_frame --- src/spikeinterface/core/baserecording.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index f69d6d25f8..4d924e9003 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -335,8 +335,6 @@ def get_traces( end_frame = int(min(end_frame, rs.get_num_samples())) if end_frame is not None else rs.get_num_samples() if start_frame < 0: raise ValueError("start_frame cannot be negative") - if start_frame > end_frame: - raise ValueError("start_frame cannot be greater than end_frame") traces = rs.get_traces(start_frame=start_frame, end_frame=end_frame, channel_indices=channel_indices) if order is not None: assert order in ["C", "F"] From bc933715b2d2895ee21e3de4064e4c06c30e2f27 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Sat, 22 Jun 2024 14:48:21 -0400 Subject: [PATCH 056/103] Alessio's fix Co-authored-by: Alessio Buccino --- examples/tutorials/core/plot_3_handle_probe_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tutorials/core/plot_3_handle_probe_info.py b/examples/tutorials/core/plot_3_handle_probe_info.py index 99a92b8a0a..deff58ebb7 100644 --- a/examples/tutorials/core/plot_3_handle_probe_info.py +++ b/examples/tutorials/core/plot_3_handle_probe_info.py @@ -50,7 +50,7 @@ # Now let's check what we have loaded. The :code:`group_mode='by_shank'` automatically # sets the 'group' property depending on the shank id. # We can use this information to split the recording into two sub-recordings. -# We can acccess this information either as a dict with :code:`outputs='dict'` (default) +# We can access this information either as a dict with :code:`outputs='dict'` (default) # or as a list of recordings with :code:`outputs='list'`. print(recording_2_shanks) From fd6369b8b6493258b4a36d88567f392125e0bf58 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sat, 22 Jun 2024 17:59:28 -0600 Subject: [PATCH 057/103] use names as channel ids in plexon --- src/spikeinterface/extractors/neoextractors/plexon2.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index fe24ba6f46..256c112e6f 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -30,7 +30,12 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) NeoBaseRecordingExtractor.__init__( - self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs + self, + stream_id=stream_id, + stream_name=stream_name, + all_annotations=all_annotations, + use_names_as_ids=True, + **neo_kwargs, ) self._kwargs.update({"file_path": str(file_path)}) From 53b3ec9bdf49c1f93aa6e03a1ecdc55a1ba00a8f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sat, 22 Jun 2024 18:27:13 -0600 Subject: [PATCH 058/103] add docstring and propagate arugment to signature --- .../extractors/neoextractors/plexon2.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 256c112e6f..941158def1 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -19,6 +19,13 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream id you want to load. stream_name : str, default: None If there are several streams, specify the stream name you want to load. + use_names_as_ids: + If True, the names of the signals are used as channel ids. If False, the channel ids are a combination of the + source id and the channel index. + + Example for widegain signals: + names: ["WB01", "WB02", "WB03", "WB04"] + ids: ["source3.1" , "source3.2", "source3.3", "source3.4"] all_annotations : bool, default: False Load exhaustively all annotations from neo. """ @@ -27,14 +34,14 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): NeoRawIOClass = "Plexon2RawIO" name = "plexon2" - def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): + def __init__(self, file_path, stream_id=None, stream_name=None, use_names_as_ids=True, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, - use_names_as_ids=True, + use_names_as_ids=use_names_as_ids, **neo_kwargs, ) self._kwargs.update({"file_path": str(file_path)}) From f533225c7236b3c3133e657286b64a5435abc032 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jun 2024 11:38:50 +0200 Subject: [PATCH 059/103] Add unsigned offset to sinaps extractor, typing, docs, and cleaning --- .../extractors/sinapsrecordingextractor.py | 114 -------- .../extractors/sinapsrecordingextractors.py | 258 ++++++++++++++++++ .../extractors/sinapsrecordingh5extractor.py | 149 ---------- 3 files changed, 258 insertions(+), 263 deletions(-) delete mode 100644 src/spikeinterface/extractors/sinapsrecordingextractor.py create mode 100644 src/spikeinterface/extractors/sinapsrecordingextractors.py delete mode 100644 src/spikeinterface/extractors/sinapsrecordingh5extractor.py diff --git a/src/spikeinterface/extractors/sinapsrecordingextractor.py b/src/spikeinterface/extractors/sinapsrecordingextractor.py deleted file mode 100644 index 1f35407c33..0000000000 --- a/src/spikeinterface/extractors/sinapsrecordingextractor.py +++ /dev/null @@ -1,114 +0,0 @@ -from pathlib import Path -import numpy as np - -from probeinterface import get_probe - -from ..core import BinaryRecordingExtractor, ChannelSliceRecording -from ..core.core_tools import define_function_from_class - - -class SinapsResearchPlatformRecordingExtractor(ChannelSliceRecording): - extractor_name = "SinapsResearchPlatform" - mode = "file" - name = "sinaps_research_platform" - - def __init__(self, file_path, stream_name="filt"): - from ..preprocessing import UnsignedToSignedRecording - - file_path = Path(file_path) - meta_file = file_path.parent / f"metadata_{file_path.stem}.txt" - meta = parse_sinaps_meta(meta_file) - - num_aux_channels = meta["nbHWAux"] + meta["numberUserAUX"] - num_total_channels = 2 * meta["nbElectrodes"] + num_aux_channels - num_electrodes = meta["nbElectrodes"] - sampling_frequency = meta["samplingFreq"] - - probe_type = meta["probeType"] - # channel_locations = meta["electrodePhysicalPosition"] # will be depricated soon by Sam, switching to probeinterface - num_shanks = meta["nbShanks"] - num_electrodes_per_shank = meta["nbElectrodesShank"] - num_bits = int(np.log2(meta["nbADCLevels"])) - - # channel_groups = [] - # for i in range(num_shanks): - # channel_groups.extend([i] * num_electrodes_per_shank) - - gain_ephys = meta["voltageConverter"] - gain_aux = meta["voltageAUXConverter"] - - recording = BinaryRecordingExtractor( - file_path, sampling_frequency, dtype="uint16", num_channels=num_total_channels - ) - recording = UnsignedToSignedRecording(recording, bit_depth=num_bits) - - if stream_name == "raw": - channel_slice = recording.channel_ids[:num_electrodes] - renamed_channels = np.arange(num_electrodes) - # locations = channel_locations - # groups = channel_groups - gain = gain_ephys - elif stream_name == "filt": - channel_slice = recording.channel_ids[num_electrodes : 2 * num_electrodes] - renamed_channels = np.arange(num_electrodes) - # locations = channel_locations - # groups = channel_groups - gain = gain_ephys - elif stream_name == "aux": - channel_slice = recording.channel_ids[2 * num_electrodes :] - hw_chans = meta["hwAUXChannelName"][1:-1].split(",") - user_chans = meta["userAuxName"][1:-1].split(",") - renamed_channels = hw_chans + user_chans - # locations = None - # groups = None - gain = gain_aux - else: - raise ValueError("stream_name must be 'raw', 'filt', or 'aux'") - - ChannelSliceRecording.__init__(self, recording, channel_ids=channel_slice, renamed_channel_ids=renamed_channels) - # if locations is not None: - # self.set_channel_locations(locations) - # if groups is not None: - # self.set_channel_groups(groups) - - self.set_channel_gains(gain) - self.set_channel_offsets(0) - - if (stream_name == "filt") | (stream_name == "raw"): - if probe_type == "p1024s1NHP": - probe = get_probe(manufacturer="sinaps", probe_name="SiNAPS-p1024s1NHP") - # now wire the probe - channel_indices = np.arange(1024) - probe.set_device_channel_indices(channel_indices) - self.set_probe(probe, in_place=True) - else: - raise ValueError(f"Unknown probe type: {probe_type}") - - self._kwargs = {"file_path": str(file_path.absolute())} - - -read_sinaps_research_platform = define_function_from_class( - source_class=SinapsResearchPlatformRecordingExtractor, name="read_sinaps_research_platform" -) - - -def parse_sinaps_meta(meta_file): - meta_dict = {} - with open(meta_file) as f: - lines = f.readlines() - for l in lines: - if "**" in l or "=" not in l: - continue - else: - key, val = l.split("=") - val = val.replace("\n", "") - try: - val = int(val) - except: - pass - try: - val = eval(val) - except: - pass - meta_dict[key] = val - return meta_dict diff --git a/src/spikeinterface/extractors/sinapsrecordingextractors.py b/src/spikeinterface/extractors/sinapsrecordingextractors.py new file mode 100644 index 0000000000..df86085dc7 --- /dev/null +++ b/src/spikeinterface/extractors/sinapsrecordingextractors.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +import warnings +from pathlib import Path +import numpy as np + +from probeinterface import get_probe + +from ..core import BaseRecording, BaseRecordingSegment, BinaryRecordingExtractor, ChannelSliceRecording +from ..core.core_tools import define_function_from_class + + +class SinapsResearchPlatformRecordingExtractor(ChannelSliceRecording): + """ + Recording extractor for the SiNAPS research platform system saved in binary format. + + Parameters + ---------- + file_path : str | Path + Path to the SiNAPS .bin file. + stream_name : "filt" | "raw" | "aux", default: "filt" + The stream name to extract. + "filt" extracts the filtered data, "raw" extracts the raw data, and "aux" extracts the auxiliary data. + """ + + extractor_name = "SinapsResearchPlatform" + mode = "file" + name = "sinaps_research_platform" + + def __init__(self, file_path: str | Path, stream_name: str = "filt"): + from ..preprocessing import UnsignedToSignedRecording + + file_path = Path(file_path) + meta_file = file_path.parent / f"metadata_{file_path.stem}.txt" + meta = parse_sinaps_meta(meta_file) + + num_aux_channels = meta["nbHWAux"] + meta["numberUserAUX"] + num_total_channels = 2 * meta["nbElectrodes"] + num_aux_channels + num_electrodes = meta["nbElectrodes"] + sampling_frequency = meta["samplingFreq"] + + probe_type = meta["probeType"] + num_bits = int(np.log2(meta["nbADCLevels"])) + + gain_ephys = meta["voltageConverter"] + gain_aux = meta["voltageAUXConverter"] + + recording = BinaryRecordingExtractor( + file_path, sampling_frequency, dtype="uint16", num_channels=num_total_channels + ) + recording = UnsignedToSignedRecording(recording, bit_depth=num_bits) + + if stream_name == "raw": + channel_slice = recording.channel_ids[:num_electrodes] + renamed_channels = np.arange(num_electrodes) + gain = gain_ephys + elif stream_name == "filt": + channel_slice = recording.channel_ids[num_electrodes : 2 * num_electrodes] + renamed_channels = np.arange(num_electrodes) + gain = gain_ephys + elif stream_name == "aux": + channel_slice = recording.channel_ids[2 * num_electrodes :] + hw_chans = meta["hwAUXChannelName"][1:-1].split(",") + user_chans = meta["userAuxName"][1:-1].split(",") + renamed_channels = hw_chans + user_chans + gain = gain_aux + else: + raise ValueError("stream_name must be 'raw', 'filt', or 'aux'") + + ChannelSliceRecording.__init__(self, recording, channel_ids=channel_slice, renamed_channel_ids=renamed_channels) + + self.set_channel_gains(gain) + self.set_channel_offsets(0) + num_channels = self.get_num_channels() + + if (stream_name == "filt") | (stream_name == "raw"): + probe = get_sinaps_probe(probe_type, num_channels) + if probe is not None: + self.set_probe(probe, in_place=True) + + self._kwargs = {"file_path": str(file_path.absolute()), "stream_name": stream_name} + + +class SinapsResearchPlatformH5RecordingExtractor(BaseRecording): + """ + Recording extractor for the SiNAPS research platform system saved in HDF5 format. + + Parameters + ---------- + file_path : str | Path + Path to the SiNAPS .h5 file. + """ + + extractor_name = "SinapsResearchPlatformH5" + mode = "file" + name = "sinaps_research_platform_h5" + + def __init__(self, file_path: str | Path): + self._file_path = file_path + + sinaps_info = parse_sinapse_h5(self._file_path) + self._rf = sinaps_info["filehandle"] + + BaseRecording.__init__( + self, + sampling_frequency=sinaps_info["sampling_frequency"], + channel_ids=sinaps_info["channel_ids"], + dtype=sinaps_info["dtype"], + ) + + self.extra_requirements.append("h5py") + + recording_segment = SiNAPSH5RecordingSegment( + self._rf, sinaps_info["num_frames"], sampling_frequency=sinaps_info["sampling_frequency"] + ) + self.add_recording_segment(recording_segment) + + # set gain + self.set_channel_gains(sinaps_info["gain"]) + self.set_channel_offsets(sinaps_info["offset"]) + self.num_bits = sinaps_info["num_bits"] + num_channels = self.get_num_channels() + + # set probe + probe = get_sinaps_probe(sinaps_info["probe_type"], num_channels) + if probe is not None: + self.set_probe(probe, in_place=True) + + self._kwargs = {"file_path": str(Path(file_path).absolute())} + + def __del__(self): + self._rf.close() + + +class SiNAPSH5RecordingSegment(BaseRecordingSegment): + def __init__(self, rf, num_frames, sampling_frequency, num_bits): + BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency) + self._rf = rf + self._num_samples = int(num_frames) + self._num_bits = num_bits + self._stream = self._rf.require_group("RealTimeProcessedData") + + def get_num_samples(self): + return self._num_samples + + def get_traces(self, start_frame=None, end_frame=None, channel_indices=None): + if isinstance(channel_indices, slice): + traces = self._stream.get("FilteredData")[channel_indices, start_frame:end_frame].T + else: + # channel_indices is np.ndarray + if np.array(channel_indices).size > 1 and np.any(np.diff(channel_indices) < 0): + # get around h5py constraint that it does not allow datasets + # to be indexed out of order + sorted_channel_indices = np.sort(channel_indices) + resorted_indices = np.array([list(sorted_channel_indices).index(ch) for ch in channel_indices]) + recordings = self._stream.get("FilteredData")[sorted_channel_indices, start_frame:end_frame].T + traces = recordings[:, resorted_indices] + else: + traces = self._stream.get("FilteredData")[channel_indices, start_frame:end_frame].T + # convert uint16 to int16 here to simplify extractor + if traces.dtype == "uint16": + dtype_signed = "int16" + # upcast to int with double itemsize + signed_dtype = "int32" + offset = 2 ** (self._num_bits - 1) + traces = traces.astype(signed_dtype, copy=False) - offset + traces = traces.astype(dtype_signed, copy=False) + return traces + + +read_sinaps_research_platform = define_function_from_class( + source_class=SinapsResearchPlatformRecordingExtractor, name="read_sinaps_research_platform" +) + +read_sinaps_research_platform_h5 = define_function_from_class( + source_class=SinapsResearchPlatformH5RecordingExtractor, name="read_sinaps_research_platform_h5" +) + + +############################################## +# HELPER FUNCTIONS +############################################## + + +def get_sinaps_probe(probe_type, num_channels): + try: + probe = get_probe(manufacturer="sinaps", probe_name=f"SiNAPS-{probe_type}") + # now wire the probe + channel_indices = np.arange(num_channels) + probe.set_device_channel_indices(channel_indices) + return probe + except: + warnings.warn(f"Could not load probe information for {probe_type}") + return None + + +def parse_sinaps_meta(meta_file): + meta_dict = {} + with open(meta_file) as f: + lines = f.readlines() + for l in lines: + if "**" in l or "=" not in l: + continue + else: + key, val = l.split("=") + val = val.replace("\n", "") + try: + val = int(val) + except: + pass + try: + val = eval(val) + except: + pass + meta_dict[key] = val + return meta_dict + + +def parse_sinapse_h5(filename): + """Open an SiNAPS hdf5 file, read and return the recording info.""" + + import h5py + + rf = h5py.File(filename, "r") + + stream = rf.require_group("RealTimeProcessedData") + data = stream.get("FilteredData") + dtype = data.dtype + + parameters = rf.require_group("Parameters") + gain = parameters.get("VoltageConverter")[0] + offset = 0 + + nRecCh, nFrames = data.shape + + samplingRate = parameters.get("SamplingFrequency")[0] + + probe_type = str( + rf.require_group("Advanced Recording Parameters").require_group("Probe").get("probeType").asstr()[...] + ) + num_bits = int( + np.log2(rf.require_group("Advanced Recording Parameters").require_group("DAQ").get("nbADCLevels")[0]) + ) + + sinaps_info = { + "filehandle": rf, + "num_frames": nFrames, + "sampling_frequency": samplingRate, + "num_channels": nRecCh, + "channel_ids": np.arange(nRecCh), + "gain": gain, + "offset": offset, + "dtype": dtype, + "probe_type": probe_type, + "num_bits": num_bits, + } + + return sinaps_info diff --git a/src/spikeinterface/extractors/sinapsrecordingh5extractor.py b/src/spikeinterface/extractors/sinapsrecordingh5extractor.py deleted file mode 100644 index dbfcb239fa..0000000000 --- a/src/spikeinterface/extractors/sinapsrecordingh5extractor.py +++ /dev/null @@ -1,149 +0,0 @@ -from pathlib import Path -import numpy as np - -from probeinterface import get_probe - -from ..core.core_tools import define_function_from_class -from ..core import BaseRecording, BaseRecordingSegment -from ..preprocessing import UnsignedToSignedRecording - - -class SinapsResearchPlatformH5RecordingExtractor_Unsigned(BaseRecording): - extractor_name = "SinapsResearchPlatformH5" - mode = "file" - name = "sinaps_research_platform_h5" - - def __init__(self, file_path): - - try: - import h5py - - self.installed = True - except ImportError: - self.installed = False - - assert self.installed, self.installation_mesg - self._file_path = file_path - - sinaps_info = openSiNAPSFile(self._file_path) - self._rf = sinaps_info["filehandle"] - - BaseRecording.__init__( - self, - sampling_frequency=sinaps_info["sampling_frequency"], - channel_ids=sinaps_info["channel_ids"], - dtype=sinaps_info["dtype"], - ) - - self.extra_requirements.append("h5py") - - recording_segment = SiNAPSRecordingSegment( - self._rf, sinaps_info["num_frames"], sampling_frequency=sinaps_info["sampling_frequency"] - ) - self.add_recording_segment(recording_segment) - - # set gain - self.set_channel_gains(sinaps_info["gain"]) - self.set_channel_offsets(sinaps_info["offset"]) - self.num_bits = sinaps_info["num_bits"] - - # set probe - if sinaps_info["probe_type"] == "p1024s1NHP": - probe = get_probe(manufacturer="sinaps", probe_name="SiNAPS-p1024s1NHP") - probe.set_device_channel_indices(np.arange(1024)) - self.set_probe(probe, in_place=True) - else: - raise ValueError(f"Unknown probe type: {sinaps_info['probe_type']}") - - # set other properties - - self._kwargs = {"file_path": str(Path(file_path).absolute())} - - def __del__(self): - self._rf.close() - - -class SiNAPSRecordingSegment(BaseRecordingSegment): - def __init__(self, rf, num_frames, sampling_frequency): - BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency) - self._rf = rf - self._num_samples = int(num_frames) - self._stream = self._rf.require_group("RealTimeProcessedData") - - def get_num_samples(self): - return self._num_samples - - def get_traces(self, start_frame=None, end_frame=None, channel_indices=None): - if isinstance(channel_indices, slice): - traces = self._stream.get("FilteredData")[channel_indices, start_frame:end_frame].T - else: - # channel_indices is np.ndarray - if np.array(channel_indices).size > 1 and np.any(np.diff(channel_indices) < 0): - # get around h5py constraint that it does not allow datasets - # to be indexed out of order - sorted_channel_indices = np.sort(channel_indices) - resorted_indices = np.array([list(sorted_channel_indices).index(ch) for ch in channel_indices]) - recordings = self._stream.get("FilteredData")[sorted_channel_indices, start_frame:end_frame].T - traces = recordings[:, resorted_indices] - else: - traces = self._stream.get("FilteredData")[channel_indices, start_frame:end_frame].T - return traces - - -class SinapsResearchPlatformH5RecordingExtractor(UnsignedToSignedRecording): - extractor_name = "SinapsResearchPlatformH5" - mode = "file" - name = "sinaps_research_platform_h5" - - def __init__(self, file_path): - recording = SinapsResearchPlatformH5RecordingExtractor_Unsigned(file_path) - UnsignedToSignedRecording.__init__(self, recording, bit_depth=recording.num_bits) - - self._kwargs = {"file_path": str(Path(file_path).absolute())} - - -read_sinaps_research_platform_h5 = define_function_from_class( - source_class=SinapsResearchPlatformH5RecordingExtractor, name="read_sinaps_research_platform_h5" -) - - -def openSiNAPSFile(filename): - """Open an SiNAPS hdf5 file, read and return the recording info.""" - - import h5py - - rf = h5py.File(filename, "r") - - stream = rf.require_group("RealTimeProcessedData") - data = stream.get("FilteredData") - dtype = data.dtype - - parameters = rf.require_group("Parameters") - gain = parameters.get("VoltageConverter")[0] - offset = 0 - - nRecCh, nFrames = data.shape - - samplingRate = parameters.get("SamplingFrequency")[0] - - probe_type = str( - rf.require_group("Advanced Recording Parameters").require_group("Probe").get("probeType").asstr()[...] - ) - num_bits = int( - np.log2(rf.require_group("Advanced Recording Parameters").require_group("DAQ").get("nbADCLevels")[0]) - ) - - sinaps_info = { - "filehandle": rf, - "num_frames": nFrames, - "sampling_frequency": samplingRate, - "num_channels": nRecCh, - "channel_ids": np.arange(nRecCh), - "gain": gain, - "offset": offset, - "dtype": dtype, - "probe_type": probe_type, - "num_bits": num_bits, - } - - return sinaps_info From c044633d0376f62c217e1b4f8bdf715082c4c6e4 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jun 2024 11:41:21 +0200 Subject: [PATCH 060/103] fix extractorlist --- src/spikeinterface/extractors/extractorlist.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/extractorlist.py b/src/spikeinterface/extractors/extractorlist.py index b226a2d838..8948aad606 100644 --- a/src/spikeinterface/extractors/extractorlist.py +++ b/src/spikeinterface/extractors/extractorlist.py @@ -45,8 +45,12 @@ from .herdingspikesextractors import HerdingspikesSortingExtractor, read_herdingspikes from .mdaextractors import MdaRecordingExtractor, MdaSortingExtractor, read_mda_recording, read_mda_sorting from .phykilosortextractors import PhySortingExtractor, KiloSortSortingExtractor, read_phy, read_kilosort -from .sinapsrecordingextractor import SinapsResearchPlatformRecordingExtractor, read_sinaps_research_platform -from .sinapsrecordingh5extractor import SinapsResearchPlatformH5RecordingExtractor, read_sinaps_research_platform_h5 +from .sinapsrecordingextractors import ( + SinapsResearchPlatformRecordingExtractor, + SinapsResearchPlatformH5RecordingExtractor, + read_sinaps_research_platform, + read_sinaps_research_platform_h5, +) # sorting in relation with simulator from .shybridextractors import ( From 465be4286fc65a9f6fe70703168ff973e7f7581d Mon Sep 17 00:00:00 2001 From: Nina Kudryashova Date: Mon, 24 Jun 2024 11:02:49 +0100 Subject: [PATCH 061/103] Fix a missing argument (num_bits) --- src/spikeinterface/extractors/sinapsrecordingextractors.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/sinapsrecordingextractors.py b/src/spikeinterface/extractors/sinapsrecordingextractors.py index df86085dc7..1642fcf351 100644 --- a/src/spikeinterface/extractors/sinapsrecordingextractors.py +++ b/src/spikeinterface/extractors/sinapsrecordingextractors.py @@ -111,7 +111,8 @@ def __init__(self, file_path: str | Path): self.extra_requirements.append("h5py") recording_segment = SiNAPSH5RecordingSegment( - self._rf, sinaps_info["num_frames"], sampling_frequency=sinaps_info["sampling_frequency"] + self._rf, sinaps_info["num_frames"], sampling_frequency=sinaps_info["sampling_frequency"], + num_bits = sinaps_info["num_bits"] ) self.add_recording_segment(recording_segment) From bfb42c5cc074793e6a297931a44c4bf772505931 Mon Sep 17 00:00:00 2001 From: Nina Kudryashova Date: Mon, 24 Jun 2024 11:09:33 +0100 Subject: [PATCH 062/103] Run black --- src/spikeinterface/extractors/sinapsrecordingextractors.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/sinapsrecordingextractors.py b/src/spikeinterface/extractors/sinapsrecordingextractors.py index 1642fcf351..522f639760 100644 --- a/src/spikeinterface/extractors/sinapsrecordingextractors.py +++ b/src/spikeinterface/extractors/sinapsrecordingextractors.py @@ -111,8 +111,10 @@ def __init__(self, file_path: str | Path): self.extra_requirements.append("h5py") recording_segment = SiNAPSH5RecordingSegment( - self._rf, sinaps_info["num_frames"], sampling_frequency=sinaps_info["sampling_frequency"], - num_bits = sinaps_info["num_bits"] + self._rf, + sinaps_info["num_frames"], + sampling_frequency=sinaps_info["sampling_frequency"], + num_bits=sinaps_info["num_bits"], ) self.add_recording_segment(recording_segment) From d1e2d866610ccfc2932f6a3e45eb21fd2d353b6b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jun 2024 12:49:39 +0200 Subject: [PATCH 063/103] Update src/spikeinterface/core/baserecording.py --- src/spikeinterface/core/baserecording.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 4d924e9003..71eecc15bc 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -332,7 +332,8 @@ def get_traces( channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True) rs = self._recording_segments[segment_index] start_frame = int(start_frame) if start_frame is not None else 0 - end_frame = int(min(end_frame, rs.get_num_samples())) if end_frame is not None else rs.get_num_samples() + num_samples = rs.get_num_samples() + end_frame = int(min(end_frame, num_samples)) if end_frame is not None else num_samples if start_frame < 0: raise ValueError("start_frame cannot be negative") traces = rs.get_traces(start_frame=start_frame, end_frame=end_frame, channel_indices=channel_indices) From 5c78a567c54288716f75956ba5885789ca8b36f6 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 24 Jun 2024 13:04:28 +0100 Subject: [PATCH 064/103] Update doc and docstrings for template_metric units --- doc/modules/core.rst | 2 +- doc/modules/postprocessing.rst | 12 ++++++-- .../postprocessing/template_metrics.py | 29 +++++++++++++------ 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 239e42bc3c..73d2217453 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -21,7 +21,7 @@ All classes support: * data on-demand (lazy loading) * multiple segments, where each segment is a contiguous piece of data (recording, sorting, events). - +.. _core-recording: Recording --------- diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index 6465e4af48..00ddefe979 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -304,11 +304,19 @@ By default, the following metrics are computed: * "peak_to_valley": duration between negative and positive peaks * "halfwidth": duration in s at 50% of the amplitude * "peak_to_trough_ratio": ratio between negative and positive peaks -* "recovery_slope": speed in V/s to recover from the negative peak to 0 -* "repolarization_slope": speed in V/s to repolarize from the positive peak to 0 +* "recovery_slope": speed to recover from the negative peak to 0 +* "repolarization_slope": speed to repolarize from the positive peak to 0 * "num_positive_peaks": the number of positive peaks * "num_negative_peaks": the number of negative peaks +The units of the results depend on the input. Voltages are based on the units of the +template, usually :math:`\mu V` (this depends on the :code:`return_scaled` +parameter, read more here: :ref:`core-recording`). Distances are based on the unit of the +underlying recording's probe's :code:`channel_locations`, usually :math:`\mu m`. +Times are always in seconds. E.g. if the templates are in units of :math:`mV` and channel +locations in :math:`\mu m` then: :code:`repolarization_slope` is in :math:`mV / s`; +:code:`peak_to_trough_ratio` is in :math:`\mu m` and the :code:`halfwidth` is in :math:`s`. + Optionally, the following multi-channel metrics can be computed by setting: :code:`include_multi_channel_metrics=True` diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index d7179ffefa..35d954389a 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -410,7 +410,8 @@ def get_repolarization_slope(template_single, sampling_frequency, trough_idx=Non After reaching it's maximum polarization, the neuron potential will recover. The repolarization slope is defined as the dV/dT of the action potential - between trough and baseline. + between trough and baseline. The returned slope is in units of (unit of template) + per second. Parameters ---------- @@ -454,12 +455,10 @@ def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwa Return the recovery slope of input waveforms. After repolarization, the neuron hyperpolarizes until it peaks. The recovery slope is the slope of the action potential after the peak, returning to the baseline - in dV/dT. The slope is computed within a user-defined window after + in dV/dT. The returned slope is in units of (unit of template) + per second. The slope is computed within a user-defined window after the peak. - Takes a numpy array of waveforms and returns an array with - recovery slopes per waveform. - Parameters ---------- template_single: numpy.ndarray @@ -619,7 +618,7 @@ def fit_velocity(peak_times, channel_dist): def get_velocity_above(template, channel_locations, sampling_frequency, **kwargs): """ - Compute the velocity above the max channel of the template. + Compute the velocity above the max channel of the template in units (unit of channel locations) per second, usually um/s. Parameters ---------- @@ -697,7 +696,7 @@ def get_velocity_above(template, channel_locations, sampling_frequency, **kwargs def get_velocity_below(template, channel_locations, sampling_frequency, **kwargs): """ - Compute the velocity below the max channel of the template. + Compute the velocity below the max channel of the template in units (unit of channel locations) per second, usually um/s. Parameters ---------- @@ -775,7 +774,8 @@ def get_velocity_below(template, channel_locations, sampling_frequency, **kwargs def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs): """ - Compute the exponential decay of the template amplitude over distance. + Compute the exponential decay of the template amplitude over distance. The returned value + is in the same units as `channel_locations`, usually um. Parameters ---------- @@ -788,6 +788,11 @@ def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs **kwargs: Required kwargs: - exp_peak_function: the function to use to compute the peak amplitude for the exp decay ("ptp" or "min") - min_r2_exp_decay: the minimum r2 to accept the exp decay fit + + Returns + ------- + exp_decay_value : float + The exponential decay of the template amplitude """ from scipy.optimize import curve_fit from sklearn.metrics import r2_score @@ -853,7 +858,8 @@ def exp_decay(x, decay, amp0, offset): def get_spread(template, channel_locations, sampling_frequency, **kwargs): """ - Compute the spread of the template amplitude over distance. + Compute the spread of the template amplitude over distance. The returned value + is in the same units as `channel_locations`, usually um. Parameters ---------- @@ -867,6 +873,11 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs): - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - spread_threshold: the threshold to compute the spread - column_range: the range in um in the x-direction to consider channels for velocity + + Returns + ------- + spread : float + Spread of the template amplitude """ assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" depth_direction = kwargs["depth_direction"] From a96fdbbe7f3f4ab0410118c4ca7e2b2464f03ff2 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 24 Jun 2024 14:41:25 +0100 Subject: [PATCH 065/103] Respond to reviews --- doc/modules/core.rst | 7 +++-- doc/modules/postprocessing.rst | 27 ++++++++++--------- .../postprocessing/template_metrics.py | 19 ++++++------- 3 files changed, 27 insertions(+), 26 deletions(-) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 73d2217453..5c0713fa21 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -162,7 +162,7 @@ Internally, any sorting object can construct 2 internal caches: 2. a unique numpy.array with structured dtype aka "spikes vector". This is useful for processing by small chunks of time, like for extracting amplitudes from a recording. - +.. _core-sorting-analyzer: SortingAnalyzer --------------- @@ -179,9 +179,8 @@ to perform further analysis, such as calculating :code:`waveforms` and :code:`te Importantly, the :py:class:`~spikeinterface.core.SortingAnalyzer` handles the *sparsity* and the physical *scaling*. Sparsity defines the channels on which waveforms and templates are calculated using, for example, a physical distance from the channel with the largest peak amplitude (see the :ref:`Sparsity` section). Scaling, set by -the :code:`return_scaled` argument, says whether the data has been converted from integer values to physical units such as -Voltage (see the end of the :ref:`Recording` section). - +the :code:`return_scaled` argument, determines whether the data is converted from integer values to :math:`\mu V` or not. +By default, it is converted and all traces have units of :math:`\mu V`. Now we will create a :code:`SortingAnalyzer` called :code:`sorting_analyzer`. diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index 00ddefe979..9aad8568f4 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -301,29 +301,30 @@ template_metrics This extension computes commonly used waveform/template metrics. By default, the following metrics are computed: -* "peak_to_valley": duration between negative and positive peaks -* "halfwidth": duration in s at 50% of the amplitude +* "peak_to_valley": duration in :math:`s` between negative and positive peaks +* "halfwidth": duration in :math:`s` at 50% of the amplitude * "peak_to_trough_ratio": ratio between negative and positive peaks * "recovery_slope": speed to recover from the negative peak to 0 * "repolarization_slope": speed to repolarize from the positive peak to 0 * "num_positive_peaks": the number of positive peaks * "num_negative_peaks": the number of negative peaks -The units of the results depend on the input. Voltages are based on the units of the -template, usually :math:`\mu V` (this depends on the :code:`return_scaled` -parameter, read more here: :ref:`core-recording`). Distances are based on the unit of the -underlying recording's probe's :code:`channel_locations`, usually :math:`\mu m`. -Times are always in seconds. E.g. if the templates are in units of :math:`mV` and channel -locations in :math:`\mu m` then: :code:`repolarization_slope` is in :math:`mV / s`; -:code:`peak_to_trough_ratio` is in :math:`\mu m` and the :code:`halfwidth` is in :math:`s`. +The units of :code:`recovery_slope` and :code:`repolarization_slope` depend on the +input. Voltages are based on the units of the template. By default this is :math:`\mu V` +but can be the raw output from the recording device (this depends on the +:code:`return_scaled` parameter, read more here: :ref:`core-sorting-analyzer`). +Distances are in :math:`\mu m` and times are in seconds. So, for example, if the +templates are in units of :math:`\mu V` then: :code:`repolarization_slope` is in +:math:`mV / s`; :code:`peak_to_trough_ratio` is in :math:`\mu m` and the +:code:`halfwidth` is in :math:`s`. Optionally, the following multi-channel metrics can be computed by setting: :code:`include_multi_channel_metrics=True` -* "velocity_above": the velocity above the max channel of the template -* "velocity_below": the velocity below the max channel of the template -* "exp_decay": the exponential decay of the template amplitude over distance -* "spread": the spread of the template amplitude over distance +* "velocity_above": the velocity in :math:`\mu m/s` above the max channel of the template +* "velocity_below": the velocity in :math:`\mu m/s` below the max channel of the template +* "exp_decay": the exponential decay in :math:`\mu m` of the template amplitude over distance +* "spread": the spread in :math:`\mu m` of the template amplitude over distance .. figure:: ../images/1d_waveform_features.png diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 35d954389a..fdc4ef4719 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -411,7 +411,9 @@ def get_repolarization_slope(template_single, sampling_frequency, trough_idx=Non After reaching it's maximum polarization, the neuron potential will recover. The repolarization slope is defined as the dV/dT of the action potential between trough and baseline. The returned slope is in units of (unit of template) - per second. + per second. By default traces are scaled to units of uV, controlled + by `sorting_analyzer.return_scaled`. In this case this function returns the slope + in uV/s. Parameters ---------- @@ -456,8 +458,9 @@ def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwa the neuron hyperpolarizes until it peaks. The recovery slope is the slope of the action potential after the peak, returning to the baseline in dV/dT. The returned slope is in units of (unit of template) - per second. The slope is computed within a user-defined window after - the peak. + per second. By default traces are scaled to units of uV, controlled + by `sorting_analyzer.return_scaled`. In this case this function returns the slope + in uV/s. The slope is computed within a user-defined window after the peak. Parameters ---------- @@ -618,7 +621,7 @@ def fit_velocity(peak_times, channel_dist): def get_velocity_above(template, channel_locations, sampling_frequency, **kwargs): """ - Compute the velocity above the max channel of the template in units (unit of channel locations) per second, usually um/s. + Compute the velocity above the max channel of the template in units um/s. Parameters ---------- @@ -696,7 +699,7 @@ def get_velocity_above(template, channel_locations, sampling_frequency, **kwargs def get_velocity_below(template, channel_locations, sampling_frequency, **kwargs): """ - Compute the velocity below the max channel of the template in units (unit of channel locations) per second, usually um/s. + Compute the velocity below the max channel of the template in units um/s. Parameters ---------- @@ -774,8 +777,7 @@ def get_velocity_below(template, channel_locations, sampling_frequency, **kwargs def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs): """ - Compute the exponential decay of the template amplitude over distance. The returned value - is in the same units as `channel_locations`, usually um. + Compute the exponential decay of the template amplitude over distance in units um/s. Parameters ---------- @@ -858,8 +860,7 @@ def exp_decay(x, decay, amp0, offset): def get_spread(template, channel_locations, sampling_frequency, **kwargs): """ - Compute the spread of the template amplitude over distance. The returned value - is in the same units as `channel_locations`, usually um. + Compute the spread of the template amplitude over distance in units um/s. Parameters ---------- From 227d91d02bc5c1b80dff106c3606632d8604c05c Mon Sep 17 00:00:00 2001 From: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 24 Jun 2024 15:08:55 +0100 Subject: [PATCH 066/103] Update doc/modules/core.rst Co-authored-by: Alessio Buccino --- doc/modules/core.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 5c0713fa21..f8f410018b 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -180,7 +180,7 @@ Importantly, the :py:class:`~spikeinterface.core.SortingAnalyzer` handles the *s Sparsity defines the channels on which waveforms and templates are calculated using, for example, a physical distance from the channel with the largest peak amplitude (see the :ref:`Sparsity` section). Scaling, set by the :code:`return_scaled` argument, determines whether the data is converted from integer values to :math:`\mu V` or not. -By default, it is converted and all traces have units of :math:`\mu V`. +By default, :code:`return_scaled` is true and all processed data voltage values are in :math:`\mu V` (e.g., waveforms, templates, spike amplitudes, etc.). Now we will create a :code:`SortingAnalyzer` called :code:`sorting_analyzer`. From 99565a321c888a5bfe69cacc2466d0b651c37fb0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jun 2024 17:10:47 +0200 Subject: [PATCH 067/103] Update src/spikeinterface/core/baserecording.py --- src/spikeinterface/core/baserecording.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 71eecc15bc..ede59c3e66 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -334,8 +334,6 @@ def get_traces( start_frame = int(start_frame) if start_frame is not None else 0 num_samples = rs.get_num_samples() end_frame = int(min(end_frame, num_samples)) if end_frame is not None else num_samples - if start_frame < 0: - raise ValueError("start_frame cannot be negative") traces = rs.get_traces(start_frame=start_frame, end_frame=end_frame, channel_indices=channel_indices) if order is not None: assert order in ["C", "F"] From fc3e6331eb3284e592808e238ab6954cb394154f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jun 2024 17:49:48 +0200 Subject: [PATCH 068/103] Add plot_drift_map --- src/spikeinterface/widgets/driftmap.py | 143 ++++++++++++++++++++++ src/spikeinterface/widgets/motion.py | 84 +++++-------- src/spikeinterface/widgets/widget_list.py | 3 + 3 files changed, 179 insertions(+), 51 deletions(-) create mode 100644 src/spikeinterface/widgets/driftmap.py diff --git a/src/spikeinterface/widgets/driftmap.py b/src/spikeinterface/widgets/driftmap.py new file mode 100644 index 0000000000..60e8df2972 --- /dev/null +++ b/src/spikeinterface/widgets/driftmap.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import numpy as np + +from .base import BaseWidget, to_attr + + +class DriftMapWidget(BaseWidget): + """ + Plot the a drift map from a motion info dictionary. + + Parameters + ---------- + peaks : np.array + The peaks array, with dtype ("sample_index", "channel_index", "amplitude", "segment_index") + peak_locations : np.array + The peak locations, with dtype ("x", "y") or ("x", "y", "z") + direction : "x" or "y", default: "y" + The direction to display + segment_index : int, default: None + The segment index to display. + recording : RecordingExtractor, default: None + The recording extractor object (only used to get "real" times) + segment_index : int, default: 0 + The segment index to display. + sampling_frequency : float, default: None + The sampling frequency (needed if recording is None) + depth_lim : tuple or None, default: None + The min and max depth to display, if None (min and max of the recording) + color_amplitude : bool, default: True + If True, the color of the scatter points is the amplitude of the peaks + scatter_decimate : int, default: None + If > 1, the scatter points are decimated + cmap : str, default: "inferno" + The colormap to use for the amplitude + clim : tuple or None, default: None + The min and max amplitude to display, if None (min and max of the amplitudes) + alpha : float, default: 1 + The alpha of the scatter points + """ + + def __init__( + self, + peaks, + peak_locations, + direction="y", + recording=None, + sampling_frequency=None, + segment_index=None, + depth_lim=None, + color_amplitude=True, + scatter_decimate=None, + cmap="inferno", + clim=None, + alpha=1, + backend=None, + **backend_kwargs, + ): + if segment_index is None: + assert ( + len(np.unique(peaks["segment_index"])) == 1 + ), "segment_index must be specified if there is only one segment in the peaks array" + assert recording or sampling_frequency, "recording or sampling_frequency must be specified" + if recording is not None: + sampling_frequency = recording.sampling_frequency + times = recording.get_times(segment_index=segment_index) + else: + times = None + + plot_data = dict( + peaks=peaks, + peak_locations=peak_locations, + direction=direction, + times=times, + sampling_frequency=sampling_frequency, + segment_index=segment_index, + depth_lim=depth_lim, + color_amplitude=color_amplitude, + scatter_decimate=scatter_decimate, + cmap=cmap, + clim=clim, + alpha=alpha, + recording=recording, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from matplotlib.colors import Normalize + + from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks + + dp = to_attr(data_plot) + + assert backend_kwargs["axes"] is None, "axes argument is not allowed in MotionWidget" + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + fig = self.figure + + if dp.times is None: + # temporal_bins_plot = dp.temporal_bins + x = dp.peaks["sample_index"] / dp.sampling_frequency + else: + # use real times and adjust temporal bins with t_start + # temporal_bins_plot = dp.temporal_bins + dp.times[0] + x = dp.times[dp.peaks["sample_index"]] + + y = dp.peak_locations[dp.direction] + if dp.scatter_decimate is not None: + x = x[:: dp.scatter_decimate] + y = y[:: dp.scatter_decimate] + y2 = y2[:: dp.scatter_decimate] + + if dp.color_amplitude: + amps = dp.peaks["amplitude"] + amps_abs = np.abs(amps) + q_95 = np.quantile(amps_abs, 0.95) + if dp.scatter_decimate is not None: + amps = amps[:: dp.scatter_decimate] + amps_abs = amps_abs[:: dp.scatter_decimate] + cmap = plt.colormaps[dp.cmap] + if dp.clim is None: + amps = amps_abs + amps /= q_95 + c = cmap(amps) + else: + norm_function = Normalize(vmin=dp.clim[0], vmax=dp.clim[1], clip=True) + c = cmap(norm_function(amps)) + color_kwargs = dict( + color=None, + c=c, + alpha=dp.alpha, + ) + else: + color_kwargs = dict(color="k", c=None, alpha=dp.alpha) + + self.ax.scatter(x, y, s=1, **color_kwargs) + if dp.depth_lim is not None: + self.ax.set_ylim(*dp.depth_lim) + self.ax.set_title("Peak depth") + self.ax.set_xlabel("Times [s]") + self.ax.set_ylabel("Depth [$\\mu$m]") diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index fc0c91423d..7d733523df 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -3,6 +3,7 @@ import numpy as np from .base import BaseWidget, to_attr +from .driftmap import DriftMapWidget class MotionWidget(BaseWidget): @@ -107,7 +108,7 @@ class MotionInfoWidget(BaseWidget): Parameters ---------- motion_info : dict - The motion info return by correct_motion() or load back with load_motion_info() + The motion info returned by correct_motion() or loaded back with load_motion_info() segment_index : int, default: None The segment index to display. recording : RecordingExtractor, default: None @@ -153,7 +154,9 @@ def __init__( if len(motion.displacement) == 1: segment_index = 0 else: - raise ValueError("plot motion : teh Motion object is multi segment you must provide segmentindex=XX") + raise ValueError( + "plot drift map : the Motion object is multi-segment you must provide segment_index=XX" + ) times = recording.get_times() if recording is not None else None @@ -214,14 +217,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax1.sharex(ax0) ax1.sharey(ax0) - if dp.times is None: - # temporal_bins_plot = dp.temporal_bins - x = dp.peaks["sample_index"] / dp.sampling_frequency - else: - # use real times and adjust temporal bins with t_start - # temporal_bins_plot = dp.temporal_bins + dp.times[0] - x = dp.times[dp.peaks["sample_index"]] - corrected_location = correct_motion_on_peaks( dp.peaks, dp.peak_locations, @@ -229,47 +224,34 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp.recording, ) - y = dp.peak_locations[motion.direction] - y2 = corrected_location[motion.direction] - if dp.scatter_decimate is not None: - x = x[:: dp.scatter_decimate] - y = y[:: dp.scatter_decimate] - y2 = y2[:: dp.scatter_decimate] - - if dp.color_amplitude: - amps = dp.peaks["amplitude"] - amps_abs = np.abs(amps) - q_95 = np.quantile(amps_abs, 0.95) - if dp.scatter_decimate is not None: - amps = amps[:: dp.scatter_decimate] - amps_abs = amps_abs[:: dp.scatter_decimate] - cmap = plt.colormaps[dp.amplitude_cmap] - if dp.amplitude_clim is None: - amps = amps_abs - amps /= q_95 - c = cmap(amps) - else: - norm_function = Normalize(vmin=dp.amplitude_clim[0], vmax=dp.amplitude_clim[1], clip=True) - c = cmap(norm_function(amps)) - color_kwargs = dict( - color=None, - c=c, - alpha=dp.amplitude_alpha, - ) - else: - color_kwargs = dict(color="k", c=None, alpha=dp.amplitude_alpha) - - ax0.scatter(x, y, s=1, **color_kwargs) - if dp.depth_lim is not None: - ax0.set_ylim(*dp.depth_lim) - ax0.set_title("Peak depth") - ax0.set_xlabel("Times [s]") - ax0.set_ylabel("Depth [$\\mu$m]") - - ax1.scatter(x, y2, s=1, **color_kwargs) - ax1.set_xlabel("Times [s]") - ax1.set_ylabel("Depth [$\\mu$m]") - ax1.set_title("Corrected peak depth") + commpon_drift_map_kwargs = dict( + direction=dp.motion.direction, + recording=dp.recording, + segment_index=dp.segment_index, + depth_lim=dp.depth_lim, + color_amplitude=dp.color_amplitude, + scatter_decimate=dp.scatter_decimate, + cmap=dp.amplitude_cmap, + clim=dp.amplitude_clim, + alpha=dp.amplitude_alpha, + backend="matplotlib", + ) + + drift_map = DriftMapWidget( + dp.peaks, + dp.peak_locations, + ax=ax0, + immediate_plot=True, + **commpon_drift_map_kwargs, + ) + + drift_map_corrected = DriftMapWidget( + dp.peaks, + corrected_location, + ax=ax1, + immediate_plot=True, + **commpon_drift_map_kwargs, + ) ax2.plot(temporal_bins_s, displacement, alpha=0.2, color="black") ax2.plot(temporal_bins_s, np.mean(displacement, axis=1), color="C0") diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 6367e098ea..8d4accaa7e 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -9,6 +9,7 @@ from .amplitudes import AmplitudesWidget from .autocorrelograms import AutoCorrelogramsWidget from .crosscorrelograms import CrossCorrelogramsWidget +from .driftmap import DriftMapWidget from .isi_distribution import ISIDistributionWidget from .motion import MotionWidget, MotionInfoWidget from .multicomparison import MultiCompGraphWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget @@ -44,6 +45,7 @@ ConfusionMatrixWidget, ComparisonCollisionBySimilarityWidget, CrossCorrelogramsWidget, + DriftMapWidget, ISIDistributionWidget, MotionWidget, MotionInfoWidget, @@ -118,6 +120,7 @@ plot_confusion_matrix = ConfusionMatrixWidget plot_comparison_collision_by_similarity = ComparisonCollisionBySimilarityWidget plot_crosscorrelograms = CrossCorrelogramsWidget +plot_drift_map = DriftMapWidget plot_isi_distribution = ISIDistributionWidget plot_motion = MotionWidget plot_motion_info = MotionInfoWidget From baf1287215e41b020dc97b5d6428dbdc5446ef76 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 25 Jun 2024 11:20:06 +0100 Subject: [PATCH 069/103] Fix docstrings for extractors module --- doc/api.rst | 2 +- doc/modules/extractors.rst | 2 +- src/spikeinterface/extractors/cbin_ibl.py | 4 +++- .../extractors/herdingspikesextractors.py | 2 +- src/spikeinterface/extractors/iblextractors.py | 2 +- .../extractors/neoextractors/alphaomega.py | 5 +++++ .../extractors/neoextractors/biocam.py | 1 - .../extractors/neoextractors/blackrock.py | 3 ++- .../extractors/neoextractors/ced.py | 2 -- .../extractors/neoextractors/intan.py | 2 ++ .../extractors/neoextractors/maxwell.py | 2 ++ .../extractors/neoextractors/neuralynx.py | 11 ++++++----- .../extractors/neoextractors/plexon2.py | 2 +- .../extractors/neoextractors/spikegadgets.py | 2 +- .../extractors/neoextractors/tdt.py | 2 ++ src/spikeinterface/extractors/toy_example.py | 10 ++++++++-- src/spikeinterface/preprocessing/filter.py | 17 ++++++++--------- 17 files changed, 44 insertions(+), 27 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index a7476cd62f..c5c9ebe4dd 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -117,7 +117,7 @@ Non-NEO-based .. autofunction:: read_bids .. autofunction:: read_cbin_ibl .. autofunction:: read_combinato - .. autofunction:: read_ibl_streaming_recording + .. autofunction:: read_ibl_recording .. autofunction:: read_hdsort .. autofunction:: read_herdingspikes .. autofunction:: read_kilosort diff --git a/doc/modules/extractors.rst b/doc/modules/extractors.rst index 2d0e047672..ba08e45aca 100644 --- a/doc/modules/extractors.rst +++ b/doc/modules/extractors.rst @@ -125,7 +125,7 @@ For raw recording formats, we currently support: * **Biocam HDF5** :py:func:`~spikeinterface.extractors.read_biocam()` * **CED** :py:func:`~spikeinterface.extractors.read_ced()` * **EDF** :py:func:`~spikeinterface.extractors.read_edf()` -* **IBL streaming** :py:func:`~spikeinterface.extractors.read_ibl_streaming_recording()` +* **IBL streaming** :py:func:`~spikeinterface.extractors.read_ibl_recording()` * **Intan** :py:func:`~spikeinterface.extractors.read_intan()` * **MaxWell** :py:func:`~spikeinterface.extractors.read_maxwell()` * **MCS H5** :py:func:`~spikeinterface.extractors.read_mcsh5()` diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index a6da19408f..1687acb073 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -27,9 +27,11 @@ class CompressedBinaryIblExtractor(BaseRecording): load_sync_channel : bool, default: False Load or not the last channel (sync). If not then the probe is loaded. - stream_name : str, default: "ap". + stream_name : {"ap", "lp"}, default: "ap". Whether to load AP or LFP band, one of "ap" or "lp". + cbin_file : str or None, default None + The cbin file of the recording. If None, searches in `folder_path` for file. Returns ------- diff --git a/src/spikeinterface/extractors/herdingspikesextractors.py b/src/spikeinterface/extractors/herdingspikesextractors.py index 139d51d62e..87f7dd74c4 100644 --- a/src/spikeinterface/extractors/herdingspikesextractors.py +++ b/src/spikeinterface/extractors/herdingspikesextractors.py @@ -20,7 +20,7 @@ class HerdingspikesSortingExtractor(BaseSorting): Parameters ---------- - folder_path : str or Path + file_path : str or Path Path to the ALF folder. load_unit_info : bool, default: True Whether to load the unit info from the file. diff --git a/src/spikeinterface/extractors/iblextractors.py b/src/spikeinterface/extractors/iblextractors.py index 2444314aec..27bb95854f 100644 --- a/src/spikeinterface/extractors/iblextractors.py +++ b/src/spikeinterface/extractors/iblextractors.py @@ -41,7 +41,7 @@ class IblRecordingExtractor(BaseRecording): stream_name : str The name of the stream to load for the session. These can be retrieved from calling `StreamingIblExtractor.get_stream_names(session="")`. - load_sync_channels : bool, default: false + load_sync_channel : bool, default: false Load or not the last channel (sync). If not then the probe is loaded. cache_folder : str or None, default: None diff --git a/src/spikeinterface/extractors/neoextractors/alphaomega.py b/src/spikeinterface/extractors/neoextractors/alphaomega.py index 5c8e58d3a5..239928f66d 100644 --- a/src/spikeinterface/extractors/neoextractors/alphaomega.py +++ b/src/spikeinterface/extractors/neoextractors/alphaomega.py @@ -50,6 +50,11 @@ def map_to_neo_kwargs(cls, folder_path, lsx_files=None): class AlphaOmegaEventExtractor(NeoBaseEventExtractor): """ Class for reading events from AlphaOmega MPX file format + + Parameters + ---------- + folder_path : str or Path-like + The folder path to the AlphaOmega events. """ mode = "folder" diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 96d4dd25a6..9f23575dba 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -42,7 +42,6 @@ def __init__( electrode_width=None, stream_id=None, stream_name=None, - block_index=None, all_annotations=False, ): neo_kwargs = self.map_to_neo_kwargs(file_path) diff --git a/src/spikeinterface/extractors/neoextractors/blackrock.py b/src/spikeinterface/extractors/neoextractors/blackrock.py index 5e28c4a20d..0015fd9f67 100644 --- a/src/spikeinterface/extractors/neoextractors/blackrock.py +++ b/src/spikeinterface/extractors/neoextractors/blackrock.py @@ -26,6 +26,8 @@ class BlackrockRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. + use_names_as_ids : bool or None, default: None + If True, use channel names as IDs. If None, use default IDs. """ mode = "file" @@ -37,7 +39,6 @@ def __init__( file_path, stream_id=None, stream_name=None, - block_index=None, all_annotations=False, use_names_as_ids=False, ): diff --git a/src/spikeinterface/extractors/neoextractors/ced.py b/src/spikeinterface/extractors/neoextractors/ced.py index 401c927fc7..e2c79478fa 100644 --- a/src/spikeinterface/extractors/neoextractors/ced.py +++ b/src/spikeinterface/extractors/neoextractors/ced.py @@ -23,8 +23,6 @@ class CedRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream id you want to load. stream_name : str, default: None If there are several streams, specify the stream name you want to load. - block_index : int, default: None - If there are several blocks, specify the block index you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. """ diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index c37ff47807..9d4db3103c 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -27,6 +27,8 @@ class IntanRecordingExtractor(NeoBaseRecordingExtractor): If True, data that violates integrity assumptions will be loaded. At the moment the only integrity check we perform is that timestamps are continuous. Setting this to True will ignore this check and set the attribute `discontinuous_timestamps` to True in the underlying neo object. + use_names_as_ids : bool or None, default: None + If True, use channel names as IDs. If None, use default IDs. """ mode = "file" diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 3888b6d5a0..a66075b451 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -35,6 +35,8 @@ class MaxwellRecordingExtractor(NeoBaseRecordingExtractor): you want to extract. (rec_name='rec0000'). install_maxwell_plugin : bool, default: False If True, install the maxwell plugin for neo. + block_index : int, default: None + If there are several blocks (experiments), specify the block index you want to load """ mode = "file" diff --git a/src/spikeinterface/extractors/neoextractors/neuralynx.py b/src/spikeinterface/extractors/neoextractors/neuralynx.py index 25b6bb5b61..0670371ba9 100644 --- a/src/spikeinterface/extractors/neoextractors/neuralynx.py +++ b/src/spikeinterface/extractors/neoextractors/neuralynx.py @@ -26,16 +26,17 @@ class NeuralynxRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. - exlude_filename : list[str], default: None + exclude_filename : list[str], default: None List of filename to exclude from the loading. For example, use `exclude_filename=["events.nev"]` to skip loading the event file. strict_gap_mode : bool, default: False See neo documentation. Detect gaps using strict mode or not. - * strict_gap_mode = True then a gap is consider when timstamp difference between two - consecutive data packets is more than one sample interval. - * strict_gap_mode = False then a gap has an increased tolerance. Some new systems with different clocks need this option - otherwise, too many gaps are detected + * strict_gap_mode = True then a gap is consider when timstamp difference between + two consecutive data packets is more than one sample interval. + * strict_gap_mode = False then a gap has an increased tolerance. Some new systems + with different clocks need this option otherwise, too many gaps are detected + Note that here the default is False contrary to neo. """ diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 941158def1..c7351a308b 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -19,7 +19,7 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream id you want to load. stream_name : str, default: None If there are several streams, specify the stream name you want to load. - use_names_as_ids: + use_names_as_ids : bool, default: True If True, the names of the signals are used as channel ids. If False, the channel ids are a combination of the source id and the channel index. diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index f326c49cd1..3d57817f88 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -32,7 +32,7 @@ class SpikeGadgetsRecordingExtractor(NeoBaseRecordingExtractor): NeoRawIOClass = "SpikeGadgetsRawIO" name = "spikegadgets" - def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None, all_annotations=False): + def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/tdt.py b/src/spikeinterface/extractors/neoextractors/tdt.py index 146f6a4b4c..27b456102f 100644 --- a/src/spikeinterface/extractors/neoextractors/tdt.py +++ b/src/spikeinterface/extractors/neoextractors/tdt.py @@ -23,6 +23,8 @@ class TdtRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. + block_index : int, default: None + If there are several blocks (experiments), specify the block index you want to load """ mode = "folder" diff --git a/src/spikeinterface/extractors/toy_example.py b/src/spikeinterface/extractors/toy_example.py index 450044d07b..2f007cca88 100644 --- a/src/spikeinterface/extractors/toy_example.py +++ b/src/spikeinterface/extractors/toy_example.py @@ -57,12 +57,18 @@ def toy_example( Spike time in the recording spike_labels : np.array or list[nparray] or None, default: None Cluster label for each spike time (needs to specified both together). - # score_detection : int (between 0 and 1) - # Generate the sorting based on a subset of spikes compare with the trace generation firing_rate : float, default: 3.0 The firing rate for the units (in Hz) seed : int or None, default: None Seed for random initialization. + upsample_factor : None or int, default: None + A upsampling factor used only when templates are not provided. + num_columns : int, default: 1 + Number of columns in probe. + average_peak_amplitude : float, default: -100 + Average peak amplitude of generated templates + contact_spacing_um : float, default: 40.0 + Spacing between probe contacts. Returns ------- diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 6a1733c57c..d18227ca83 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -10,15 +10,14 @@ _common_filter_docs = """**filter_kwargs : dict Certain keyword arguments for `scipy.signal` filters: - filter_order : order - The order of the filter - filter_mode : "sos" | "ba", default: "sos" - Filter form of the filter coefficients: - - second-order sections ("sos") - - numerator/denominator : ("ba") - ftype : str, default: "butter" - Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1". - """ + filter_order : order + The order of the filter + filter_mode : "sos" | "ba", default: "sos" + Filter form of the filter coefficients: + - second-order sections ("sos") + - numerator/denominator : ("ba") + ftype : str, default: "butter" + Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1".""" class FilterRecording(BasePreprocessor): From 5c28ecfe9f93ed5deff88ae3ad5485e37ae4b1f6 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 25 Jun 2024 08:38:04 -0600 Subject: [PATCH 070/103] Add macos and windows to cache cron jobs (#3075) Add macos and windows to cron jobs for caching testing data --- .github/workflows/caches_cron_job.yml | 68 ++++++++++----------------- 1 file changed, 25 insertions(+), 43 deletions(-) diff --git a/.github/workflows/caches_cron_job.yml b/.github/workflows/caches_cron_job.yml index 20e2a55178..2454e97ad7 100644 --- a/.github/workflows/caches_cron_job.yml +++ b/.github/workflows/caches_cron_job.yml @@ -2,64 +2,35 @@ name: Create caches for gin ecephys data and virtual env on: workflow_dispatch: - push: # When someting is pushed into main this checks if caches need to re-created + push: # When something is pushed into main this checks if caches need to be re-created branches: - main schedule: - cron: "0 12 * * *" # Daily at noon UTC jobs: - - - - create-virtual-env-cache-if-missing: - name: Caching virtual env - runs-on: "ubuntu-latest" - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: '3.10' - - name: Get current year-month - id: date - run: | - echo "date=$(date +'%Y-%m')" >> $GITHUB_OUTPUT - - name: Get current dependencies hash - id: dependencies - run: | - echo "hash=${{hashFiles('**/pyproject.toml')}}" >> $GITHUB_OUTPUT - - uses: actions/cache@v4 - id: cache-venv - with: - path: ${{ github.workspace }}/test_env - key: ${{ runner.os }}-venv-${{ steps.dependencies.outputs.hash }}-${{ steps.date.outputs.date }} - lookup-only: 'true' # Avoids downloading the data, saving behavior is not affected. - - name: Cache found? - run: echo "Cache-hit == ${{steps.cache-venv.outputs.cache-hit == 'true'}}" - - name: Create the virtual environment to be cached - if: steps.cache-venv.outputs.cache-hit != 'true' - uses: ./.github/actions/build-test-environment - - - - create-gin-data-cache-if-missing: name: Caching data env - runs-on: "ubuntu-latest" + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] steps: - uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.11' - name: Create the directory to store the data run: | - mkdir --parents --verbose $HOME/spikeinterface_datasets/ephy_testing_data/ - chmod -R 777 $HOME/spikeinterface_datasets - ls -l $HOME/spikeinterface_datasets + mkdir -p ~/spikeinterface_datasets/ephy_testing_data/ + ls -l ~/spikeinterface_datasets + shell: bash - name: Get current hash (SHA) of the ephy_testing_data repo id: repo_hash run: | echo "dataset_hash=$(git ls-remote https://gin.g-node.org/NeuralEnsemble/ephy_testing_data.git HEAD | cut -f1)" echo "dataset_hash=$(git ls-remote https://gin.g-node.org/NeuralEnsemble/ephy_testing_data.git HEAD | cut -f1)" >> $GITHUB_OUTPUT + shell: bash - uses: actions/cache@v4 id: cache-datasets with: @@ -68,6 +39,7 @@ jobs: lookup-only: 'true' # Avoids downloading the data, saving behavior is not affected. - name: Cache found? run: echo "Cache-hit == ${{steps.cache-datasets.outputs.cache-hit == 'true'}}" + shell: bash - name: Installing datalad and git-annex if: steps.cache-datasets.outputs.cache-hit != 'true' run: | @@ -75,20 +47,29 @@ jobs: git config --global user.name "CI Almighty" python -m pip install -U pip # Official recommended way pip install datalad-installer - datalad-installer --sudo ok git-annex --method datalad/packages + if [ ${{ runner.os }} == 'Linux' ]; then + datalad-installer --sudo ok git-annex --method datalad/packages + elif [ ${{ runner.os }} == 'macOS' ]; then + datalad-installer --sudo ok git-annex --method brew + elif [ ${{ runner.os }} == 'Windows' ]; then + datalad-installer --sudo ok git-annex --method datalad/git-annex:release + fi pip install datalad git config --global filter.annex.process "git-annex filter-process" # recommended for efficiency + shell: bash - name: Download dataset if: steps.cache-datasets.outputs.cache-hit != 'true' run: | datalad install --recursive --get-data https://gin.g-node.org/NeuralEnsemble/ephy_testing_data + shell: bash - name: Move the downloaded data to the right directory if: steps.cache-datasets.outputs.cache-hit != 'true' run: | - mv --force ./ephy_testing_data $HOME/spikeinterface_datasets/ + mv ./ephy_testing_data ~/spikeinterface_datasets/ + shell: bash - name: Show size of the cache to assert data is downloaded run: | - cd $HOME + cd ~ pwd du -hs spikeinterface_datasets # Should show the size of ephy_testing_data cd spikeinterface_datasets @@ -96,3 +77,4 @@ jobs: ls -lh # Should show ephy_testing_data cd ephy_testing_data ls -lh + shell: bash From 99cc04ef882a7695c08e473fd9f98df942feb2d8 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 25 Jun 2024 12:47:15 -0600 Subject: [PATCH 071/103] Add tests for windows and mac (#2937) * extend tests for windows and mac --------- Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/run_tests.sh | 7 +- .github/workflows/all-tests.yml | 129 ++++++++++++++++++ pyproject.toml | 7 +- src/spikeinterface/core/datasets.py | 56 +++++--- .../extractors/tests/common_tests.py | 5 +- .../tests/test_datalad_downloading.py | 13 +- .../extractors/tests/test_neoextractors.py | 9 +- .../tests/test_principal_component.py | 2 +- .../sorters/tests/test_container_tools.py | 5 +- 9 files changed, 197 insertions(+), 36 deletions(-) create mode 100644 .github/workflows/all-tests.yml diff --git a/.github/run_tests.sh b/.github/run_tests.sh index 04a6b5ac6b..558e0b64d3 100644 --- a/.github/run_tests.sh +++ b/.github/run_tests.sh @@ -1,8 +1,13 @@ #!/bin/bash MARKER=$1 +NOVIRTUALENV=$2 + +# Check if the second argument is provided and if it is equal to --no-virtual-env +if [ -z "$NOVIRTUALENV" ] || [ "$NOVIRTUALENV" != "--no-virtual-env" ]; then + source $GITHUB_WORKSPACE/test_env/bin/activate +fi -source $GITHUB_WORKSPACE/test_env/bin/activate pytest -m "$MARKER" -vv -ra --durations=0 --durations-min=0.001 | tee report.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 echo "# Timing profile of ${MARKER}" >> $GITHUB_STEP_SUMMARY python $GITHUB_WORKSPACE/.github/build_job_summary.py report.txt >> $GITHUB_STEP_SUMMARY diff --git a/.github/workflows/all-tests.yml b/.github/workflows/all-tests.yml new file mode 100644 index 0000000000..1c426ba11c --- /dev/null +++ b/.github/workflows/all-tests.yml @@ -0,0 +1,129 @@ +name: Complete tests + +on: + workflow_dispatch: + schedule: + - cron: "0 12 * * 0" # Weekly on Sunday at noon UTC + pull_request: + types: [synchronize, opened, reopened] + branches: + - main + +env: + KACHERY_CLOUD_CLIENT_ID: ${{ secrets.KACHERY_CLOUD_CLIENT_ID }} + KACHERY_CLOUD_PRIVATE_KEY: ${{ secrets.KACHERY_CLOUD_PRIVATE_KEY }} + +concurrency: # Cancel previous workflows on the same pull request + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + run: + name: ${{ matrix.os }} Python ${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.12"] # Lower and higher versions we support + os: [macos-13, windows-latest, ubuntu-latest] + steps: + - uses: actions/checkout@v4 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + # cache: 'pip' # caching pip dependencies + + - name: Get current hash (SHA) of the ephy_testing_data repo + id: repo_hash + run: | + echo "dataset_hash=$(git ls-remote https://gin.g-node.org/NeuralEnsemble/ephy_testing_data.git HEAD | cut -f1)" + echo "dataset_hash=$(git ls-remote https://gin.g-node.org/NeuralEnsemble/ephy_testing_data.git HEAD | cut -f1)" >> $GITHUB_OUTPUT + shell: bash + - name: Cache datasets + id: cache-datasets + uses: actions/cache/restore@v4 + with: + path: ~/spikeinterface_datasets + key: ${{ runner.os }}-datasets-${{ steps.repo_hash.outputs.dataset_hash }} + restore-keys: ${{ runner.os }}-datasets + + - name: Install packages + run: | + git config --global user.email "CI@example.com" + git config --global user.name "CI Almighty" + pip install -e .[test,extractors,streaming_extractors,full] + pip install tabulate + shell: bash + + - name: Installad datalad + run: | + pip install datalad-installer + if [ ${{ runner.os }} = 'Linux' ]; then + datalad-installer --sudo ok git-annex --method datalad/packages + elif [ ${{ runner.os }} = 'macOS' ]; then + datalad-installer --sudo ok git-annex --method brew + elif [ ${{ runner.os }} = 'Windows' ]; then + datalad-installer --sudo ok git-annex --method datalad/git-annex:release + fi + pip install datalad + git config --global filter.annex.process "git-annex filter-process" # recommended for efficiency + shell: bash + + - name: Set execute permissions on run_tests.sh + run: chmod +x .github/run_tests.sh + shell: bash + + - name: Test core + run: pytest -m "core" + shell: bash + + - name: Test extractors + env: + HDF5_PLUGIN_PATH: ${{ github.workspace }}/hdf5_plugin_path_maxwell + run: pytest -m "extractors" + shell: bash + + - name: Test preprocessing + run: ./.github/run_tests.sh "preprocessing and not deepinterpolation" --no-virtual-env + shell: bash + + - name: Test postprocessing + run: ./.github/run_tests.sh postprocessing --no-virtual-env + shell: bash + + - name: Test quality metrics + run: ./.github/run_tests.sh qualitymetrics --no-virtual-env + shell: bash + + - name: Test comparison + run: ./.github/run_tests.sh comparison --no-virtual-env + shell: bash + + - name: Test core sorters + run: ./.github/run_tests.sh sorters --no-virtual-env + shell: bash + + - name: Test internal sorters + run: ./.github/run_tests.sh sorters_internal --no-virtual-env + shell: bash + + - name: Test curation + run: ./.github/run_tests.sh curation --no-virtual-env + shell: bash + + - name: Test widgets + run: ./.github/run_tests.sh widgets --no-virtual-env + shell: bash + + - name: Test exporters + run: ./.github/run_tests.sh exporters --no-virtual-env + shell: bash + + - name: Test sortingcomponents + run: ./.github/run_tests.sh sortingcomponents --no-virtual-env + shell: bash + + - name: Test generation + run: ./.github/run_tests.sh generation --no-virtual-env + shell: bash diff --git a/pyproject.toml b/pyproject.toml index 58c0f66e44..b26337ad01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,10 +137,9 @@ test = [ # for sortingview backend "sortingview", - - # recent datalad need a too recent version for git-annex - # so we use an old one here - "datalad==0.16.2", + # Download data + "pooch>=1.8.2", + "datalad>=1.0.2", ## install tridesclous for testing ## "tridesclous>=1.6.8", diff --git a/src/spikeinterface/core/datasets.py b/src/spikeinterface/core/datasets.py index 59cfbfac55..c8d897d9fc 100644 --- a/src/spikeinterface/core/datasets.py +++ b/src/spikeinterface/core/datasets.py @@ -14,10 +14,13 @@ def download_dataset( remote_path: str = "mearec/mearec_test_10s.h5", local_folder: Path | None = None, update_if_exists: bool = False, - unlock: bool = False, ) -> Path: """ - Function to download dataset from a remote repository using datalad. + Function to download dataset from a remote repository using a combination of datalad and pooch. + + Pooch is designed to download single files from a remote repository. + Because our datasets in gin sometimes point just to a folder, we still use datalad to download + a list of all the files in the folder and then use pooch to download them one by one. Parameters ---------- @@ -25,19 +28,25 @@ def download_dataset( The repository to download the dataset from remote_path : str, default: "mearec/mearec_test_10s.h5" A specific subdirectory in the repository to download (e.g. Mearec, SpikeGLX, etc) - local_folder : str, default: None + local_folder : str, optional The destination folder / directory to download the dataset to. - defaults to the path "get_global_dataset_folder()" / f{repo_name} (see `spikeinterface.core.globals`) + if None, then the path "get_global_dataset_folder()" / f{repo_name} is used (see `spikeinterface.core.globals`) update_if_exists : bool, default: False Forces re-download of the dataset if it already exists, default: False - unlock : bool, default: False - Use to enable the edition of the downloaded file content, default: False Returns ------- Path The local path to the downloaded dataset + + Notes + ----- + The reason we use pooch is because have had problems with datalad not being able to download + data on windows machines. Especially in the CI. + + See https://handbook.datalad.org/en/latest/intro/windows.html """ + import pooch import datalad.api from datalad.support.gitrepo import GitRepo @@ -45,25 +54,40 @@ def download_dataset( base_local_folder = get_global_dataset_folder() base_local_folder.mkdir(exist_ok=True, parents=True) local_folder = base_local_folder / repo.split("/")[-1] + local_folder.mkdir(exist_ok=True, parents=True) + else: + if not local_folder.is_dir(): + local_folder.mkdir(exist_ok=True, parents=True) local_folder = Path(local_folder) if local_folder.exists() and GitRepo.is_valid_repo(local_folder): dataset = datalad.api.Dataset(path=local_folder) - # make sure git repo is in clean state - repo = dataset.repo - if update_if_exists: - repo.call_git(["checkout", "--force", "master"]) - dataset.update(merge=True) else: dataset = datalad.api.install(path=local_folder, source=repo) local_path = local_folder / remote_path + dataset_status = dataset.status(path=remote_path, annex="simple") + + # Download only files that also have a git-annex key + dataset_status_files = [status for status in dataset_status if status["type"] == "file"] + dataset_status_files = [status for status in dataset_status_files if "key" in status] - # This downloads the data set content - dataset.get(remote_path) + git_annex_hashing_algorithm = {"MD5E": "md5"} + for status in dataset_status_files: + hash_algorithm = git_annex_hashing_algorithm[status["backend"]] + hash = status["keyname"].split(".")[0] + known_hash = f"{hash_algorithm}:{hash}" + fname = Path(status["path"]).relative_to(local_folder) + url = f"{repo}/raw/master/{fname.as_posix()}" + expected_full_path = local_folder / fname - # Unlock files of a dataset in order to be able to edit the actual content - if unlock: - dataset.unlock(remote_path, recursive=True) + full_path = pooch.retrieve( + url=url, + fname=str(fname), + path=local_folder, + known_hash=known_hash, + progressbar=True, + ) + assert full_path == str(expected_full_path) return local_path diff --git a/src/spikeinterface/extractors/tests/common_tests.py b/src/spikeinterface/extractors/tests/common_tests.py index dcbd2304f1..5432efa9f3 100644 --- a/src/spikeinterface/extractors/tests/common_tests.py +++ b/src/spikeinterface/extractors/tests/common_tests.py @@ -18,8 +18,9 @@ class CommonTestSuite: downloads = [] entities = [] - def setUp(self): - for remote_path in self.downloads: + @classmethod + def setUpClass(cls): + for remote_path in cls.downloads: download_dataset(repo=gin_repo, remote_path=remote_path, local_folder=local_folder, update_if_exists=True) diff --git a/src/spikeinterface/extractors/tests/test_datalad_downloading.py b/src/spikeinterface/extractors/tests/test_datalad_downloading.py index 97e68146a6..8abccc6707 100644 --- a/src/spikeinterface/extractors/tests/test_datalad_downloading.py +++ b/src/spikeinterface/extractors/tests/test_datalad_downloading.py @@ -1,15 +1,12 @@ import pytest from spikeinterface.core import download_dataset +import importlib.util -try: - import datalad - HAVE_DATALAD = True -except: - HAVE_DATALAD = False - - -@pytest.mark.skipif(not HAVE_DATALAD, reason="No datalad") +@pytest.mark.skipif( + importlib.util.find_spec("pooch") is None or importlib.util.find_spec("datalad") is None, + reason="Either pooch or datalad is not installed", +) def test_download_dataset(): repo = "https://gin.g-node.org/NeuralEnsemble/ephy_testing_data" remote_path = "mearec" diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index 379bf00c6b..acd7ebe8ad 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -351,8 +351,10 @@ def test_pickling(self): pass -# We run plexon2 tests only if we have dependencies (wine) -@pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") +# TODO solve plexon bug +@pytest.mark.skipif( + not has_plexon2_dependencies() or platform.system() == "Windows", reason="There is a bug on windows" +) class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2RecordingExtractor downloads = ["plexon"] @@ -361,6 +363,7 @@ class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): ] +@pytest.mark.skipif(not has_plexon2_dependencies() or platform.system() == "Windows", reason="There is a bug") @pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2EventExtractor @@ -370,7 +373,7 @@ class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): ] -@pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") +@pytest.mark.skipif(not has_plexon2_dependencies() or platform.system() == "Windows", reason="There is a bug") class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2SortingExtractor downloads = ["plexon"] diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 08ec32c6c2..38ae3b2c5e 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -136,7 +136,7 @@ def test_compute_for_all_spikes(self, sparse): ext.run_for_all_spikes(pc_file2, chunk_size=10000, n_jobs=2) all_pc2 = np.load(pc_file2) - assert np.array_equal(all_pc1, all_pc2) + np.testing.assert_almost_equal(all_pc1, all_pc2, decimal=3) def test_project_new(self): """ diff --git a/src/spikeinterface/sorters/tests/test_container_tools.py b/src/spikeinterface/sorters/tests/test_container_tools.py index 3ae03abff1..0369bca860 100644 --- a/src/spikeinterface/sorters/tests/test_container_tools.py +++ b/src/spikeinterface/sorters/tests/test_container_tools.py @@ -8,6 +8,7 @@ from spikeinterface import generate_ground_truth_recording from spikeinterface.sorters.container_tools import find_recording_folders, ContainerClient, install_package_in_container +import platform ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) @@ -58,7 +59,9 @@ def test_find_recording_folders(setup_module): assert str(f2[0]) == str((cache_folder / "multi").absolute()) # in this case the paths are in 3 separate drives - assert len(f3) == 3 + # Not a good test on windows because all the paths resolve to C when absolute in `find_recording_folders` + if platform.system() != "Windows": + assert len(f3) == 3 @pytest.mark.skipif(ON_GITHUB, reason="Docker tests don't run on github: test locally") From 921ec82c6ab955a1622fb28e60b05dbb455529c6 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 26 Jun 2024 10:43:49 +0200 Subject: [PATCH 072/103] Template similarity lags (#2941) Extend template similarity with lags and distance metrics --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alessio Buccino --- .../comparison/basecomparison.py | 8 +- .../comparison/multicomparisons.py | 9 +- .../comparison/paircomparisons.py | 47 +++--- .../postprocessing/template_similarity.py | 143 ++++++++++++++++-- .../tests/test_template_similarity.py | 17 ++- 5 files changed, 179 insertions(+), 45 deletions(-) diff --git a/src/spikeinterface/comparison/basecomparison.py b/src/spikeinterface/comparison/basecomparison.py index 0fdda745b2..f1d2130d38 100644 --- a/src/spikeinterface/comparison/basecomparison.py +++ b/src/spikeinterface/comparison/basecomparison.py @@ -313,9 +313,11 @@ class MixinTemplateComparison: """ Mixin for template comparisons to define: * similarity method - * sparsity + * support + * num_shifts """ - def __init__(self, similarity_method="cosine_similarity", sparsity_dict=None): + def __init__(self, similarity_method="cosine", support="union", num_shifts=0): self.similarity_method = similarity_method - self.sparsity_dict = sparsity_dict + self.support = support + self.num_shifts = num_shifts diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 7cde985b37..35c298d4ac 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -333,8 +333,9 @@ def __init__( match_score=0.8, chance_score=0.3, verbose=False, - similarity_method="cosine_similarity", - sparsity_dict=None, + similarity_method="cosine", + support="union", + num_shifts=0, do_matching=True, ): if name_list is None: @@ -347,7 +348,9 @@ def __init__( chance_score=chance_score, verbose=verbose, ) - MixinTemplateComparison.__init__(self, similarity_method=similarity_method, sparsity_dict=sparsity_dict) + MixinTemplateComparison.__init__( + self, similarity_method=similarity_method, support=support, num_shifts=num_shifts + ) if do_matching: self._compute_all() diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index ea4b72b200..7d5f04dfdd 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -697,24 +697,26 @@ class TemplateComparison(BasePairComparison, MixinTemplateComparison): Parameters ---------- sorting_analyzer_1 : SortingAnalyzer - The first SortingAnalyzer to get templates to compare + The first SortingAnalyzer to get templates to compare. sorting_analyzer_2 : SortingAnalyzer - The second SortingAnalyzer to get templates to compare + The second SortingAnalyzer to get templates to compare. unit_ids1 : list, default: None - List of units from sorting_analyzer_1 to compare + List of units from sorting_analyzer_1 to compare. unit_ids2 : list, default: None - List of units from sorting_analyzer_2 to compare - similarity_method : str, default: "cosine_similarity" - Method for the similaroty matrix - sparsity_dict : dict, default: None - Dictionary for sparsity + List of units from sorting_analyzer_2 to compare. + similarity_method : "cosine" | "l1" | "l2", default: "cosine" + Method for the similarity matrix. + support : "dense" | "union" | "intersection", default: "union" + The support to compute the similarity matrix. + num_shifts : int, default: 0 + Number of shifts to use to shift templates to maximize similarity. verbose : bool, default: False - If True, output is verbose + If True, output is verbose. Returns ------- comparison : TemplateComparison - The output TemplateComparison object + The output TemplateComparison object. """ def __init__( @@ -727,8 +729,9 @@ def __init__( unit_ids2=None, match_score=0.7, chance_score=0.3, - similarity_method="cosine_similarity", - sparsity_dict=None, + similarity_method="cosine", + support="union", + num_shifts=0, verbose=False, ): if name1 is None: @@ -745,7 +748,9 @@ def __init__( chance_score=chance_score, verbose=verbose, ) - MixinTemplateComparison.__init__(self, similarity_method=similarity_method, sparsity_dict=sparsity_dict) + MixinTemplateComparison.__init__( + self, similarity_method=similarity_method, support=support, num_shifts=num_shifts + ) self.sorting_analyzer_1 = sorting_analyzer_1 self.sorting_analyzer_2 = sorting_analyzer_2 @@ -754,10 +759,9 @@ def __init__( # two options: all channels are shared or partial channels are shared if sorting_analyzer_1.recording.get_num_channels() != sorting_analyzer_2.recording.get_num_channels(): - raise NotImplementedError + raise ValueError("The two recordings must have the same number of channels") if np.any([ch1 != ch2 for (ch1, ch2) in zip(channel_ids1, channel_ids2)]): - # TODO: here we can check location and run it on the union. Might be useful for reconfigurable probes - raise NotImplementedError + raise ValueError("The two recordings must have the same channel ids") self.matches = dict() @@ -768,11 +772,6 @@ def __init__( unit_ids2 = sorting_analyzer_2.sorting.get_unit_ids() self.unit_ids = [unit_ids1, unit_ids2] - if sparsity_dict is not None: - raise NotImplementedError - else: - self.sparsity = None - self._do_agreement() self._do_matching() @@ -781,7 +780,11 @@ def _do_agreement(self): print("Agreement scores...") agreement_scores = compute_template_similarity_by_pair( - self.sorting_analyzer_1, self.sorting_analyzer_2, method=self.similarity_method + self.sorting_analyzer_1, + self.sorting_analyzer_2, + method=self.similarity_method, + support=self.support, + num_shifts=self.num_shifts, ) import pandas as pd diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 15a1fe34ce..777f84dfd7 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +import warnings from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension from ..core.template_tools import get_dense_templates_array @@ -9,13 +10,26 @@ class ComputeTemplateSimilarity(AnalyzerExtension): """Compute similarity between templates with several methods. + Similarity is defined as 1 - distance(T_1, T_2) for two templates T_1, T_2 + Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer The SortingAnalyzer object - method: str, default: "cosine_similarity" - The method to compute the similarity + method : str, default: "cosine" + The method to compute the similarity. Can be in ["cosine", "l2", "l1"] + max_lag_ms : float, default: 0 + If specified, the best distance for all given lag within max_lag_ms is kept, for every template + support : "dense" | "union" | "intersection", default: "union" + Support that should be considered to compute the distances between the templates, given their sparsities. + Can be either ["dense", "union", "intersection"] + + In case of "l1" or "l2", the formula used is: + similarity = 1 - norm(T_1 - T_2)/(norm(T_1) + norm(T_2)) + + In case of cosine this is: + similarity = 1 - sum(T_1.T_2)/(norm(T_1)norm(T_2)) Returns ------- @@ -32,8 +46,15 @@ class ComputeTemplateSimilarity(AnalyzerExtension): def __init__(self, sorting_analyzer): AnalyzerExtension.__init__(self, sorting_analyzer) - def _set_params(self, method="cosine_similarity"): - params = dict(method=method) + def _set_params(self, method="cosine", max_lag_ms=0, support="union"): + if method == "cosine_similarity": + warnings.warn( + "The method 'cosine_similarity' is deprecated and will be removed in the next version. Use 'cosine' instead.", + DeprecationWarning, + stacklevel=2, + ) + method = "cosine" + params = dict(method=method, max_lag_ms=max_lag_ms, support=support) return params def _select_extension_data(self, unit_ids): @@ -43,11 +64,19 @@ def _select_extension_data(self, unit_ids): return dict(similarity=new_similarity) def _run(self, verbose=False): + num_shifts = int(self.params["max_lag_ms"] * self.sorting_analyzer.sampling_frequency / 1000) templates_array = get_dense_templates_array( self.sorting_analyzer, return_scaled=self.sorting_analyzer.return_scaled ) + sparsity = self.sorting_analyzer.sparsity similarity = compute_similarity_with_templates_array( - templates_array, templates_array, method=self.params["method"] + templates_array, + templates_array, + method=self.params["method"], + num_shifts=num_shifts, + support=self.params["support"], + sparsity=sparsity, + other_sparsity=sparsity, ) self.data["similarity"] = similarity @@ -60,25 +89,109 @@ def _get_data(self): compute_template_similarity = ComputeTemplateSimilarity.function_factory() -def compute_similarity_with_templates_array(templates_array, other_templates_array, method): +def compute_similarity_with_templates_array( + templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None +): + import sklearn.metrics.pairwise if method == "cosine_similarity": - assert templates_array.shape[0] == other_templates_array.shape[0] - templates_flat = templates_array.reshape(templates_array.shape[0], -1) - other_templates_flat = templates_array.reshape(other_templates_array.shape[0], -1) - similarity = sklearn.metrics.pairwise.cosine_similarity(templates_flat, other_templates_flat) - + method = "cosine" + + all_metrics = ["cosine", "l1", "l2"] + + if method not in all_metrics: + raise ValueError(f"compute_template_similarity (method {method}) not exists") + + assert ( + templates_array.shape[1] == other_templates_array.shape[1] + ), "The number of samples in the templates should be the same for both arrays" + assert ( + templates_array.shape[2] == other_templates_array.shape[2] + ), "The number of channels in the templates should be the same for both arrays" + num_templates = templates_array.shape[0] + num_samples = templates_array.shape[1] + num_channels = templates_array.shape[2] + other_num_templates = other_templates_array.shape[0] + + mask = None + if sparsity is not None and other_sparsity is not None: + if support == "intersection": + mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) + elif support == "union": + mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) + units_overlaps = np.sum(mask, axis=2) > 0 + mask = np.logical_or(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) + mask[~units_overlaps] = False + if mask is not None: + units_overlaps = np.sum(mask, axis=2) > 0 + overlapping_templates = {} + for i in range(num_templates): + overlapping_templates[i] = np.flatnonzero(units_overlaps[i]) else: - raise ValueError(f"compute_template_similarity(method {method}) not exists") + # here we make a dense mask and overlapping templates + overlapping_templates = {i: np.arange(other_num_templates) for i in range(num_templates)} + mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool) + + assert num_shifts < num_samples, "max_lag is too large" + num_shifts_both_sides = 2 * num_shifts + 1 + distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32) + + # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t + # So the matrix can be computed only for negative lags and be transposed + for count, shift in enumerate(range(-num_shifts, 1)): + src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts] + tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] + for i in range(num_templates): + src_template = src_sliced_templates[i] + tgt_templates = tgt_sliced_templates[overlapping_templates[i]] + for gcount, j in enumerate(overlapping_templates[i]): + # symmetric values are handled later + if num_templates == other_num_templates and j < i: + continue + src = src_template[:, mask[i, j]].reshape(1, -1) + tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1) + + if method == "l1": + norm_i = np.sum(np.abs(src)) + norm_j = np.sum(np.abs(tgt)) + distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l1") + distances[count, i, j] /= norm_i + norm_j + elif method == "l2": + norm_i = np.linalg.norm(src, ord=2) + norm_j = np.linalg.norm(tgt, ord=2) + distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l2") + distances[count, i, j] /= norm_i + norm_j + else: + distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="cosine") + if num_templates == other_num_templates: + distances[count, j, i] = distances[count, i, j] + + if num_shifts != 0: + distances[num_shifts_both_sides - count - 1] = distances[count].T + + distances = np.min(distances, axis=0) + similarity = 1 - distances return similarity -def compute_template_similarity_by_pair(sorting_analyzer_1, sorting_analyzer_2, method="cosine_similarity"): +def compute_template_similarity_by_pair( + sorting_analyzer_1, sorting_analyzer_2, method="cosine", support="union", num_shifts=0 +): templates_array_1 = get_dense_templates_array(sorting_analyzer_1, return_scaled=True) templates_array_2 = get_dense_templates_array(sorting_analyzer_2, return_scaled=True) - similarity = compute_similarity_with_templates_array(templates_array_1, templates_array_2, method) + sparsity_1 = sorting_analyzer_1.sparsity + sparsity_2 = sorting_analyzer_2.sparsity + similarity = compute_similarity_with_templates_array( + templates_array_1, + templates_array_2, + method=method, + support=support, + num_shifts=num_shifts, + sparsity=sparsity_1, + other_sparsity=sparsity_2, + ) return similarity diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index a4de2a3a90..f98a5624db 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -1,3 +1,5 @@ +import pytest + from spikeinterface.postprocessing.tests.common_extension_tests import ( AnalyzerExtensionCommonTestSuite, ) @@ -7,8 +9,19 @@ class TestSimilarityExtension(AnalyzerExtensionCommonTestSuite): - def test_extension(self): - self.run_extension_tests(ComputeTemplateSimilarity, params=dict(method="cosine_similarity")) + @pytest.mark.parametrize( + "params", + [ + dict(method="cosine"), + dict(method="l2"), + dict(method="l1", max_lag_ms=0.2), + dict(method="l1", support="intersection"), + dict(method="l2", support="union"), + dict(method="cosine", support="dense"), + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputeTemplateSimilarity, params=params) def test_check_equal_template_with_distribution_overlap(self): """ From 2867d7c09977cd5c77a9bf5fdf0793e3e1f05314 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 26 Jun 2024 06:32:08 -0600 Subject: [PATCH 073/103] remove cached dependencies (#3080) --- .../actions/show-test-environment/action.yml | 23 ------------------- .github/workflows/full-test-with-codecov.yml | 11 --------- .github/workflows/full-test.yml | 8 ------- 3 files changed, 42 deletions(-) delete mode 100644 .github/actions/show-test-environment/action.yml diff --git a/.github/actions/show-test-environment/action.yml b/.github/actions/show-test-environment/action.yml deleted file mode 100644 index 3bc062d414..0000000000 --- a/.github/actions/show-test-environment/action.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Log test environment -description: Shows installed packages by pip, git-annex and cached testing files - -inputs: {} - -runs: - using: "composite" - steps: - - name: git-annex version - run: | - git-annex version - shell: bash - - name: Packages installed - run: | - source ${{ github.workspace }}/test_env/bin/activate - pip list - shell: bash - - name: Check ephy_testing_data files - run: | - if [ -d "$HOME/spikeinterface_datasets" ]; then - find $HOME/spikeinterface_datasets - fi - shell: bash diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index 75847759f6..ab4a083ae1 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -23,15 +23,6 @@ jobs: - uses: actions/setup-python@v5 with: python-version: '3.10' - - name: Get current year-month - id: date - run: echo "date=$(date +'%Y-%m')" >> $GITHUB_OUTPUT - - name: Restore cached virtual environment with dependencies - uses: actions/cache/restore@v4 - id: cache-venv - with: - path: ${{ github.workspace }}/test_env - key: ${{ runner.os }}-venv-${{ hashFiles('**/pyproject.toml') }}-${{ steps.date.outputs.date }} - name: Get ephy_testing_data current head hash # the key depends on the last comit repo https://gin.g-node.org/NeuralEnsemble/ephy_testing_data.git id: vars @@ -49,8 +40,6 @@ jobs: restore-keys: ${{ runner.os }}-datasets - name: Install packages uses: ./.github/actions/build-test-environment - - name: Shows installed packages by pip, git-annex and cached testing files - uses: ./.github/actions/show-test-environment - name: run tests env: HDF5_PLUGIN_PATH: ${{ github.workspace }}/hdf5_plugin_path_maxwell diff --git a/.github/workflows/full-test.yml b/.github/workflows/full-test.yml index b432fbd4d5..ed2f28dc23 100644 --- a/.github/workflows/full-test.yml +++ b/.github/workflows/full-test.yml @@ -31,12 +31,6 @@ jobs: - name: Get current year-month id: date run: echo "date=$(date +'%Y-%m')" >> $GITHUB_OUTPUT - - name: Restore cached virtual environment with dependencies - uses: actions/cache/restore@v4 - id: cache-venv - with: - path: ${{ github.workspace }}/test_env - key: ${{ runner.os }}-venv-${{ hashFiles('**/pyproject.toml') }}-${{ steps.date.outputs.date }} - name: Get ephy_testing_data current head hash # the key depends on the last comit repo https://gin.g-node.org/NeuralEnsemble/ephy_testing_data.git id: vars @@ -54,8 +48,6 @@ jobs: restore-keys: ${{ runner.os }}-datasets - name: Install packages uses: ./.github/actions/build-test-environment - - name: Shows installed packages by pip, git-annex and cached testing files - uses: ./.github/actions/show-test-environment - name: Get changed files id: changed-files uses: tj-actions/changed-files@v41 From b88ddcb9969e01c019452e1a0d1832b092390ea8 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 26 Jun 2024 15:17:09 +0100 Subject: [PATCH 074/103] Respond to review --- src/spikeinterface/extractors/neoextractors/intan.py | 4 ++-- src/spikeinterface/extractors/toy_example.py | 6 +++--- src/spikeinterface/preprocessing/filter.py | 4 +++- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index 9d4db3103c..50fda79123 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -27,8 +27,8 @@ class IntanRecordingExtractor(NeoBaseRecordingExtractor): If True, data that violates integrity assumptions will be loaded. At the moment the only integrity check we perform is that timestamps are continuous. Setting this to True will ignore this check and set the attribute `discontinuous_timestamps` to True in the underlying neo object. - use_names_as_ids : bool or None, default: None - If True, use channel names as IDs. If None, use default IDs. + use_names_as_ids : bool, default: False + If True, use channel names as IDs. If False, use default IDs inherited from neo. """ mode = "file" diff --git a/src/spikeinterface/extractors/toy_example.py b/src/spikeinterface/extractors/toy_example.py index 2f007cca88..55b787f3ed 100644 --- a/src/spikeinterface/extractors/toy_example.py +++ b/src/spikeinterface/extractors/toy_example.py @@ -62,13 +62,13 @@ def toy_example( seed : int or None, default: None Seed for random initialization. upsample_factor : None or int, default: None - A upsampling factor used only when templates are not provided. + An upsampling factor, used only when templates are not provided. num_columns : int, default: 1 Number of columns in probe. average_peak_amplitude : float, default: -100 - Average peak amplitude of generated templates + Average peak amplitude of generated templates. contact_spacing_um : float, default: 40.0 - Spacing between probe contacts. + Spacing between probe contacts in micrometers. Returns ------- diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index d18227ca83..93462ac5d8 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -11,7 +11,9 @@ _common_filter_docs = """**filter_kwargs : dict Certain keyword arguments for `scipy.signal` filters: filter_order : order - The order of the filter + The order of the filter. Note as filtering is applied with scipy's + `filtfilt` functions (i.e. acausal, zero-phase) the effective + order will be double the `filter_order`. filter_mode : "sos" | "ba", default: "sos" Filter form of the filter coefficients: - second-order sections ("sos") From a166e5a3d419c49aa6afc69f0e2f98ea7eb9d0c3 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 26 Jun 2024 15:33:51 -0600 Subject: [PATCH 075/103] add recording iterator --- src/spikeinterface/core/core_tools.py | 58 +++++++++++++++++-- src/spikeinterface/sorters/container_tools.py | 11 +--- 2 files changed, 55 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index f3d8b3df7f..3fe4939524 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -1,6 +1,6 @@ from __future__ import annotations from pathlib import Path, WindowsPath -from typing import Union +from typing import Union, Generator import os import sys import datetime @@ -8,6 +8,7 @@ from copy import deepcopy import importlib from math import prod +from collections import namedtuple import numpy as np @@ -183,6 +184,50 @@ def is_dict_extractor(d: dict) -> bool: return is_extractor +recording_dict_element = namedtuple(typename="recording_dict_element", field_names=["value", "name", "access_path"]) + + +def recording_dict_iterator(extractor_dict: dict) -> Generator[recording_dict_element]: + """ + Iterator for recursive traversal of a dictionary. + This function explores the dictionary recursively and yields the path to each value along with the value itself. + + By path here we mean the keys that lead to the value in the dictionary: + e.g. for the dictionary {'a': {'b': 1}}, the path to the value 1 is ('a', 'b'). + + See `BaseExtractor.to_dict()` for a description of `extractor_dict` structure. + + Parameters + ---------- + extractor_dict : dict + Input dictionary + + Yields + ------ + recording_dict_element + Named tuple containing the value, the name, and the access_path to the value in the dictionary. + + """ + + def _recording_dict_iterator(dict_list_or_value, access_path=(), name=""): + if isinstance(dict_list_or_value, dict): + for k, v in dict_list_or_value.items(): + yield from _recording_dict_iterator(v, access_path + (k,), name=k) + elif isinstance(dict_list_or_value, list): + for i, v in enumerate(dict_list_or_value): + yield from _recording_dict_iterator( + v, access_path + (i,), name=name + ) # Propagate name of list to children + else: + yield recording_dict_element( + value=dict_list_or_value, + name=name, + access_path=access_path, + ) + + yield from _recording_dict_iterator(extractor_dict) + + def recursive_path_modifier(d, func, target="path", copy=True) -> dict: """ Generic function for recursive modification of paths in an extractor dict. @@ -250,15 +295,16 @@ def recursive_path_modifier(d, func, target="path", copy=True) -> dict: raise ValueError(f"{k} key for path must be str or list[str]") -def _get_paths_list(d): +def _get_paths_list(d: dict) -> list[str | Path]: # this explore a dict and get all paths flatten in a list # the trick is to use a closure func called by recursive_path_modifier() - path_list = [] - def append_to_path(p): - path_list.append(p) + element_is_path = lambda element: "path" in element.name and isinstance(element.value, (str, Path)) + path_list = [e.value for e in recording_dict_iterator(d) if element_is_path(e)] + + # if check_if_exists: TODO: Enable this once container_tools test uses proper mocks + # path_list = [p for p in path_list if Path(p).exists()] - recursive_path_modifier(d, append_to_path, target="path", copy=True) return path_list diff --git a/src/spikeinterface/sorters/container_tools.py b/src/spikeinterface/sorters/container_tools.py index 60eb080ae5..8e03090eaf 100644 --- a/src/spikeinterface/sorters/container_tools.py +++ b/src/spikeinterface/sorters/container_tools.py @@ -9,19 +9,14 @@ # TODO move this inside functions -from spikeinterface.core.core_tools import recursive_path_modifier +from spikeinterface.core.core_tools import recursive_path_modifier, _get_paths_list def find_recording_folders(d): """Finds all recording folders 'paths' in a dict""" - folders_to_mount = [] - def append_parent_folder(p): - p = Path(p) - folders_to_mount.append(p.resolve().absolute().parent) - return p - - _ = recursive_path_modifier(d, append_parent_folder, target="path", copy=True) + path_list = _get_paths_list(d=d) + folders_to_mount = [Path(p).resolve().parent for p in path_list] try: # this will fail if on different drives (Windows) base_folders_to_mount = [Path(os.path.commonpath(folders_to_mount))] From 27a7c9a96c2e8f008109c99d8dd90ac52ac5fd3e Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 26 Jun 2024 16:58:39 -0600 Subject: [PATCH 076/103] add and fix tests --- src/spikeinterface/core/core_tools.py | 83 ++++++++-- .../core/tests/test_core_tools.py | 153 ++++++++++++------ 2 files changed, 170 insertions(+), 66 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 3fe4939524..9e90b56c8d 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -187,7 +187,7 @@ def is_dict_extractor(d: dict) -> bool: recording_dict_element = namedtuple(typename="recording_dict_element", field_names=["value", "name", "access_path"]) -def recording_dict_iterator(extractor_dict: dict) -> Generator[recording_dict_element]: +def extractor_dict_iterator(extractor_dict: dict) -> Generator[recording_dict_element]: """ Iterator for recursive traversal of a dictionary. This function explores the dictionary recursively and yields the path to each value along with the value itself. @@ -209,13 +209,13 @@ def recording_dict_iterator(extractor_dict: dict) -> Generator[recording_dict_el """ - def _recording_dict_iterator(dict_list_or_value, access_path=(), name=""): + def _extractor_dict_iterator(dict_list_or_value, access_path=(), name=""): if isinstance(dict_list_or_value, dict): for k, v in dict_list_or_value.items(): - yield from _recording_dict_iterator(v, access_path + (k,), name=k) + yield from _extractor_dict_iterator(v, access_path + (k,), name=k) elif isinstance(dict_list_or_value, list): for i, v in enumerate(dict_list_or_value): - yield from _recording_dict_iterator( + yield from _extractor_dict_iterator( v, access_path + (i,), name=name ) # Propagate name of list to children else: @@ -225,7 +225,32 @@ def _recording_dict_iterator(dict_list_or_value, access_path=(), name=""): access_path=access_path, ) - yield from _recording_dict_iterator(extractor_dict) + yield from _extractor_dict_iterator(extractor_dict) + + +def set_value_in_recording_dict(extractor_dict: dict, access_path: tuple, new_value): + """ + In place modification of a value in a nested dictionary given its access path. + + Parameters + ---------- + extractor_dict : dict + The dictionary to modify + access_path : tuple + The path to the value in the dictionary + new_value : object + The new value to set + + Returns + ------- + dict + The modified dictionary + """ + + current = extractor_dict + for key in access_path[:-1]: + current = current[key] + current[access_path[-1]] = new_value def recursive_path_modifier(d, func, target="path", copy=True) -> dict: @@ -295,12 +320,13 @@ def recursive_path_modifier(d, func, target="path", copy=True) -> dict: raise ValueError(f"{k} key for path must be str or list[str]") -def _get_paths_list(d: dict) -> list[str | Path]: - # this explore a dict and get all paths flatten in a list - # the trick is to use a closure func called by recursive_path_modifier() +# This is the current definition that an element in a recording_dict is a path +# This is shared across a couple of definition so it is here for DNRY +element_is_path = lambda element: "path" in element.name and isinstance(element.value, (str, Path)) + - element_is_path = lambda element: "path" in element.name and isinstance(element.value, (str, Path)) - path_list = [e.value for e in recording_dict_iterator(d) if element_is_path(e)] +def _get_paths_list(d: dict) -> list[str | Path]: + path_list = [e.value for e in extractor_dict_iterator(d) if element_is_path(e)] # if check_if_exists: TODO: Enable this once container_tools test uses proper mocks # path_list = [p for p in path_list if Path(p).exists()] @@ -364,7 +390,7 @@ def check_paths_relative(input_dict, relative_folder) -> bool: return len(not_possible) == 0 -def make_paths_relative(input_dict, relative_folder) -> dict: +def make_paths_relative(input_dict: dict, relative_folder: str | Path) -> dict: """ Recursively transform a dict describing an BaseExtractor to make every path relative to a folder. @@ -380,9 +406,22 @@ def make_paths_relative(input_dict, relative_folder) -> dict: output_dict: dict A copy of the input dict with modified paths. """ + relative_folder = Path(relative_folder).resolve().absolute() - func = lambda p: _relative_to(p, relative_folder) - output_dict = recursive_path_modifier(input_dict, func, target="path", copy=True) + + path_elements_in_dict = [e for e in extractor_dict_iterator(input_dict) if element_is_path(e)] + # Only paths that exist are made relative + path_elements_in_dict = [e for e in path_elements_in_dict if Path(e.value).exists()] + + output_dict = deepcopy(input_dict) + for element in path_elements_in_dict: + new_value = _relative_to(element.value, relative_folder) + set_value_in_recording_dict( + extractor_dict=output_dict, + access_path=element.access_path, + new_value=new_value, + ) + return output_dict @@ -405,12 +444,28 @@ def make_paths_absolute(input_dict, base_folder): base_folder = Path(base_folder) # use as_posix instead of str to make the path unix like even on window func = lambda p: (base_folder / p).resolve().absolute().as_posix() - output_dict = recursive_path_modifier(input_dict, func, target="path", copy=True) + + path_elements_in_dict = [e for e in extractor_dict_iterator(input_dict) if element_is_path(e)] + output_dict = deepcopy(input_dict) + + output_dict = deepcopy(input_dict) + for element in path_elements_in_dict: + absolute_path = (base_folder / element.value).resolve() + if Path(absolute_path).exists(): + new_value = absolute_path.as_posix() # Not so sure about this, Sam + set_value_in_recording_dict( + extractor_dict=output_dict, + access_path=element.access_path, + new_value=new_value, + ) + return output_dict def recursive_key_finder(d, key): # Find all values for a key on a dictionary, even if nested + # TODO refactor to use extractor_dict_iterator + for k, v in d.items(): if isinstance(v, dict): yield from recursive_key_finder(v, key) diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index 8e00dcb779..043e0cabf3 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -51,14 +51,9 @@ def test_path_utils_functions(): assert d2["kwargs"]["path"].startswith("/yop") assert d2["kwargs"]["recording"]["kwargs"]["path"].startswith("/yop") - d3 = make_paths_relative(d, Path("/yep")) - assert d3["kwargs"]["path"] == "sub/path1" - assert d3["kwargs"]["recording"]["kwargs"]["path"] == "sub/path2" - - d4 = make_paths_absolute(d3, "/yop") - assert d4["kwargs"]["path"].startswith("/yop") - assert d4["kwargs"]["recording"]["kwargs"]["path"].startswith("/yop") +@pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") +def test_relative_path_on_windows(): if platform.system() == "Windows": # test for windows Path d = { @@ -74,57 +69,111 @@ def test_path_utils_functions(): } } - d2 = make_paths_relative(d, "c:\\yep") - # the str be must unix like path even on windows for more portability - assert d2["kwargs"]["path"] == "sub/path1" - assert d2["kwargs"]["recording"]["kwargs"]["path"] == "sub/path2" - # same drive assert check_paths_relative(d, r"c:\yep") # not the same drive assert not check_paths_relative(d, r"d:\yep") - d = { - "kwargs": { - "path": r"\\host\share\yep\sub\path1", - } - } - # UNC cannot be relative to d: drive - assert not check_paths_relative(d, r"d:\yep") - # UNC can be relative to the same UNC - assert check_paths_relative(d, r"\\host\share") - - def test_convert_string_to_bytes(): - # Test SI prefixes - assert convert_string_to_bytes("1k") == 1000 - assert convert_string_to_bytes("1M") == 1000000 - assert convert_string_to_bytes("1G") == 1000000000 - assert convert_string_to_bytes("1T") == 1000000000000 - assert convert_string_to_bytes("1P") == 1000000000000000 - # Test IEC prefixes - assert convert_string_to_bytes("1Ki") == 1024 - assert convert_string_to_bytes("1Mi") == 1048576 - assert convert_string_to_bytes("1Gi") == 1073741824 - assert convert_string_to_bytes("1Ti") == 1099511627776 - assert convert_string_to_bytes("1Pi") == 1125899906842624 - # Test mixed values - assert convert_string_to_bytes("1.5k") == 1500 - assert convert_string_to_bytes("2.5M") == 2500000 - assert convert_string_to_bytes("0.5G") == 500000000 - assert convert_string_to_bytes("1.2T") == 1200000000000 - assert convert_string_to_bytes("1.5Pi") == 1688849860263936 - # Test zero values - assert convert_string_to_bytes("0k") == 0 - assert convert_string_to_bytes("0Ki") == 0 - # Test invalid inputs (should raise assertion error) - with pytest.raises(AssertionError) as e: - convert_string_to_bytes("1Z") - assert str(e.value) == "Unknown suffix: Z" - - with pytest.raises(AssertionError) as e: - convert_string_to_bytes("1Xi") - assert str(e.value) == "Unknown suffix: Xi" +@pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") +def test_universal_naming_convention(): + d = { + "kwargs": { + "path": r"\\host\share\yep\sub\path1", + } + } + # UNC cannot be relative to d: drive + assert not check_paths_relative(d, r"d:\yep") + + # UNC can be relative to the same UNC + assert check_paths_relative(d, r"\\host\share") + + +def test_make_paths_relative(tmp_path): + + path_1 = tmp_path / "sub" / "path1" + path_2 = tmp_path / "sub" / "path2" + + # Create the objects in the path + path_1.mkdir(parents=True, exist_ok=True) + path_2.mkdir(parents=True, exist_ok=True) + extractor_dict = { + "kwargs": { + "path": str(path_1), # Note this is different in windows and posix + "electrical_series_path": "/acquisition/timeseries", # non-existent path-like objects should not be modified + "recording": { + "module": "mock_module", + "class": "mock_class", + "version": "1.2", + "annotations": {}, + "kwargs": {"path": str(path_2)}, + }, + } + } + modified_extractor_dict = make_paths_relative(extractor_dict, tmp_path) + assert modified_extractor_dict["kwargs"]["path"] == "sub/path1" + assert modified_extractor_dict["kwargs"]["recording"]["kwargs"]["path"] == "sub/path2" + assert modified_extractor_dict["kwargs"]["electrical_series_path"] == "/acquisition/timeseries" + + +def test_make_paths_absolute(tmp_path): + + path_1 = tmp_path / "sub" / "path1" + path_2 = tmp_path / "sub" / "path2" + + path_1.mkdir(parents=True, exist_ok=True) + path_2.mkdir(parents=True, exist_ok=True) + + extractor_dict = { + "kwargs": { + "path": "sub/path1", + "electrical_series_path": "/acquisition/timeseries", # non-existent path-like objects should not be modified + "recording": { + "module": "mock_module", + "class": "mock_class", + "version": "1.2", + "annotations": {}, + "kwargs": {"path": "sub/path2"}, + }, + } + } + + modified_extractor_dict = make_paths_absolute(extractor_dict, tmp_path) + assert modified_extractor_dict["kwargs"]["path"].startswith(str(tmp_path)) + assert modified_extractor_dict["kwargs"]["recording"]["kwargs"]["path"].startswith(str(tmp_path)) + assert modified_extractor_dict["kwargs"]["electrical_series_path"] == "/acquisition/timeseries" + + +def test_convert_string_to_bytes(): + # Test SI prefixes + assert convert_string_to_bytes("1k") == 1000 + assert convert_string_to_bytes("1M") == 1000000 + assert convert_string_to_bytes("1G") == 1000000000 + assert convert_string_to_bytes("1T") == 1000000000000 + assert convert_string_to_bytes("1P") == 1000000000000000 + # Test IEC prefixes + assert convert_string_to_bytes("1Ki") == 1024 + assert convert_string_to_bytes("1Mi") == 1048576 + assert convert_string_to_bytes("1Gi") == 1073741824 + assert convert_string_to_bytes("1Ti") == 1099511627776 + assert convert_string_to_bytes("1Pi") == 1125899906842624 + # Test mixed values + assert convert_string_to_bytes("1.5k") == 1500 + assert convert_string_to_bytes("2.5M") == 2500000 + assert convert_string_to_bytes("0.5G") == 500000000 + assert convert_string_to_bytes("1.2T") == 1200000000000 + assert convert_string_to_bytes("1.5Pi") == 1688849860263936 + # Test zero values + assert convert_string_to_bytes("0k") == 0 + assert convert_string_to_bytes("0Ki") == 0 + # Test invalid inputs (should raise assertion error) + with pytest.raises(AssertionError) as e: + convert_string_to_bytes("1Z") + assert str(e.value) == "Unknown suffix: Z" + + with pytest.raises(AssertionError) as e: + convert_string_to_bytes("1Xi") + assert str(e.value) == "Unknown suffix: Xi" def test_normal_pdf() -> None: From b3b85b2fe5670217d80c4adec1a751d1e1d5d024 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 26 Jun 2024 17:21:45 -0600 Subject: [PATCH 077/103] naming --- src/spikeinterface/core/core_tools.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 9e90b56c8d..d5480d6f00 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -228,7 +228,7 @@ def _extractor_dict_iterator(dict_list_or_value, access_path=(), name=""): yield from _extractor_dict_iterator(extractor_dict) -def set_value_in_recording_dict(extractor_dict: dict, access_path: tuple, new_value): +def set_value_in_extractor_dict(extractor_dict: dict, access_path: tuple, new_value): """ In place modification of a value in a nested dictionary given its access path. @@ -416,7 +416,7 @@ def make_paths_relative(input_dict: dict, relative_folder: str | Path) -> dict: output_dict = deepcopy(input_dict) for element in path_elements_in_dict: new_value = _relative_to(element.value, relative_folder) - set_value_in_recording_dict( + set_value_in_extractor_dict( extractor_dict=output_dict, access_path=element.access_path, new_value=new_value, @@ -453,7 +453,7 @@ def make_paths_absolute(input_dict, base_folder): absolute_path = (base_folder / element.value).resolve() if Path(absolute_path).exists(): new_value = absolute_path.as_posix() # Not so sure about this, Sam - set_value_in_recording_dict( + set_value_in_extractor_dict( extractor_dict=output_dict, access_path=element.access_path, new_value=new_value, From d794c8220e9e2ed2431636e53aee9b7b8d6b998b Mon Sep 17 00:00:00 2001 From: h-mayorquin Date: Thu, 27 Jun 2024 00:39:58 -0600 Subject: [PATCH 078/103] windows test remove inner conditional --- .../core/tests/test_core_tools.py | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index 043e0cabf3..ed13bd46fd 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -54,25 +54,24 @@ def test_path_utils_functions(): @pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") def test_relative_path_on_windows(): - if platform.system() == "Windows": - # test for windows Path - d = { - "kwargs": { - "path": r"c:\yep\sub\path1", - "recording": { - "module": "mock_module", - "class": "mock_class", - "version": "1.2", - "annotations": {}, - "kwargs": {"path": r"c:\yep\sub\path2"}, - }, - } + + d = { + "kwargs": { + "path": r"c:\yep\sub\path1", + "recording": { + "module": "mock_module", + "class": "mock_class", + "version": "1.2", + "annotations": {}, + "kwargs": {"path": r"c:\yep\sub\path2"}, + }, } + } - # same drive - assert check_paths_relative(d, r"c:\yep") - # not the same drive - assert not check_paths_relative(d, r"d:\yep") + # same drive + assert check_paths_relative(d, r"c:\yep") + # not the same drive + assert not check_paths_relative(d, r"d:\yep") @pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") @@ -139,8 +138,8 @@ def test_make_paths_absolute(tmp_path): } modified_extractor_dict = make_paths_absolute(extractor_dict, tmp_path) - assert modified_extractor_dict["kwargs"]["path"].startswith(str(tmp_path)) - assert modified_extractor_dict["kwargs"]["recording"]["kwargs"]["path"].startswith(str(tmp_path)) + assert modified_extractor_dict["kwargs"]["path"].startswith(str(tmp_path.as_posix())) + assert modified_extractor_dict["kwargs"]["recording"]["kwargs"]["path"].startswith(str(tmp_path.as_posix())) assert modified_extractor_dict["kwargs"]["electrical_series_path"] == "/acquisition/timeseries" From c1e4eee519c289899f2650d98e6210d631ae42f2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Jun 2024 00:41:00 +0000 Subject: [PATCH 079/103] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/tests/test_core_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index ed13bd46fd..724517577c 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -54,7 +54,7 @@ def test_path_utils_functions(): @pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") def test_relative_path_on_windows(): - + d = { "kwargs": { "path": r"c:\yep\sub\path1", From 0d993421fc2f4bb6e35facf25164b3a370d28c03 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 27 Jun 2024 00:58:09 -0600 Subject: [PATCH 080/103] Add machinery to run test only on changed files (#3084) Improve full tests and file changed machinery Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alessio Buccino --- .github/determine_testing_environment.py | 118 ++++++++++++++ .github/workflows/all-tests.yml | 150 ++++++++++++++---- pyproject.toml | 5 + .../tests/test_templatecomparison.py | 7 +- 4 files changed, 245 insertions(+), 35 deletions(-) create mode 100644 .github/determine_testing_environment.py diff --git a/.github/determine_testing_environment.py b/.github/determine_testing_environment.py new file mode 100644 index 0000000000..4945ccc807 --- /dev/null +++ b/.github/determine_testing_environment.py @@ -0,0 +1,118 @@ +from pathlib import Path +import argparse +import os + + +# We get the list of files change as an input +parser = argparse.ArgumentParser() +parser.add_argument("changed_files_in_the_pull_request", nargs="*", help="List of changed files") +args = parser.parse_args() + +changed_files_in_the_pull_request = args.changed_files_in_the_pull_request +changed_files_in_the_pull_request_paths = [Path(file) for file in changed_files_in_the_pull_request] + +# We assume nothing has been changed + +core_changed = False +pyproject_toml_changed = False +neobaseextractor_changed = False +extractors_changed = False +plexon2_changed = False +preprocessing_changed = False +postprocessing_changed = False +qualitymetrics_changed = False +sorters_changed = False +sorters_external_changed = False +sorters_internal_changed = False +comparison_changed = False +curation_changed = False +widgets_changed = False +exporters_changed = False +sortingcomponents_changed = False +generation_changed = False + + +for changed_file in changed_files_in_the_pull_request_paths: + + file_is_in_src = changed_file.parts[0] == "src" + + if not file_is_in_src: + + if changed_file.name == "pyproject.toml": + pyproject_toml_changed = True + + else: + if changed_file.name == "neobaseextractor.py": + neobaseextractor_changed = True + elif changed_file.name == "plexon2.py": + extractors_changed = True + elif "core" in changed_file.parts: + conditions_changed = True + elif "extractors" in changed_file.parts: + extractors_changed = True + elif "preprocessing" in changed_file.parts: + preprocessing_changed = True + elif "postprocessing" in changed_file.parts: + postprocessing_changed = True + elif "qualitymetrics" in changed_file.parts: + qualitymetrics_changed = True + elif "comparison" in changed_file.parts: + comparison_changed = True + elif "curation" in changed_file.parts: + curation_changed = True + elif "widgets" in changed_file.parts: + widgets_changed = True + elif "exporters" in changed_file.parts: + exporters_changed = True + elif "sortingcomponents" in changed_file.parts: + sortingcomponents_changed = True + elif "generation" in changed_file.parts: + generation_changed = True + elif "sorters" in changed_file.parts: + if "external" in changed_file.parts: + sorters_external_changed = True + elif "internal" in changed_file.parts: + sorters_internal_changed = True + else: + sorters_changed = True + + +run_everything = core_changed or pyproject_toml_changed or neobaseextractor_changed +run_generation_tests = run_everything or generation_changed +run_extractor_tests = run_everything or extractors_changed +run_preprocessing_tests = run_everything or preprocessing_changed +run_postprocessing_tests = run_everything or postprocessing_changed +run_qualitymetrics_tests = run_everything or qualitymetrics_changed +run_curation_tests = run_everything or curation_changed +run_sortingcomponents_tests = run_everything or sortingcomponents_changed + +run_comparison_test = run_everything or run_generation_tests or comparison_changed +run_widgets_test = run_everything or run_qualitymetrics_tests or run_preprocessing_tests or widgets_changed +run_exporters_test = run_everything or run_widgets_test or exporters_changed + +run_sorters_test = run_everything or sorters_changed +run_internal_sorters_test = run_everything or run_sortingcomponents_tests or sorters_internal_changed + +install_plexon_dependencies = plexon2_changed + +environment_varaiables_to_add = { + "RUN_EXTRACTORS_TESTS": run_extractor_tests, + "RUN_PREPROCESSING_TESTS": run_preprocessing_tests, + "RUN_POSTPROCESSING_TESTS": run_postprocessing_tests, + "RUN_QUALITYMETRICS_TESTS": run_qualitymetrics_tests, + "RUN_CURATION_TESTS": run_curation_tests, + "RUN_SORTINGCOMPONENTS_TESTS": run_sortingcomponents_tests, + "RUN_GENERATION_TESTS": run_generation_tests, + "RUN_COMPARISON_TESTS": run_comparison_test, + "RUN_WIDGETS_TESTS": run_widgets_test, + "RUN_EXPORTERS_TESTS": run_exporters_test, + "RUN_SORTERS_TESTS": run_sorters_test, + "RUN_INTERNAL_SORTERS_TESTS": run_internal_sorters_test, + "INSTALL_PLEXON_DEPENDENCIES": install_plexon_dependencies, +} + +# Write the conditions to the GITHUB_ENV file +env_file = os.getenv("GITHUB_ENV") +with open(env_file, "a") as f: + for key, value in environment_varaiables_to_add.items(): + f.write(f"{key}={value}\n") diff --git a/.github/workflows/all-tests.yml b/.github/workflows/all-tests.yml index 1c426ba11c..cce73a9008 100644 --- a/.github/workflows/all-tests.yml +++ b/.github/workflows/all-tests.yml @@ -32,14 +32,64 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - # cache: 'pip' # caching pip dependencies - - name: Get current hash (SHA) of the ephy_testing_data repo - id: repo_hash + - name: Get changed files + id: changed-files + uses: tj-actions/changed-files@v41 + + - name: List all changed files + shell: bash + env: + ALL_CHANGED_FILES: ${{ steps.changed-files.outputs.all_changed_files }} + run: | + for file in ${ALL_CHANGED_FILES}; do + echo "$file was changed" + done + + - name: Set testing environment # This decides which tests are run and whether to install especial dependencies + shell: bash + run: | + changed_files="${{ steps.changed-files.outputs.all_changed_files }}" + python .github/determine_testing_environment.py $changed_files + + - name: Display testing environment + shell: bash + run: | + echo "RUN_EXTRACTORS_TESTS=${RUN_EXTRACTORS_TESTS}" + echo "RUN_PREPROCESSING_TESTS=${RUN_PREPROCESSING_TESTS}" + echo "RUN_POSTPROCESSING_TESTS=${RUN_POSTPROCESSING_TESTS}" + echo "RUN_QUALITYMETRICS_TESTS=${RUN_QUALITYMETRICS_TESTS}" + echo "RUN_CURATION_TESTS=${RUN_CURATION_TESTS}" + echo "RUN_SORTINGCOMPONENTS_TESTS=${RUN_SORTINGCOMPONENTS_TESTS}" + echo "RUN_GENERATION_TESTS=${RUN_GENERATION_TESTS}" + echo "RUN_COMPARISON_TESTS=${RUN_COMPARISON_TESTS}" + echo "RUN_WIDGETS_TESTS=${RUN_WIDGETS_TESTS}" + echo "RUN_EXPORTERS_TESTS=${RUN_EXPORTERS_TESTS}" + echo "RUN_SORTERS_TESTS=${RUN_SORTERS_TESTS}" + echo "RUN_INTERNAL_SORTERS_TESTS=${RUN_INTERNAL_SORTERS_TESTS}" + echo "INSTALL_PLEXON_DEPENDENCIES=${INSTALL_PLEXON_DEPENDENCIES}" + + - name: Install packages run: | - echo "dataset_hash=$(git ls-remote https://gin.g-node.org/NeuralEnsemble/ephy_testing_data.git HEAD | cut -f1)" - echo "dataset_hash=$(git ls-remote https://gin.g-node.org/NeuralEnsemble/ephy_testing_data.git HEAD | cut -f1)" >> $GITHUB_OUTPUT + pip install -e .[test_core] shell: bash + + - name: Test core + run: pytest -m "core" + shell: bash + + - name: Install Other Testing Dependencies + run: | + pip install -e .[test] + pip install tabulate + pip install pandas + shell: bash + + - name: Get current hash (SHA) of the ephy_testing_data repo + shell: bash + id: repo_hash + run: echo "dataset_hash=$(git ls-remote https://gin.g-node.org/NeuralEnsemble/ephy_testing_data.git HEAD | cut -f1)" >> $GITHUB_OUTPUT + - name: Cache datasets id: cache-datasets uses: actions/cache/restore@v4 @@ -48,82 +98,114 @@ jobs: key: ${{ runner.os }}-datasets-${{ steps.repo_hash.outputs.dataset_hash }} restore-keys: ${{ runner.os }}-datasets - - name: Install packages - run: | - git config --global user.email "CI@example.com" - git config --global user.name "CI Almighty" - pip install -e .[test,extractors,streaming_extractors,full] - pip install tabulate + - name: Install git-annex shell: bash - - - name: Installad datalad + if: env.RUN_EXTRACTORS_TESTS == 'true' run: | pip install datalad-installer if [ ${{ runner.os }} = 'Linux' ]; then - datalad-installer --sudo ok git-annex --method datalad/packages + wget https://downloads.kitenet.net/git-annex/linux/current/git-annex-standalone-amd64.tar.gz + mkdir /home/runner/work/installation + mv git-annex-standalone-amd64.tar.gz /home/runner/work/installation/ + workdir=$(pwd) + cd /home/runner/work/installation + tar xvzf git-annex-standalone-amd64.tar.gz + echo "$(pwd)/git-annex.linux" >> $GITHUB_PATH + cd $workdir elif [ ${{ runner.os }} = 'macOS' ]; then datalad-installer --sudo ok git-annex --method brew elif [ ${{ runner.os }} = 'Windows' ]; then datalad-installer --sudo ok git-annex --method datalad/git-annex:release fi - pip install datalad git config --global filter.annex.process "git-annex filter-process" # recommended for efficiency - shell: bash - - name: Set execute permissions on run_tests.sh - run: chmod +x .github/run_tests.sh - shell: bash - - name: Test core - run: pytest -m "core" + - name: Set execute permissions on run_tests.sh shell: bash + run: chmod +x .github/run_tests.sh - name: Test extractors + shell: bash env: HDF5_PLUGIN_PATH: ${{ github.workspace }}/hdf5_plugin_path_maxwell - run: pytest -m "extractors" - shell: bash + if: env.RUN_EXTRACTORS_TESTS == 'true' + run: | + pip install -e .[extractors,streaming_extractors] + ./.github/run_tests.sh "extractors and not streaming_extractors" --no-virtual-env - name: Test preprocessing - run: ./.github/run_tests.sh "preprocessing and not deepinterpolation" --no-virtual-env shell: bash + if: env.RUN_PREPROCESSING_TESTS == 'true' + run: | + pip install -e .[preprocessing] + ./.github/run_tests.sh "preprocessing and not deepinterpolation" --no-virtual-env - name: Test postprocessing - run: ./.github/run_tests.sh postprocessing --no-virtual-env shell: bash + if: env.RUN_POSTPROCESSING_TESTS == 'true' + run: | + pip install -e .[full] + ./.github/run_tests.sh postprocessing --no-virtual-env - name: Test quality metrics - run: ./.github/run_tests.sh qualitymetrics --no-virtual-env shell: bash + if: env.RUN_QUALITYMETRICS_TESTS == 'true' + run: | + pip install -e .[qualitymetrics] + ./.github/run_tests.sh qualitymetrics --no-virtual-env - name: Test comparison - run: ./.github/run_tests.sh comparison --no-virtual-env shell: bash + if: env.RUN_COMPARISON_TESTS == 'true' + run: | + pip install -e .[full] + ./.github/run_tests.sh comparison --no-virtual-env - name: Test core sorters - run: ./.github/run_tests.sh sorters --no-virtual-env shell: bash + if: env.RUN_SORTERS_TESTS == 'true' + run: | + pip install -e .[full] + ./.github/run_tests.sh sorters --no-virtual-env - name: Test internal sorters - run: ./.github/run_tests.sh sorters_internal --no-virtual-env shell: bash + if: env.RUN_INTERNAL_SORTERS_TESTS == 'true' + run: | + pip install -e .[full] + ./.github/run_tests.sh sorters_internal --no-virtual-env - name: Test curation - run: ./.github/run_tests.sh curation --no-virtual-env shell: bash + if: env.RUN_CURATION_TESTS == 'true' + run: | + pip install -e .[full] + ./.github/run_tests.sh curation --no-virtual-env - name: Test widgets - run: ./.github/run_tests.sh widgets --no-virtual-env shell: bash + if: env.RUN_WIDGETS_TESTS == 'true' + run: | + pip install -e .[full] + ./.github/run_tests.sh widgets --no-virtual-env - name: Test exporters - run: ./.github/run_tests.sh exporters --no-virtual-env shell: bash + if: env.RUN_EXPORTERS_TESTS == 'true' + run: | + pip install -e .[full] + ./.github/run_tests.sh exporters --no-virtual-env - name: Test sortingcomponents - run: ./.github/run_tests.sh sortingcomponents --no-virtual-env shell: bash + if: env.RUN_SORTINGCOMPONENTS_TESTS == 'true' + run: | + pip install -e .[full] + ./.github/run_tests.sh sortingcomponents --no-virtual-env - name: Test generation - run: ./.github/run_tests.sh generation --no-virtual-env shell: bash + if: env.RUN_GENERATION_TESTS == 'true' + run: | + pip install -e .[full] + ./.github/run_tests.sh generation --no-virtual-env diff --git a/pyproject.toml b/pyproject.toml index b26337ad01..72bf376a31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,11 @@ streaming_extractors = [ "s3fs" ] +preprocessing = [ + "scipy", +] + + full = [ "h5py", "pandas", diff --git a/src/spikeinterface/comparison/tests/test_templatecomparison.py b/src/spikeinterface/comparison/tests/test_templatecomparison.py index 871bdeaed3..a7f30cfc45 100644 --- a/src/spikeinterface/comparison/tests/test_templatecomparison.py +++ b/src/spikeinterface/comparison/tests/test_templatecomparison.py @@ -16,7 +16,12 @@ def test_compare_multiple_templates(): duration = 60 num_channels = 8 - rec, sort = generate_ground_truth_recording(durations=[duration], num_channels=num_channels) + seed = 0 + rec, sort = generate_ground_truth_recording( + durations=[duration], + num_channels=num_channels, + seed=seed, + ) # split recording in 3 equal slices fs = rec.get_sampling_frequency() From efede134e52a0a01e1665cffb5543a696673b525 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 27 Jun 2024 08:50:11 +0100 Subject: [PATCH 081/103] use_names_as_ids update --- src/spikeinterface/extractors/neoextractors/blackrock.py | 5 +++-- src/spikeinterface/extractors/neoextractors/intan.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/blackrock.py b/src/spikeinterface/extractors/neoextractors/blackrock.py index 0015fd9f67..ab3710e05e 100644 --- a/src/spikeinterface/extractors/neoextractors/blackrock.py +++ b/src/spikeinterface/extractors/neoextractors/blackrock.py @@ -26,8 +26,9 @@ class BlackrockRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. - use_names_as_ids : bool or None, default: None - If True, use channel names as IDs. If None, use default IDs. + use_names_as_ids : bool, default: False + If False, use default IDs inherited from Neo. If True, use channel names as IDs. + """ mode = "file" diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index 50fda79123..43439b80c9 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -28,7 +28,9 @@ class IntanRecordingExtractor(NeoBaseRecordingExtractor): check we perform is that timestamps are continuous. Setting this to True will ignore this check and set the attribute `discontinuous_timestamps` to True in the underlying neo object. use_names_as_ids : bool, default: False - If True, use channel names as IDs. If False, use default IDs inherited from neo. + If False, use default IDs inherited from Neo. If True, use channel names as IDs. + + """ mode = "file" From 713a6612af89db7983f621bd03de1b5b22a754de Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 10:16:57 +0200 Subject: [PATCH 082/103] Add ibllib to pteprocessing requirements --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 72bf376a31..c801f1f735 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ streaming_extractors = [ preprocessing = [ "scipy", + "ibllib>=2.36.0", # for IBL preprocessing ] From dbe3ef2b095af84b4ab8ebc0b0396b97de576ef0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 10:26:21 +0200 Subject: [PATCH 083/103] Move iblibb in test dependencies --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c801f1f735..69f4067d13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,6 @@ streaming_extractors = [ preprocessing = [ "scipy", - "ibllib>=2.36.0", # for IBL preprocessing ] @@ -137,6 +136,9 @@ test = [ "xarray", "huggingface_hub", + # preprocessing + "ibllib>=2.36.0", # for IBL + # tridesclous "numba", "hdbscan>=0.8.33", # Previous version had a broken wheel From 1aa036885b3fefc3bf8440ee2a7cd71295badf0f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 12:04:03 +0200 Subject: [PATCH 084/103] Move drift_raster_map to motion, typing, docs, and tests --- src/spikeinterface/widgets/driftmap.py | 143 --------- src/spikeinterface/widgets/motion.py | 284 +++++++++++++++--- .../widgets/tests/test_widgets.py | 48 +-- src/spikeinterface/widgets/widget_list.py | 7 +- 4 files changed, 266 insertions(+), 216 deletions(-) delete mode 100644 src/spikeinterface/widgets/driftmap.py diff --git a/src/spikeinterface/widgets/driftmap.py b/src/spikeinterface/widgets/driftmap.py deleted file mode 100644 index 60e8df2972..0000000000 --- a/src/spikeinterface/widgets/driftmap.py +++ /dev/null @@ -1,143 +0,0 @@ -from __future__ import annotations - -import numpy as np - -from .base import BaseWidget, to_attr - - -class DriftMapWidget(BaseWidget): - """ - Plot the a drift map from a motion info dictionary. - - Parameters - ---------- - peaks : np.array - The peaks array, with dtype ("sample_index", "channel_index", "amplitude", "segment_index") - peak_locations : np.array - The peak locations, with dtype ("x", "y") or ("x", "y", "z") - direction : "x" or "y", default: "y" - The direction to display - segment_index : int, default: None - The segment index to display. - recording : RecordingExtractor, default: None - The recording extractor object (only used to get "real" times) - segment_index : int, default: 0 - The segment index to display. - sampling_frequency : float, default: None - The sampling frequency (needed if recording is None) - depth_lim : tuple or None, default: None - The min and max depth to display, if None (min and max of the recording) - color_amplitude : bool, default: True - If True, the color of the scatter points is the amplitude of the peaks - scatter_decimate : int, default: None - If > 1, the scatter points are decimated - cmap : str, default: "inferno" - The colormap to use for the amplitude - clim : tuple or None, default: None - The min and max amplitude to display, if None (min and max of the amplitudes) - alpha : float, default: 1 - The alpha of the scatter points - """ - - def __init__( - self, - peaks, - peak_locations, - direction="y", - recording=None, - sampling_frequency=None, - segment_index=None, - depth_lim=None, - color_amplitude=True, - scatter_decimate=None, - cmap="inferno", - clim=None, - alpha=1, - backend=None, - **backend_kwargs, - ): - if segment_index is None: - assert ( - len(np.unique(peaks["segment_index"])) == 1 - ), "segment_index must be specified if there is only one segment in the peaks array" - assert recording or sampling_frequency, "recording or sampling_frequency must be specified" - if recording is not None: - sampling_frequency = recording.sampling_frequency - times = recording.get_times(segment_index=segment_index) - else: - times = None - - plot_data = dict( - peaks=peaks, - peak_locations=peak_locations, - direction=direction, - times=times, - sampling_frequency=sampling_frequency, - segment_index=segment_index, - depth_lim=depth_lim, - color_amplitude=color_amplitude, - scatter_decimate=scatter_decimate, - cmap=cmap, - clim=clim, - alpha=alpha, - recording=recording, - ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - - def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - from matplotlib.colors import Normalize - - from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks - - dp = to_attr(data_plot) - - assert backend_kwargs["axes"] is None, "axes argument is not allowed in MotionWidget" - - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - fig = self.figure - - if dp.times is None: - # temporal_bins_plot = dp.temporal_bins - x = dp.peaks["sample_index"] / dp.sampling_frequency - else: - # use real times and adjust temporal bins with t_start - # temporal_bins_plot = dp.temporal_bins + dp.times[0] - x = dp.times[dp.peaks["sample_index"]] - - y = dp.peak_locations[dp.direction] - if dp.scatter_decimate is not None: - x = x[:: dp.scatter_decimate] - y = y[:: dp.scatter_decimate] - y2 = y2[:: dp.scatter_decimate] - - if dp.color_amplitude: - amps = dp.peaks["amplitude"] - amps_abs = np.abs(amps) - q_95 = np.quantile(amps_abs, 0.95) - if dp.scatter_decimate is not None: - amps = amps[:: dp.scatter_decimate] - amps_abs = amps_abs[:: dp.scatter_decimate] - cmap = plt.colormaps[dp.cmap] - if dp.clim is None: - amps = amps_abs - amps /= q_95 - c = cmap(amps) - else: - norm_function = Normalize(vmin=dp.clim[0], vmax=dp.clim[1], clip=True) - c = cmap(norm_function(amps)) - color_kwargs = dict( - color=None, - c=c, - alpha=dp.alpha, - ) - else: - color_kwargs = dict(color="k", c=None, alpha=dp.alpha) - - self.ax.scatter(x, y, s=1, **color_kwargs) - if dp.depth_lim is not None: - self.ax.set_ylim(*dp.depth_lim) - self.ax.set_title("Peak depth") - self.ax.set_xlabel("Times [s]") - self.ax.set_ylabel("Depth [$\\mu$m]") diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 7d733523df..ee1599822f 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -3,31 +3,32 @@ import numpy as np from .base import BaseWidget, to_attr -from .driftmap import DriftMapWidget + +from spikeinterface.core import BaseRecording, SortingAnalyzer +from spikeinterface.sortingcomponents.motion_utils import Motion class MotionWidget(BaseWidget): """ - Plot the Motion object + Plot the Motion object. Parameters ---------- motion : Motion - The motion object - segment_index : None | int - If Motion is multi segment, the must be not None - mode : "auto" | "line" | "map" - How to plot map or lines. - "auto" make it automatic if the number of depth is too high. + The motion object. + segment_index : int | None, default: None + If Motion is multi segment, the must be not None. + mode : "auto" | "line" | "map", default: "line" + How to plot map or lines. "auto" makes it automatic if the number of motion depths is too high. """ def __init__( self, - motion, - segment_index=None, - mode="line", - motion_lim=None, - backend=None, + motion: Motion, + segment_index: int | None = None, + mode: str = "line", + motion_lim: float | None = None, + backend: str | None = None, **backend_kwargs, ): if isinstance(motion, dict): @@ -51,19 +52,15 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt from .utils_matplotlib import make_mpl_figure - from matplotlib.colors import Normalize dp = to_attr(data_plot) - motion = data_plot["motion"] - segment_index = data_plot["segment_index"] - assert backend_kwargs["axes"] is None self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + motion = dp.motion displacement = motion.displacement[dp.segment_index] temporal_bins_s = motion.temporal_bins_s[dp.segment_index] depth = motion.spatial_bins_um @@ -97,55 +94,241 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.set_ylabel("Depth [um]") +class DriftRasterMapWidget(BaseWidget): + """ + Plot the drift raster map from peaks or a SortingAnalyzer. + + Parameters + ---------- + peaks : np.array | None, default: None + The peaks array, with dtype ("sample_index", "channel_index", "amplitude", "segment_index"), + as returned by the `detect_peaks` or `correct_motion` functions. + peak_locations : np.array | None, default: None + The peak locations, with dtype ("x", "y") or ("x", "y", "z"), as returned by the + `localize_peaks` or `correct_motion` functions. + sorting_analyzer : SortingAnalyzer | None, default: None + The sorting analyzer object. To use this function, the `SortingAnalyzer` must have the + "spike_locations" extension computed. + direction : "x" or "y", default: "y" + The direction to display. + segment_index : int, default: None + The segment index to display. + recording : RecordingExtractor | None, default: None + The recording extractor object (only used to get "real" times). + segment_index : int, default: 0 + The segment index to display. + sampling_frequency : float, default: None + The sampling frequency (needed if recording is None). + depth_lim : tuple or None, default: None + The min and max depth to display, if None (min and max of the recording). + scatter_decimate : int, default: None + If > 1, the scatter points are decimated. + color_amplitude : bool, default: True + If True, the color of the scatter points is the amplitude of the peaks. + cmap : str, default: "inferno" + The colormap to use for the amplitude. + color : str, default: "Gray" + The color of the scatter points if color_amplitude is False. + clim : tuple or None, default: None + The min and max amplitude to display, if None (min and max of the amplitudes). + alpha : float, default: 1 + The alpha of the scatter points. + """ + + def __init__( + self, + peaks: np.array | None = None, + peak_locations: np.array | None = None, + sorting_analyzer: SortingAnalyzer | None = None, + direction: str = "y", + recording: BaseRecording | None = None, + sampling_frequency: float | None = None, + segment_index: int | None = None, + depth_lim: tuple[float, float] | None = None, + color_amplitude: bool = True, + scatter_decimate: int | None = None, + cmap: str = "inferno", + color: str = "Gray", + clim: tuple[float, float] | None = None, + alpha: float = 1, + backend: str | None = None, + **backend_kwargs, + ): + assert peaks is not None or sorting_analyzer is not None + if peaks is not None: + assert peak_locations is not None + if recording is None: + assert sampling_frequency is not None, "If recording is None, you must provide the sampling frequency" + else: + sampling_frequency = recording.sampling_frequency + peak_amplitudes = peaks["amplitude"] + if sorting_analyzer is not None: + if sorting_analyzer.has_recording(): + recording = sorting_analyzer.recording + else: + recording = None + sampling_frequency = sorting_analyzer.sampling_frequency + peaks = sorting_analyzer.sorting.to_spike_vector() + assert sorting_analyzer.has_extension( + "spike_locations" + ), "The sorting analyzer must have the 'spike_locations' extension to use this function" + peak_locations = sorting_analyzer.get_extension("spike_locations").get_data() + if color_amplitude: + assert sorting_analyzer.has_extension("spike_amplitudes"), ( + "The sorting analyzer must have the 'spike_amplitudes' extension to use color_amplitude=True. " + "You can compute it or set color_amplitude=False." + ) + if sorting_analyzer.has_extension("spike_amplitudes"): + peak_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() + else: + peak_amplitudes = None + times = recording.get_times(segment_index=segment_index) if recording is not None else None + + if segment_index is None: + assert ( + len(np.unique(peaks["segment_index"])) == 1 + ), "segment_index must be specified if there is only one segment in the peaks array" + segment_index = 0 + else: + peak_mask = peaks["segment_index"] == segment_index + peaks = peaks[peak_mask] + peak_locations = peak_locations[peak_mask] + if peak_amplitudes is not None: + peak_amplitudes = peak_amplitudes[peak_mask] + + if recording is not None: + sampling_frequency = recording.sampling_frequency + times = recording.get_times(segment_index=segment_index) + else: + times = None + + plot_data = dict( + peaks=peaks, + peak_locations=peak_locations, + peak_amplitudes=peak_amplitudes, + direction=direction, + times=times, + sampling_frequency=sampling_frequency, + segment_index=segment_index, + depth_lim=depth_lim, + color_amplitude=color_amplitude, + color=color, + scatter_decimate=scatter_decimate, + cmap=cmap, + clim=clim, + alpha=alpha, + recording=recording, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from matplotlib.colors import Normalize + from .utils_matplotlib import make_mpl_figure + + from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks + + dp = to_attr(data_plot) + + assert backend_kwargs["axes"] is None, "axes argument is not allowed in MotionWidget" + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + fig = self.figure + + if dp.times is None: + x = dp.peaks["sample_index"] / dp.sampling_frequency + else: + x = dp.times[dp.peaks["sample_index"]] + + y = dp.peak_locations[dp.direction] + if dp.scatter_decimate is not None: + x = x[:: dp.scatter_decimate] + y = y[:: dp.scatter_decimate] + y2 = y2[:: dp.scatter_decimate] + + if dp.color_amplitude: + amps = dp.peak_amplitudes + amps_abs = np.abs(amps) + q_95 = np.quantile(amps_abs, 0.95) + if dp.scatter_decimate is not None: + amps = amps[:: dp.scatter_decimate] + amps_abs = amps_abs[:: dp.scatter_decimate] + cmap = plt.colormaps[dp.cmap] + if dp.clim is None: + amps = amps_abs + amps /= q_95 + c = cmap(amps) + else: + norm_function = Normalize(vmin=dp.clim[0], vmax=dp.clim[1], clip=True) + c = cmap(norm_function(amps)) + color_kwargs = dict( + color=None, + c=c, + alpha=dp.alpha, + ) + else: + color_kwargs = dict(color=dp.color, c=None, alpha=dp.alpha) + + self.ax.scatter(x, y, s=1, **color_kwargs) + if dp.depth_lim is not None: + self.ax.set_ylim(*dp.depth_lim) + self.ax.set_title("Peak depth") + self.ax.set_xlabel("Times [s]") + self.ax.set_ylabel("Depth [$\\mu$m]") + + class MotionInfoWidget(BaseWidget): """ - Plot motion information from the motion_info dict returned by correct_motion(). - This plot: - * the motion iself - * the peak depth vs time before correction - * the peak depth vs time after correction + Plot motion information from the motion_info dictionary returned by the `correct_motion()` funciton. + This widget plots:: + * the motion iself + * the drift raster map (peak depth vs time) before correction + * the drift raster map (peak depth vs time) after correction Parameters ---------- motion_info : dict - The motion info returned by correct_motion() or loaded back with load_motion_info() + The motion info returned by correct_motion() or loaded back with load_motion_info(). segment_index : int, default: None The segment index to display. recording : RecordingExtractor, default: None - The recording extractor object (only used to get "real" times) + The recording extractor object (only used to get "real" times). segment_index : int, default: 0 The segment index to display. sampling_frequency : float, default: None - The sampling frequency (needed if recording is None) + The sampling frequency (needed if recording is None). depth_lim : tuple or None, default: None - The min and max depth to display, if None (min and max of the recording) + The min and max depth to display, if None (min and max of the recording). motion_lim : tuple or None, default: None - The min and max motion to display, if None (min and max of the motion) - color_amplitude : bool, default: False - If True, the color of the scatter points is the amplitude of the peaks + The min and max motion to display, if None (min and max of the motion). scatter_decimate : int, default: None - If > 1, the scatter points are decimated + If > 1, the scatter points are decimated. + color_amplitude : bool, default: False + If True, the color of the scatter points is the amplitude of the peaks. amplitude_cmap : str, default: "inferno" - The colormap to use for the amplitude + The colormap to use for the amplitude. + amplitude_color : str, default: "Gray" + The color of the scatter points if color_amplitude is False. amplitude_clim : tuple or None, default: None - The min and max amplitude to display, if None (min and max of the amplitudes) + The min and max amplitude to display, if None (min and max of the amplitudes). amplitude_alpha : float, default: 1 - The alpha of the scatter points + The alpha of the scatter points. """ def __init__( self, - motion_info, - segment_index=None, - recording=None, - depth_lim=None, - motion_lim=None, - color_amplitude=False, - scatter_decimate=None, - amplitude_cmap="inferno", - amplitude_clim=None, - amplitude_alpha=1, - backend=None, + motion_info: dict, + segment_index: int | None = None, + recording: BaseRecording | None = None, + depth_lim: tuple[float, float] | None = None, + motion_lim: tuple[float, float] | None = None, + color_amplitude: bool = False, + scatter_decimate: int | None = None, + amplitude_cmap: str = "inferno", + amplitude_color: str = "Gray", + amplitude_clim: tuple[float, float] | None = None, + amplitude_alpha: float = 1, + backend: str | None = None, **backend_kwargs, ): @@ -169,6 +352,7 @@ def __init__( color_amplitude=color_amplitude, scatter_decimate=scatter_decimate, amplitude_cmap=amplitude_cmap, + amplitude_color=amplitude_color, amplitude_clim=amplitude_clim, amplitude_alpha=amplitude_alpha, recording=recording, @@ -178,9 +362,7 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt from .utils_matplotlib import make_mpl_figure - from matplotlib.colors import Normalize from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks @@ -229,15 +411,17 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): recording=dp.recording, segment_index=dp.segment_index, depth_lim=dp.depth_lim, - color_amplitude=dp.color_amplitude, scatter_decimate=dp.scatter_decimate, + color_amplitude=dp.color_amplitude, + color=dp.amplitude_color, cmap=dp.amplitude_cmap, clim=dp.amplitude_clim, alpha=dp.amplitude_alpha, backend="matplotlib", ) - drift_map = DriftMapWidget( + # with immediate_plot=True the widgets are plotted immediately + _ = DriftRasterMapWidget( dp.peaks, dp.peak_locations, ax=ax0, @@ -245,7 +429,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): **commpon_drift_map_kwargs, ) - drift_map_corrected = DriftMapWidget( + _ = DriftRasterMapWidget( dp.peaks, corrected_location, ax=ax1, diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index e841a1c93b..0eef8539cc 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -22,7 +22,7 @@ import spikeinterface.widgets as sw import spikeinterface.comparison as sc -from spikeinterface.preprocessing import scale +from spikeinterface.preprocessing import scale, correct_motion ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) @@ -56,6 +56,9 @@ def setUpClass(cls): cls.recording = recording cls.sorting = sorting + # estimate motion for motion widgets + _, cls.motion_info = correct_motion(recording, preset="kilosort_like", output_motion_info=True) + cls.num_units = len(cls.sorting.get_unit_ids()) extensions_to_compute = dict( @@ -581,9 +584,7 @@ def test_plot_multicomparison(self): sw.plot_multicomparison_agreement_by_sorter(mcmp, axes=axes) def test_plot_motion(self): - from spikeinterface.sortingcomponents.tests.test_motion_utils import make_fake_motion - - motion = make_fake_motion() + motion = self.motion_info["motion"] possible_backends = list(sw.MotionWidget.get_possible_backends()) for backend in possible_backends: @@ -591,22 +592,31 @@ def test_plot_motion(self): sw.plot_motion(motion, backend=backend, mode="line") sw.plot_motion(motion, backend=backend, mode="map") - def test_plot_motion_info(self): - from spikeinterface.sortingcomponents.tests.test_motion_utils import make_fake_motion - - motion = make_fake_motion() - rng = np.random.default_rng(seed=2205) - peak_locations = np.zeros(self.peaks.size, dtype=[("x", "float64"), ("y", "float64")]) - peak_locations["y"] = rng.uniform(motion.spatial_bins_um[0], motion.spatial_bins_um[-1], size=self.peaks.size) - - motion_info = dict( - motion=motion, - parameters=dict(sampling_frequency=30000.0), - run_times=dict(), - peaks=self.peaks, - peak_locations=peak_locations, - ) + def test_drift_raster_map(self): + peaks = self.motion_info["peaks"] + recording = self.recording + peak_locations = self.motion_info["peak_locations"] + analyzer = self.sorting_analyzer_sparse + possible_backends = list(sw.MotionWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + # with recoridng + sw.plot_drift_raster_map( + peaks=peaks, peak_locations=peak_locations, recording=recording, color_amplitude=True + ) + # without recording + sw.plot_drift_raster_map( + peaks=peaks, + peak_locations=peak_locations, + sampling_frequency=recording.sampling_frequency, + color_amplitude=False, + ) + # with analyzer + sw.plot_drift_raster_map(sorting_analyzer=analyzer, color_amplitude=True) + + def test_plot_motion_info(self): + motion_info = self.motion_info possible_backends = list(sw.MotionWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 8d4accaa7e..8163271ec4 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -9,9 +9,8 @@ from .amplitudes import AmplitudesWidget from .autocorrelograms import AutoCorrelogramsWidget from .crosscorrelograms import CrossCorrelogramsWidget -from .driftmap import DriftMapWidget from .isi_distribution import ISIDistributionWidget -from .motion import MotionWidget, MotionInfoWidget +from .motion import DriftRasterMapWidget, MotionWidget, MotionInfoWidget from .multicomparison import MultiCompGraphWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget from .peak_activity import PeakActivityMapWidget from .peaks_on_probe import PeaksOnProbeWidget @@ -45,7 +44,7 @@ ConfusionMatrixWidget, ComparisonCollisionBySimilarityWidget, CrossCorrelogramsWidget, - DriftMapWidget, + DriftRasterMapWidget, ISIDistributionWidget, MotionWidget, MotionInfoWidget, @@ -120,7 +119,7 @@ plot_confusion_matrix = ConfusionMatrixWidget plot_comparison_collision_by_similarity = ComparisonCollisionBySimilarityWidget plot_crosscorrelograms = CrossCorrelogramsWidget -plot_drift_map = DriftMapWidget +plot_drift_raster_map = DriftRasterMapWidget plot_isi_distribution = ISIDistributionWidget plot_motion = MotionWidget plot_motion_info = MotionInfoWidget From 30b60e7eab49bfa47696593e8f7f3506113cda53 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 12:05:47 +0200 Subject: [PATCH 085/103] Add explanation on what drift rastermap is --- src/spikeinterface/widgets/motion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index ee1599822f..66ef2a3f01 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -97,6 +97,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): class DriftRasterMapWidget(BaseWidget): """ Plot the drift raster map from peaks or a SortingAnalyzer. + The drift raster map is a scatter plot of the estimated peak depth vs time and it is + useful to visualize the drift over the course of the recording. Parameters ---------- From 31064ec453f65cac23baa2379991b0996492618b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 12:06:39 +0200 Subject: [PATCH 086/103] Add explanation on 'y' direction --- src/spikeinterface/widgets/motion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 66ef2a3f01..31edbf2f4d 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -112,7 +112,7 @@ class DriftRasterMapWidget(BaseWidget): The sorting analyzer object. To use this function, the `SortingAnalyzer` must have the "spike_locations" extension computed. direction : "x" or "y", default: "y" - The direction to display. + The direction to display. "y" is the depth direction. segment_index : int, default: None The segment index to display. recording : RecordingExtractor | None, default: None From cc550b9622bee8bf11a11b585ee9ff02cb829423 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 12:07:35 +0200 Subject: [PATCH 087/103] Fix segment index error --- src/spikeinterface/widgets/motion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 31edbf2f4d..31a938829d 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -189,7 +189,7 @@ def __init__( if segment_index is None: assert ( len(np.unique(peaks["segment_index"])) == 1 - ), "segment_index must be specified if there is only one segment in the peaks array" + ), "segment_index must be specified if there are multiple segments" segment_index = 0 else: peak_mask = peaks["segment_index"] == segment_index From 80ba2e512f568a2b96ea3e38095bc19f9a987480 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 12:12:03 +0200 Subject: [PATCH 088/103] Review suggestions and test with scatter_decimate --- src/spikeinterface/widgets/motion.py | 16 +++++++--------- src/spikeinterface/widgets/tests/test_widgets.py | 2 +- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 31a938829d..895a8733c7 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -232,21 +232,19 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) - assert backend_kwargs["axes"] is None, "axes argument is not allowed in MotionWidget" + assert backend_kwargs["axes"] is None, "axes argument is not allowed in DriftRasterMapWidget. Use ax instead." self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - fig = self.figure if dp.times is None: - x = dp.peaks["sample_index"] / dp.sampling_frequency + peak_times = dp.peaks["sample_index"] / dp.sampling_frequency else: - x = dp.times[dp.peaks["sample_index"]] + peak_times = dp.times[dp.peaks["sample_index"]] - y = dp.peak_locations[dp.direction] + peak_locs = dp.peak_locations[dp.direction] if dp.scatter_decimate is not None: - x = x[:: dp.scatter_decimate] - y = y[:: dp.scatter_decimate] - y2 = y2[:: dp.scatter_decimate] + peak_times = peak_times[:: dp.scatter_decimate] + peak_locs = peak_locs[:: dp.scatter_decimate] if dp.color_amplitude: amps = dp.peak_amplitudes @@ -271,7 +269,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): else: color_kwargs = dict(color=dp.color, c=None, alpha=dp.alpha) - self.ax.scatter(x, y, s=1, **color_kwargs) + self.ax.scatter(peak_times, peak_locs, s=1, **color_kwargs) if dp.depth_lim is not None: self.ax.set_ylim(*dp.depth_lim) self.ax.set_title("Peak depth") diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 0eef8539cc..7887ecda66 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -613,7 +613,7 @@ def test_drift_raster_map(self): color_amplitude=False, ) # with analyzer - sw.plot_drift_raster_map(sorting_analyzer=analyzer, color_amplitude=True) + sw.plot_drift_raster_map(sorting_analyzer=analyzer, color_amplitude=True, scatter_decimate=2) def test_plot_motion_info(self): motion_info = self.motion_info From 3e9f342e6a8d7695186c2aef4e12cde30d984cea Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 12:22:09 +0200 Subject: [PATCH 089/103] Mark failing sorter test on Windows*Python3.12 as xfail --- src/spikeinterface/sorters/tests/test_runsorter.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/spikeinterface/sorters/tests/test_runsorter.py b/src/spikeinterface/sorters/tests/test_runsorter.py index 470bdc3602..6bd73c5691 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter.py +++ b/src/spikeinterface/sorters/tests/test_runsorter.py @@ -1,7 +1,9 @@ import os +import platform import pytest from pathlib import Path import shutil +from packaging.version import parse from spikeinterface import generate_ground_truth_recording from spikeinterface.sorters import run_sorter @@ -19,6 +21,10 @@ def generate_recording(): return _generate_recording() +@pytest.mark.xfail( + platform.system() == "Windows" and parse(platform.python_version()) > parse("3.12"), + reason="3rd parth threadpoolctl issue: OSError('GetModuleFileNameEx failed')", +) def test_run_sorter_local(generate_recording, create_cache_folder): recording = generate_recording cache_folder = create_cache_folder From d1d65f6ca6338ac2dd8d6f9c99ee657f0db76d21 Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 11:58:23 +0100 Subject: [PATCH 090/103] estimate_sparsity arg ordering --- src/spikeinterface/core/sortinganalyzer.py | 2 +- src/spikeinterface/core/sparsity.py | 6 +++--- src/spikeinterface/core/tests/test_sparsity.py | 4 ++-- .../postprocessing/tests/common_extension_tests.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 53e060262b..62b7f9e7c0 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -127,7 +127,7 @@ def create_sorting_analyzer( recording.channel_ids, sparsity.channel_ids ), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond" elif sparse: - sparsity = estimate_sparsity(recording, sorting, **sparsity_kwargs) + sparsity = estimate_sparsity(sorting, recording, **sparsity_kwargs) else: sparsity = None diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index cefd7bd950..1cd7822f99 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -539,8 +539,8 @@ def compute_sparsity( def estimate_sparsity( - recording: BaseRecording, sorting: BaseSorting, + recording: BaseRecording, num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, @@ -563,10 +563,10 @@ def estimate_sparsity( Parameters ---------- - recording: BaseRecording - The recording sorting: BaseSorting The sorting + recording: BaseRecording + The recording num_spikes_for_sparsity: int, default: 100 How many spikes per units to compute the sparsity ms_before: float, default: 1.0 diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 98d033d8ea..a192d90502 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -166,8 +166,8 @@ def test_estimate_sparsity(): # small radius should give a very sparse = one channel per unit sparsity = estimate_sparsity( - recording, sorting, + recording, num_spikes_for_sparsity=50, ms_before=1.0, ms_after=2.0, @@ -182,8 +182,8 @@ def test_estimate_sparsity(): # best_channel : the mask should exactly 3 channels per units sparsity = estimate_sparsity( - recording, sorting, + recording, num_spikes_for_sparsity=50, ms_before=1.0, ms_after=2.0, diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index bf462a9466..8c46fa5e24 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -79,7 +79,7 @@ class AnalyzerExtensionCommonTestSuite: def setUpClass(cls): cls.recording, cls.sorting = get_dataset() # sparsity is computed once for all cases to save processing time and force a small radius - cls.sparsity = estimate_sparsity(cls.recording, cls.sorting, method="radius", radius_um=20) + cls.sparsity = estimate_sparsity(cls.sorting, cls.recording, method="radius", radius_um=20) @property def extension_name(self): From 02ae32a857c9ce59a54deffcc1465a3d975342aa Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 14:22:18 +0200 Subject: [PATCH 091/103] Update src/spikeinterface/widgets/motion.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/widgets/motion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 895a8733c7..5f0e02fdab 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -280,7 +280,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): class MotionInfoWidget(BaseWidget): """ Plot motion information from the motion_info dictionary returned by the `correct_motion()` funciton. - This widget plots:: + This widget plots: * the motion iself * the drift raster map (peak depth vs time) before correction * the drift raster map (peak depth vs time) after correction From c111cfcacb1c80b4166320c3f3753a2a7d629f69 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 14:22:28 +0200 Subject: [PATCH 092/103] Update src/spikeinterface/widgets/tests/test_widgets.py Co-authored-by: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> --- src/spikeinterface/widgets/tests/test_widgets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 7887ecda66..012b1ac07c 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -601,7 +601,7 @@ def test_drift_raster_map(self): possible_backends = list(sw.MotionWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - # with recoridng + # with recording sw.plot_drift_raster_map( peaks=peaks, peak_locations=peak_locations, recording=recording, color_amplitude=True ) From 12d823bb0dc7e1536486508c473f0ce5562e395a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 14:37:10 +0200 Subject: [PATCH 093/103] Better docs for plot mode (line, map, auto) --- src/spikeinterface/widgets/motion.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 895a8733c7..766938299a 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -19,7 +19,10 @@ class MotionWidget(BaseWidget): segment_index : int | None, default: None If Motion is multi segment, the must be not None. mode : "auto" | "line" | "map", default: "line" - How to plot map or lines. "auto" makes it automatic if the number of motion depths is too high. + How to plot the motion. + "line" plots estimated motion at different depths as lines. + "map" plots estimated motion at different depths as a heatmap. + "auto" makes it automatic depending on the number of motion depths. """ def __init__( From a3deed8211f9b20e3acbe41f9b7297e285ba68ed Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Jun 2024 14:39:07 +0200 Subject: [PATCH 094/103] Remove duplicated line --- src/spikeinterface/widgets/motion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index bf9010c144..0b79350a62 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -187,7 +187,6 @@ def __init__( peak_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() else: peak_amplitudes = None - times = recording.get_times(segment_index=segment_index) if recording is not None else None if segment_index is None: assert ( From 2cc719986e5d6fceb9ea828206d7cf1d9a3fef9a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 27 Jun 2024 08:11:55 -0600 Subject: [PATCH 095/103] @alejo91 suggestion --- src/spikeinterface/core/core_tools.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index d5480d6f00..066ab58d8c 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -184,10 +184,10 @@ def is_dict_extractor(d: dict) -> bool: return is_extractor -recording_dict_element = namedtuple(typename="recording_dict_element", field_names=["value", "name", "access_path"]) +extractor_dict_element = namedtuple(typename="extractor_dict_element", field_names=["value", "name", "access_path"]) -def extractor_dict_iterator(extractor_dict: dict) -> Generator[recording_dict_element]: +def extractor_dict_iterator(extractor_dict: dict) -> Generator[extractor_dict_element]: """ Iterator for recursive traversal of a dictionary. This function explores the dictionary recursively and yields the path to each value along with the value itself. @@ -204,7 +204,7 @@ def extractor_dict_iterator(extractor_dict: dict) -> Generator[recording_dict_el Yields ------ - recording_dict_element + extractor_dict_element Named tuple containing the value, the name, and the access_path to the value in the dictionary. """ @@ -219,7 +219,7 @@ def _extractor_dict_iterator(dict_list_or_value, access_path=(), name=""): v, access_path + (i,), name=name ) # Propagate name of list to children else: - yield recording_dict_element( + yield extractor_dict_element( value=dict_list_or_value, name=name, access_path=access_path, @@ -320,7 +320,7 @@ def recursive_path_modifier(d, func, target="path", copy=True) -> dict: raise ValueError(f"{k} key for path must be str or list[str]") -# This is the current definition that an element in a recording_dict is a path +# This is the current definition that an element in a extractor_dict is a path # This is shared across a couple of definition so it is here for DNRY element_is_path = lambda element: "path" in element.name and isinstance(element.value, (str, Path)) From 61060781eef87597461241aec077aac27baff69b Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:15:14 +0100 Subject: [PATCH 096/103] SpikeRetriever arg switch --- src/spikeinterface/core/node_pipeline.py | 16 +-- .../core/tests/test_node_pipeline.py | 4 +- .../tests/test_train_manual_curation.py | 120 ++++++++++++++++++ .../postprocessing/amplitude_scalings.py | 2 +- .../postprocessing/spike_amplitudes.py | 2 +- .../postprocessing/spike_locations.py | 2 +- 6 files changed, 133 insertions(+), 13 deletions(-) create mode 100644 src/spikeinterface/curation/tests/test_train_manual_curation.py diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 1c0107d235..0722ede23f 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -152,29 +152,29 @@ class SpikeRetriever(PeakSource): * compute_spike_amplitudes() * compute_principal_components() + sorting : BaseSorting + The sorting object. recording : BaseRecording The recording object. - sorting: BaseSorting - The sorting object. - channel_from_template: bool, default: True + channel_from_template : bool, default: True If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided. If False, the max channel is computed for each spike given a radius around the template max channel. - extremum_channel_inds: dict of int | None, default: None + extremum_channel_inds : dict of int | None, default: None The extremum channel index dict given from template. - radius_um: float, default: 50 + radius_um : float, default: 50 The radius to find the real max channel. Used only when channel_from_template=False - peak_sign: "neg" | "pos", default: "neg" + peak_sign : "neg" | "pos", default: "neg" Peak sign to find the max channel. Used only when channel_from_template=False - include_spikes_in_margin: bool, default False + include_spikes_in_margin : bool, default False If not None then spikes in margin are added and an extra filed in dtype is added """ def __init__( self, - recording, sorting, + recording, channel_from_template=True, extremum_channel_inds=None, radius_um=50, diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 03acc9fed1..8d788acbad 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -87,12 +87,12 @@ def test_run_node_pipeline(cache_folder_creation): peak_retriever = PeakRetriever(recording, peaks) # channel index is from template spike_retriever_T = SpikeRetriever( - recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds + sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds ) # channel index is per spike spike_retriever_S = SpikeRetriever( - recording, sorting, + recording, channel_from_template=False, extremum_channel_inds=extremum_channel_inds, radius_um=50, diff --git a/src/spikeinterface/curation/tests/test_train_manual_curation.py b/src/spikeinterface/curation/tests/test_train_manual_curation.py new file mode 100644 index 0000000000..f0f9ff4d75 --- /dev/null +++ b/src/spikeinterface/curation/tests/test_train_manual_curation.py @@ -0,0 +1,120 @@ +import pytest +import pandas as pd +import os +import shutil + +from spikeinterface.curation.train_manual_curation import CurationModelTrainer, Objective, train_model + +# Sample data for testing +data = { + 'num_spikes': [1, 2, 3, 4, 5, 6], + 'firing_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'presence_ratio': [0.9, 0.8, 0.7, 0.6, 0.5, 0.4], + 'isi_violations_ratio': [0.01, 0.02, 0.03, 0.04, 0.05, 0.06], + 'amplitude_cutoff': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'amplitude_median': [0.2, 0.3, 0.4, 0.5, 0.6, 0.7], + 'amplitude_cv_median': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'amplitude_cv_range': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'sync_spike_2': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'sync_spike_4': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'sync_spike_8': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'firing_range': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'drift_ptp': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'drift_std': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'drift_mad': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'isolation_distance': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'l_ratio': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'd_prime': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'silhouette': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'nn_hit_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'nn_miss_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'peak_to_valley': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'peak_trough_ratio': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'half_width': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'repolarization_slope': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'recovery_slope': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'num_positive_peaks': [1, 2, 3, 4, 5, 6], + 'num_negative_peaks': [1, 2, 3, 4, 5, 6], + 'velocity_above': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'velocity_below': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'exp_decay': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'spread': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + 'is_noise': [0, 1, 0, 1, 0, 1], + 'is_sua': [1, 0, 1, 0, 1, 0], + 'majority_vote': ['good', 'bad', 'good', 'bad', 'good', 'bad'] +} + +df = pd.DataFrame(data) + +# Test initialization +def test_initialization(): + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + assert trainer.output_folder == '/tmp' + assert trainer.curator_column == 'num_spikes' + assert trainer.imputation_strategies is not None + assert trainer.scaling_techniques is not None + +# Test load_data_file +def test_load_data_file(): + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + df.to_csv('/tmp/test.csv', index=False) + trainer.load_data_file('/tmp/test.csv') + assert trainer.testing_metrics is not None + assert 0 in trainer.testing_metrics + +# Test process_test_data_for_classification +def test_process_test_data_for_classification(): + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + trainer.testing_metrics = {0: df} + trainer.process_test_data_for_classification() + assert trainer.noise_test is not None + assert trainer.sua_mua_test is not None + +# Test apply_scaling_imputation +def test_apply_scaling_imputation(): + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + X_train = df.drop(columns=['is_noise', 'is_sua', 'majority_vote']) + X_val = df.drop(columns=['is_noise', 'is_sua', 'majority_vote']) + y_train = df['is_noise'] + y_val = df['is_noise'] + result = trainer.apply_scaling_imputation('median', trainer.scaling_techniques[0][1], X_train, X_val, y_train, y_val) + assert result is not None + +# Test get_classifier_search_space +def test_get_classifier_search_space(): + from sklearn.linear_model import LogisticRegression + trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') + model, param_space = trainer.get_classifier_search_space(LogisticRegression) + assert model is not None + assert param_space is not None + +# Test Objective Enum +def test_objective_enum(): + assert Objective.Noise == Objective(1) + assert Objective.SUA == Objective(2) + assert str(Objective.Noise) == "Objective.Noise" + assert str(Objective.SUA) == "Objective.SUA" + +# Test train_model function +def test_train_model(monkeypatch): + output_folder = '/tmp/output' + os.makedirs(output_folder, exist_ok=True) + df.to_csv('/tmp/metrics.csv', index=False) + + def mock_load_and_preprocess_full(self, path): + self.testing_metrics = {0: df} + self.process_test_data_for_classification() + + monkeypatch.setattr(CurationModelTrainer, 'load_and_preprocess_full', mock_load_and_preprocess_full) + + trainer = train_model('/tmp/metrics.csv', output_folder, 'is_noise') + assert trainer is not None + assert trainer.testing_metrics is not None + assert 0 in trainer.testing_metrics + +# Clean up temporary files +@pytest.fixture(scope="module", autouse=True) +def cleanup(request): + def remove_tmp(): + shutil.rmtree('/tmp', ignore_errors=True) + request.addfinalizer(remove_tmp) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 2e544d086b..8ff9cc5666 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -170,8 +170,8 @@ def _get_pipeline_nodes(self): sparsity_mask = sparsity.mask spike_retriever_node = SpikeRetriever( - recording, sorting, + recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices, include_spikes_in_margin=True, diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index aebfd1fd78..72cbcb651f 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -95,7 +95,7 @@ def _get_pipeline_nodes(self): peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_analyzer, peak_sign=peak_sign) spike_retriever_node = SpikeRetriever( - recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channels_indices + sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices ) spike_amplitudes_node = SpikeAmplitudeNode( recording, diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 52a91342b6..23301292e5 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -103,8 +103,8 @@ def _get_pipeline_nodes(self): ) retriever = SpikeRetriever( - recording, sorting, + recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices, ) From 722c313382b6ac225a2c9119c676bc1bcab6e480 Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:17:43 +0100 Subject: [PATCH 097/103] has_exceeding_spikes arg switch --- src/spikeinterface/core/basesorting.py | 2 +- src/spikeinterface/core/frameslicesorting.py | 2 +- src/spikeinterface/core/waveform_tools.py | 2 +- src/spikeinterface/curation/remove_excess_spikes.py | 2 +- .../curation/tests/test_remove_excess_spikes.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index fd68df9dda..d9a567dedf 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -197,7 +197,7 @@ def register_recording(self, recording, check_spike_frames=True): self.get_num_segments() == recording.get_num_segments() ), "The recording has a different number of segments than the sorting!" if check_spike_frames: - if has_exceeding_spikes(recording, self): + if has_exceeding_spikes(self, recording): warnings.warn( "Some spikes exceed the recording's duration! " "Removing these excess spikes with `spikeinterface.curation.remove_excess_spikes()` " diff --git a/src/spikeinterface/core/frameslicesorting.py b/src/spikeinterface/core/frameslicesorting.py index ffd8af5fd8..f3ec449ab0 100644 --- a/src/spikeinterface/core/frameslicesorting.py +++ b/src/spikeinterface/core/frameslicesorting.py @@ -54,7 +54,7 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike assert ( start_frame <= parent_n_samples ), "`start_frame` should be smaller than the sortings' total number of samples." - if check_spike_frames and has_exceeding_spikes(parent_sorting._recording, parent_sorting): + if check_spike_frames and has_exceeding_spikes(parent_sorting, parent_sorting._recording): raise ValueError( "The sorting object has spikes whose times go beyond the recording duration." "This could indicate a bug in the sorter. " diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index befc49d034..4543074872 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -679,7 +679,7 @@ def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None return waveforms_by_units -def has_exceeding_spikes(recording, sorting) -> bool: +def has_exceeding_spikes(sorting, recording) -> bool: """ Check if the sorting objects has spikes exceeding the recording number of samples, for all segments diff --git a/src/spikeinterface/curation/remove_excess_spikes.py b/src/spikeinterface/curation/remove_excess_spikes.py index 0ae7a59fc6..d1d6b7f3cb 100644 --- a/src/spikeinterface/curation/remove_excess_spikes.py +++ b/src/spikeinterface/curation/remove_excess_spikes.py @@ -102,7 +102,7 @@ def remove_excess_spikes(sorting, recording): sorting_without_excess_spikes : Sorting The sorting without any excess spikes. """ - if has_exceeding_spikes(recording=recording, sorting=sorting): + if has_exceeding_spikes(sorting=sorting, recording=recording): return RemoveExcessSpikesSorting(sorting=sorting, recording=recording) else: return sorting diff --git a/src/spikeinterface/curation/tests/test_remove_excess_spikes.py b/src/spikeinterface/curation/tests/test_remove_excess_spikes.py index 69edbaba4c..141cc4c34e 100644 --- a/src/spikeinterface/curation/tests/test_remove_excess_spikes.py +++ b/src/spikeinterface/curation/tests/test_remove_excess_spikes.py @@ -39,10 +39,10 @@ def test_remove_excess_spikes(): labels.append(labels_segment) sorting = NumpySorting.from_times_labels(times, labels, sampling_frequency=sampling_frequency) - assert has_exceeding_spikes(recording, sorting) + assert has_exceeding_spikes(sorting, recording) sorting_corrected = remove_excess_spikes(sorting, recording) - assert not has_exceeding_spikes(recording, sorting_corrected) + assert not has_exceeding_spikes(sorting_corrected, recording) for u in sorting.unit_ids: for segment_index in range(sorting.get_num_segments()): From d0968c4c941e290488848d14c6881c7a2cdf9c8c Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:19:24 +0100 Subject: [PATCH 098/103] removed accidental commit --- .../tests/test_train_manual_curation.py | 120 ------------------ 1 file changed, 120 deletions(-) delete mode 100644 src/spikeinterface/curation/tests/test_train_manual_curation.py diff --git a/src/spikeinterface/curation/tests/test_train_manual_curation.py b/src/spikeinterface/curation/tests/test_train_manual_curation.py deleted file mode 100644 index f0f9ff4d75..0000000000 --- a/src/spikeinterface/curation/tests/test_train_manual_curation.py +++ /dev/null @@ -1,120 +0,0 @@ -import pytest -import pandas as pd -import os -import shutil - -from spikeinterface.curation.train_manual_curation import CurationModelTrainer, Objective, train_model - -# Sample data for testing -data = { - 'num_spikes': [1, 2, 3, 4, 5, 6], - 'firing_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'presence_ratio': [0.9, 0.8, 0.7, 0.6, 0.5, 0.4], - 'isi_violations_ratio': [0.01, 0.02, 0.03, 0.04, 0.05, 0.06], - 'amplitude_cutoff': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'amplitude_median': [0.2, 0.3, 0.4, 0.5, 0.6, 0.7], - 'amplitude_cv_median': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'amplitude_cv_range': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'sync_spike_2': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'sync_spike_4': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'sync_spike_8': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'firing_range': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'drift_ptp': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'drift_std': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'drift_mad': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'isolation_distance': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'l_ratio': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'd_prime': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'silhouette': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'nn_hit_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'nn_miss_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'peak_to_valley': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'peak_trough_ratio': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'half_width': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'repolarization_slope': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'recovery_slope': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'num_positive_peaks': [1, 2, 3, 4, 5, 6], - 'num_negative_peaks': [1, 2, 3, 4, 5, 6], - 'velocity_above': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'velocity_below': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'exp_decay': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'spread': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - 'is_noise': [0, 1, 0, 1, 0, 1], - 'is_sua': [1, 0, 1, 0, 1, 0], - 'majority_vote': ['good', 'bad', 'good', 'bad', 'good', 'bad'] -} - -df = pd.DataFrame(data) - -# Test initialization -def test_initialization(): - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - assert trainer.output_folder == '/tmp' - assert trainer.curator_column == 'num_spikes' - assert trainer.imputation_strategies is not None - assert trainer.scaling_techniques is not None - -# Test load_data_file -def test_load_data_file(): - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - df.to_csv('/tmp/test.csv', index=False) - trainer.load_data_file('/tmp/test.csv') - assert trainer.testing_metrics is not None - assert 0 in trainer.testing_metrics - -# Test process_test_data_for_classification -def test_process_test_data_for_classification(): - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - trainer.testing_metrics = {0: df} - trainer.process_test_data_for_classification() - assert trainer.noise_test is not None - assert trainer.sua_mua_test is not None - -# Test apply_scaling_imputation -def test_apply_scaling_imputation(): - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - X_train = df.drop(columns=['is_noise', 'is_sua', 'majority_vote']) - X_val = df.drop(columns=['is_noise', 'is_sua', 'majority_vote']) - y_train = df['is_noise'] - y_val = df['is_noise'] - result = trainer.apply_scaling_imputation('median', trainer.scaling_techniques[0][1], X_train, X_val, y_train, y_val) - assert result is not None - -# Test get_classifier_search_space -def test_get_classifier_search_space(): - from sklearn.linear_model import LogisticRegression - trainer = CurationModelTrainer(column_name='num_spikes', output_folder='/tmp') - model, param_space = trainer.get_classifier_search_space(LogisticRegression) - assert model is not None - assert param_space is not None - -# Test Objective Enum -def test_objective_enum(): - assert Objective.Noise == Objective(1) - assert Objective.SUA == Objective(2) - assert str(Objective.Noise) == "Objective.Noise" - assert str(Objective.SUA) == "Objective.SUA" - -# Test train_model function -def test_train_model(monkeypatch): - output_folder = '/tmp/output' - os.makedirs(output_folder, exist_ok=True) - df.to_csv('/tmp/metrics.csv', index=False) - - def mock_load_and_preprocess_full(self, path): - self.testing_metrics = {0: df} - self.process_test_data_for_classification() - - monkeypatch.setattr(CurationModelTrainer, 'load_and_preprocess_full', mock_load_and_preprocess_full) - - trainer = train_model('/tmp/metrics.csv', output_folder, 'is_noise') - assert trainer is not None - assert trainer.testing_metrics is not None - assert 0 in trainer.testing_metrics - -# Clean up temporary files -@pytest.fixture(scope="module", autouse=True) -def cleanup(request): - def remove_tmp(): - shutil.rmtree('/tmp', ignore_errors=True) - request.addfinalizer(remove_tmp) From f687c2c2fe9b70a970cfd39d6dd7b134c15e065f Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:20:32 +0100 Subject: [PATCH 099/103] docs --- src/spikeinterface/core/waveform_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 4543074872..98380e955f 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -685,10 +685,10 @@ def has_exceeding_spikes(sorting, recording) -> bool: Parameters ---------- - recording : BaseRecording - The recording object sorting : BaseSorting The sorting object + recording : BaseRecording + The recording object Returns ------- From b8c8fa83ba8695545b420d135c92f5167d7d2de1 Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Thu, 27 Jun 2024 15:54:59 +0100 Subject: [PATCH 100/103] Missed one --- .../postprocessing/tests/common_extension_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index bb2f5aaafd..52dbaf23d4 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -73,7 +73,7 @@ class instance is used for each. In this case, we have to set self.__class__.recording, self.__class__.sorting = get_dataset() self.__class__.sparsity = estimate_sparsity( - self.__class__.recording, self.__class__.sorting, method="radius", radius_um=20 + self.__class__.sorting, self.__class__.recording, method="radius", radius_um=20 ) self.__class__.cache_folder = create_cache_folder From 3eee955a8da3989dda6cbd84b25c0eabc2222527 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 27 Jun 2024 09:01:15 -0600 Subject: [PATCH 101/103] make test skipif --- .../core/tests/test_core_tools.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index 724517577c..7153991543 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -31,25 +31,25 @@ def test_add_suffix(): assert str(file_path_with_suffix) == expected_path +@pytest.mark.skipif(platform.system() == "Windows", reason="Runs on posix only") def test_path_utils_functions(): - if platform.system() != "Windows": - # posix path - d = { - "kwargs": { - "path": "/yep/sub/path1", - "recording": { - "module": "mock_module", - "class": "mock_class", - "version": "1.2", - "annotations": {}, - "kwargs": {"path": "/yep/sub/path2"}, - }, - } + # posix path + d = { + "kwargs": { + "path": "/yep/sub/path1", + "recording": { + "module": "mock_module", + "class": "mock_class", + "version": "1.2", + "annotations": {}, + "kwargs": {"path": "/yep/sub/path2"}, + }, } + } - d2 = recursive_path_modifier(d, lambda p: p.replace("/yep", "/yop")) - assert d2["kwargs"]["path"].startswith("/yop") - assert d2["kwargs"]["recording"]["kwargs"]["path"].startswith("/yop") + d2 = recursive_path_modifier(d, lambda p: p.replace("/yep", "/yop")) + assert d2["kwargs"]["path"].startswith("/yop") + assert d2["kwargs"]["recording"]["kwargs"]["path"].startswith("/yop") @pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") From c24c9669dcd8e53246c376c6d33eebbf39cbab83 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Thu, 27 Jun 2024 18:32:22 +0100 Subject: [PATCH 102/103] Add *sg_execution_times.rst to gitignore. (#3097) --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index d981c8de4e..6c9fa6869f 100644 --- a/.gitignore +++ b/.gitignore @@ -180,6 +180,7 @@ examples/tutorials/*.svg doc/_build/* doc/tutorials/* doc/sources/* +*sg_execution_times.rst examples/getting_started/tmp_* examples/getting_started/phy From d5ec1806bf41c27317f60e7c96cf71972400774b Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 28 Jun 2024 16:58:30 -0400 Subject: [PATCH 103/103] get rid of waveform term --- src/spikeinterface/widgets/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index b94167d2b7..9566989d31 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -139,7 +139,7 @@ def check_extensions(sorting_analyzer, extensions): if not sorting_analyzer.has_extension(extension): raise_error = True error_msg += ( - f"The {extension} waveform extension is required for this widget. " + f"The {extension} sorting analyzer extension is required for this widget. " f"Run the `sorting_analyzer.compute('{extension}', ...)` to compute it.\n" ) if raise_error: