diff --git a/src/spikeinterface/core/tests/test_channelslicerecording.py b/src/spikeinterface/core/tests/test_channelslicerecording.py index 08bb22a2c8..565a743084 100644 --- a/src/spikeinterface/core/tests/test_channelslicerecording.py +++ b/src/spikeinterface/core/tests/test_channelslicerecording.py @@ -4,7 +4,7 @@ import pytest import numpy as np -import probeinterface as pi +import probeinterface from spikeinterface.core import ChannelSliceRecording, BinaryRecordingExtractor @@ -58,7 +58,7 @@ def test_ChannelSliceRecording(): assert np.all(traces[:, 1] == 0) # with probe and after save() - probe = pi.generate_linear_probe(num_elec=num_chan) + probe = probeinterface.generate_linear_probe(num_elec=num_chan) probe.set_device_channel_indices(np.arange(num_chan)) rec_p = rec.set_probe(probe) rec_sliced3 = ChannelSliceRecording(rec_p, channel_ids=[0, 2], renamed_channel_ids=[3, 4]) diff --git a/src/spikeinterface/extractors/bids.py b/src/spikeinterface/extractors/bids.py index 8b70722652..b1888f4a27 100644 --- a/src/spikeinterface/extractors/bids.py +++ b/src/spikeinterface/extractors/bids.py @@ -3,7 +3,7 @@ import numpy as np import neo -from probeinterface import read_BIDS_probe +import probeinterface from .nwbextractors import read_nwb from .neoextractors import read_nix @@ -60,7 +60,7 @@ def read_bids(folder_path): def _read_probe_group(folder, bids_name, recording_channel_ids): - probegroup = read_BIDS_probe(folder) + probegroup = probeinterface.read_BIDS_probe(folder) # make maps between : channel_id and contact_id using _channels.tsv import pandas as pd diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index bd56208ebe..37ed931d1a 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -1,6 +1,6 @@ from pathlib import Path -import probeinterface as pi +import probeinterface from spikeinterface.core import BaseRecording, BaseRecordingSegment from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts @@ -89,7 +89,7 @@ def __init__(self, folder_path, load_sync_channel=False, stream_name="ap"): self.set_channel_offsets(offsets) if not load_sync_channel: - probe = pi.read_spikeglx(meta_file) + probe = probeinterface.read_spikeglx(meta_file) if probe.shank_ids is not None: self.set_probe(probe, in_place=True, group_mode="by_shank") diff --git a/src/spikeinterface/extractors/iblstreamingrecording.py b/src/spikeinterface/extractors/iblstreamingrecording.py index ca0f0a0335..35dccbef1e 100644 --- a/src/spikeinterface/extractors/iblstreamingrecording.py +++ b/src/spikeinterface/extractors/iblstreamingrecording.py @@ -4,7 +4,7 @@ from pathlib import Path import numpy as np -import probeinterface as pi +import probeinterface from spikeinterface.core import BaseRecording, BaseRecordingSegment from spikeinterface.core.core_tools import define_function_from_class @@ -165,7 +165,7 @@ def __init__( # set probe if not load_sync_channel: - probe = pi.read_spikeglx(meta_file) + probe = probeinterface.read_spikeglx(meta_file) if probe.shank_ids is not None: self.set_probe(probe, in_place=True, group_mode="by_shank") diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index cde1167835..b4f1e3f341 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -1,6 +1,6 @@ from pathlib import Path -import probeinterface as pi +import probeinterface from spikeinterface.core.core_tools import define_function_from_class @@ -54,7 +54,7 @@ def __init__( probe_kwargs["mea_pitch"] = mea_pitch if electrode_width is not None: probe_kwargs["electrode_width"] = electrode_width - probe = pi.read_3brain(file_path, **probe_kwargs) + probe = probeinterface.read_3brain(file_path, **probe_kwargs) self.set_probe(probe, in_place=True) self.set_property("row", self.get_property("contact_vector")["row"]) self.set_property("col", self.get_property("contact_vector")["col"]) diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index ea54f9f201..ca03aa7f85 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -1,7 +1,7 @@ import numpy as np from pathlib import Path -import probeinterface as pi +import probeinterface from spikeinterface import BaseEvent, BaseEventSegment from spikeinterface.core.core_tools import define_function_from_class @@ -68,7 +68,7 @@ def __init__( well_name = self.stream_id # rec_name auto set by neo rec_name = self.neo_reader.rec_name - probe = pi.read_maxwell(file_path, well_name=well_name, rec_name=rec_name) + probe = probeinterface.read_maxwell(file_path, well_name=well_name, rec_name=rec_name) self.set_probe(probe, in_place=True) self.set_property("electrode", self.get_property("contact_vector")["electrode"]) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), rec_name=rec_name)) diff --git a/src/spikeinterface/extractors/neoextractors/mearec.py b/src/spikeinterface/extractors/neoextractors/mearec.py index 7dda9175f5..c0b820f65b 100644 --- a/src/spikeinterface/extractors/neoextractors/mearec.py +++ b/src/spikeinterface/extractors/neoextractors/mearec.py @@ -3,7 +3,7 @@ import numpy as np -import probeinterface as pi +import probeinterface from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor @@ -48,7 +48,7 @@ def __init__(self, file_path: Union[str, Path], all_annotations: bool = False): self.extra_requirements.append("mearec") - probe = pi.read_mearec(file_path) + probe = probeinterface.read_mearec(file_path) probe.annotations["mearec_name"] = str(probe.annotations["mearec_name"]) self.set_probe(probe, in_place=True) self.annotate(is_filtered=True) diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index 1d17cd728c..6a37ab8d06 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -1,5 +1,4 @@ """ - There are two extractors for data saved by the Open Ephys GUI * OpenEphysLegacyRecordingExtractor: reads the original "Open Ephys" data format @@ -7,7 +6,6 @@ See https://open-ephys.github.io/gui-docs/User-Manual/Recording-data/index.html for more info. - """ from pathlib import Path @@ -15,7 +13,7 @@ import numpy as np import warnings -import probeinterface as pi +import probeinterface from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor, NeoBaseEventExtractor @@ -23,10 +21,10 @@ def drop_invalid_neo_arguments_for_version_0_12_0(neo_kwargs): - # Temporary function until neo version 0.13.0 is released from packaging.version import Version from importlib.metadata import version as lib_version + # Temporary function until neo version 0.13.0 is released neo_version = lib_version("neo") # The possibility of ignoring timestamps errors is not present in neo <= 0.12.0 if Version(neo_version) <= Version("0.12.0"): @@ -178,7 +176,9 @@ def __init__( settings_file = self.neo_reader.folder_structure[record_node]["experiments"][exp_id]["settings_file"] if Path(settings_file).is_file(): - probe = pi.read_openephys(settings_file=settings_file, stream_name=stream_name, raise_error=False) + probe = probeinterface.read_openephys( + settings_file=settings_file, stream_name=stream_name, raise_error=False + ) else: probe = None @@ -187,9 +187,16 @@ def __init__( self.set_probe(probe, in_place=True, group_mode="by_shank") else: self.set_probe(probe, in_place=True) - probe_name = probe.annotations["probe_name"] + + # this handles a breaking change in probeinterface after v0.2.18 + # in the new version, the Neuropixels model name is stored in the "model_name" annotation, + # rather than in the "probe_name" annotation + model_name = probe.annotations.get("model_name", None) + if model_name is None: + model_name = probe.annotations["probe_name"] + # load num_channels_per_adc depending on probe type - if "2.0" in probe_name: + if "2.0" in model_name: num_channels_per_adc = 16 num_cycles_in_adc = 16 total_channels = 384 @@ -203,7 +210,7 @@ def __init__( sample_shifts = get_neuropixels_sample_shifts(total_channels, num_channels_per_adc, num_cycles_in_adc) if self.get_num_channels() != total_channels: # need slice because not all channel are saved - chans = pi.get_saved_channel_indices_from_openephys_settings(settings_file, oe_stream) + chans = probeinterface.get_saved_channel_indices_from_openephys_settings(settings_file, oe_stream) # lets clip to 384 because this contains also the synchro channel chans = chans[chans < total_channels] sample_shifts = sample_shifts[chans] diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index db6ee9bd48..6a6901b62e 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -4,7 +4,7 @@ from pathlib import Path import neo -import probeinterface as pi +import probeinterface from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts @@ -60,7 +60,7 @@ def __init__(self, folder_path, load_sync_channel=False, stream_id=None, stream_ # Load probe geometry if available if "lf" in self.stream_id: meta_filename = meta_filename.replace(".lf", ".ap") - probe = pi.read_spikeglx(meta_filename) + probe = probeinterface.read_spikeglx(meta_filename) if probe.shank_ids is not None: self.set_probe(probe, in_place=True, group_mode="by_shank") @@ -84,7 +84,7 @@ def __init__(self, folder_path, load_sync_channel=False, stream_id=None, stream_ sample_shifts = get_neuropixels_sample_shifts(total_channels, num_channels_per_adc, num_cycles_in_adc) if self.get_num_channels() != total_channels: # need slice because not all channel are saved - chans = pi.get_saved_channel_indices_from_spikeglx_meta(meta_filename) + chans = probeinterface.get_saved_channel_indices_from_spikeglx_meta(meta_filename) # lets clip to 384 because this contains also the synchro channel chans = chans[chans < total_channels] sample_shifts = sample_shifts[chans] diff --git a/src/spikeinterface/extractors/shybridextractors.py b/src/spikeinterface/extractors/shybridextractors.py index d7ae09144f..ccb97e31b3 100644 --- a/src/spikeinterface/extractors/shybridextractors.py +++ b/src/spikeinterface/extractors/shybridextractors.py @@ -2,7 +2,7 @@ import numpy as np -from probeinterface import read_prb, write_prb +import probeinterface from spikeinterface.core import BinaryRecordingExtractor, BaseRecordingSegment, BaseSorting, BaseSortingSegment from spikeinterface.core.core_tools import write_binary_recording, define_function_from_class @@ -69,7 +69,7 @@ def __init__(self, file_path): ) # load probe file - probegroup = read_prb(params["probe"]) + probegroup = probeinterface.read_prb(params["probe"]) self.set_probegroup(probegroup, in_place=True) self._kwargs = {"file_path": str(Path(file_path).absolute())} self.extra_requirements.extend(["hybridizer", "pyyaml"]) @@ -119,7 +119,7 @@ def write_recording(recording, save_path, initial_sorting_fn, dtype="float32", * # write probe file probe_fn = (save_path / probe_name).absolute() probegroup = recording.get_probegroup() - write_prb(probe_fn, probegroup, total_nb_channels=recording.get_num_channels()) + probeinterface.write_prb(probe_fn, probegroup, total_nb_channels=recording.get_num_channels()) # create parameters file parameters = dict( diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index 257c1d566a..64c6499767 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -1,10 +1,10 @@ import unittest import platform import subprocess +import os from packaging import version import pytest -import numpy as np from spikeinterface.core.testing import check_recordings_equal from spikeinterface import get_global_dataset_folder @@ -16,6 +16,7 @@ EventCommonTestSuite, ) +ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) local_folder = get_global_dataset_folder() / "ephy_testing_data" @@ -277,6 +278,7 @@ class CedRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ] +@pytest.mark.skipif(ON_GITHUB, reason="Maxwell plugin not installed on GitHub") class MaxwellRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = MaxwellRecordingExtractor downloads = ["maxwell"] diff --git a/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py index 0525cdfc7a..43b35dfef9 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py @@ -2,7 +2,7 @@ import numpy as np from pathlib import Path -import probeinterface as pi +import probeinterface from spikeinterface import download_dataset, generate_recording, append_recordings, concatenate_recordings from spikeinterface.extractors import read_mearec, read_spikeglx, read_openephys from spikeinterface.preprocessing import depth_order, zscore @@ -29,7 +29,7 @@ def recording_and_shape(): num_cols = 2 num_rows = 64 - probe = pi.generate_multi_columns_probe(num_columns=num_cols, num_contact_per_column=num_rows) + probe = probeinterface.generate_multi_columns_probe(num_columns=num_cols, num_contact_per_column=num_rows) probe.set_device_channel_indices(np.arange(num_cols * num_rows)) recording = generate_recording(num_channels=num_cols * num_rows, durations=[10.0], sampling_frequency=30000) recording.set_probe(probe, in_place=True) diff --git a/src/spikeinterface/preprocessing/tests/test_zero_padding.py b/src/spikeinterface/preprocessing/tests/test_zero_padding.py index 75d64b0088..954f5ed7e8 100644 --- a/src/spikeinterface/preprocessing/tests/test_zero_padding.py +++ b/src/spikeinterface/preprocessing/tests/test_zero_padding.py @@ -6,7 +6,7 @@ from spikeinterface.core import generate_recording from spikeinterface.core.numpyextractors import NumpyRecording -from spikeinterface.preprocessing import zero_channel_pad +from spikeinterface.preprocessing import zero_channel_pad, bandpass_filter, phase_shift from spikeinterface.preprocessing.zero_channel_pad import TracePaddedRecording if hasattr(pytest, "global_test_folder"): @@ -39,7 +39,7 @@ def test_zero_padding_channel(): @pytest.fixture def recording(): num_channels = 4 - num_samples = 10 + num_samples = 10000 rng = np.random.default_rng(seed=0) traces = rng.random(size=(num_samples, num_channels)) traces_list = [traces] @@ -258,5 +258,74 @@ def test_trace_padded_recording_retrieve_traces_with_partial_padding(recording, assert np.allclose(padded_traces_end, expected_zeros) +@pytest.mark.parametrize("padding_start, padding_end", [(5, 5), (0, 5), (5, 0), (0, 0)]) +def test_trace_padded_recording_retrieve_only_start_padding(recording, padding_start, padding_end): + num_samples = recording.get_num_samples() + num_channels = recording.get_num_channels() + + padded_recording = TracePaddedRecording( + parent_recording=recording, + padding_start=padding_start, + padding_end=padding_end, + ) + + # Retrieve the padding at the start and test it + padded_traces_start = padded_recording.get_traces(start_frame=0, end_frame=padding_start) + expected_traces = np.zeros((padding_start, num_channels)) + assert np.allclose(padded_traces_start, expected_traces) + + +@pytest.mark.parametrize("padding_start, padding_end", [(5, 5), (0, 5), (5, 0), (0, 0)]) +def test_trace_padded_recording_retrieve_only_end_padding(recording, padding_start, padding_end): + num_samples = recording.get_num_samples() + num_channels = recording.get_num_channels() + + padded_recording = TracePaddedRecording( + parent_recording=recording, + padding_start=padding_start, + padding_end=padding_end, + ) + + # Retrieve the padding at the end and test it + start_frame = padding_start + num_samples + end_frame = padding_start + num_samples + padding_end + padded_traces_end = padded_recording.get_traces(start_frame=start_frame, end_frame=end_frame) + expected_traces = np.zeros((padding_end, num_channels)) + assert np.allclose(padded_traces_end, expected_traces) + + +@pytest.mark.parametrize("preprocessing", ["bandpass_filter", "phase_shift"]) +@pytest.mark.parametrize("padding_start, padding_end", [(5, 5), (0, 5), (5, 0), (0, 0)]) +def test_trace_padded_recording_retrieve_only_end_padding_with_preprocessing( + recording, padding_start, padding_end, preprocessing +): + """This is a tmeporary test to check that this works when the recording is called out of bonds. It should be removed + when more general test are added in that direction""" + + num_samples = recording.get_num_samples() + num_channels = recording.get_num_channels() + + if preprocessing == "bandpass_filter": + recording = bandpass_filter(recording, freq_min=300, freq_max=6000) + else: + sample_shift_size = 0.4 + inter_sample_shift = np.arange(recording.get_num_channels()) * sample_shift_size + recording.set_property("inter_sample_shift", inter_sample_shift) + recording = phase_shift(recording) + + padded_recording = TracePaddedRecording( + parent_recording=recording, + padding_start=padding_start, + padding_end=padding_end, + ) + + # Retrieve the padding at the end and test it + start_frame = padding_start + num_samples + end_frame = padding_start + num_samples + padding_end + padded_traces_end = padded_recording.get_traces(start_frame=start_frame, end_frame=end_frame) + expected_traces = np.zeros((padding_end, num_channels)) + assert np.allclose(padded_traces_end, expected_traces) + + if __name__ == "__main__": test_zero_padding_channel() diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index ee6eb014aa..c1ed31f508 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -18,9 +18,11 @@ class TracePaddedRecording(BasePreprocessor): parent_recording_segment : BaseRecording The parent recording segment from which the traces are to be retrieved. padding_start : int, default: 0 - The amount of padding to add to the left of the traces. It has to be non-negative + The amount of padding to add to the left of the traces. It has to be non-negative. + Note that this counts the number of samples, not the number of seconds. padding_end : int, default: 0 The amount of padding to add to the right of the traces. It has to be non-negative + Note that this counts the number of samples, not the number of seconds fill_value: float, default: 0 The value to pad with """ @@ -88,13 +90,18 @@ def get_traces(self, start_frame, end_frame, channel_indices): raise ValueError(f"Unsupported channel_indices type: {type(channel_indices)} raise an issue on github ") # This avoids an extra memory allocation if we are within the confines of the old traces - if start_frame > self.padding_start and end_frame < self.num_samples_in_original_segment + self.padding_start: + end_of_original_traces = self.num_samples_in_original_segment + self.padding_start + if start_frame > self.padding_start and end_frame < end_of_original_traces: return self.get_original_traces_shifted(start_frame, end_frame, channel_indices) - # Else, we start with the full padded traces and allocate the original traces in the middle + # We start with the full padded traces and fill in the original traces if necessary output_traces = np.full(shape=(trace_size, num_channels), fill_value=self.fill_value, dtype=self.dtype) - # After the padding, the original traces are placed in the middle until the end of the original traces + # If start frame is larger than the end of the original traces, we return the padded traces as they are + if start_frame >= end_of_original_traces and end_frame > end_of_original_traces: + return output_traces + + # We add the original traces if end_frame is larger than the start of the original traces if end_frame >= self.padding_start: original_traces = self.get_original_traces_shifted( start_frame=start_frame, @@ -119,6 +126,7 @@ def get_original_traces_shifted(self, start_frame, end_frame, channel_indices): """ original_start_frame = max(start_frame - self.padding_start, 0) original_end_frame = min(end_frame - self.padding_start, self.num_samples_in_original_segment) + original_traces = self.parent_recording_segment.get_traces( start_frame=original_start_frame, end_frame=original_end_frame,