diff --git a/.github/actions/build-test-environment/action.yml b/.github/actions/build-test-environment/action.yml index 9f3d6cc5cd..53fcd37f45 100644 --- a/.github/actions/build-test-environment/action.yml +++ b/.github/actions/build-test-environment/action.yml @@ -26,7 +26,7 @@ runs: - name: Force installation of latest dev from key-packages when running dev (not release) run: | source ${{ github.workspace }}/test_env/bin/activate - spikeinterface_is_dev_version=$(python -c "import importlib.metadata; version = importlib.metadata.version('spikeinterface'); print(version.endswith('dev0'))") + spikeinterface_is_dev_version=$(python -c "import spikeinterface; print(spikeinterface.DEV_MODE)") if [ $spikeinterface_is_dev_version = "True" ]; then echo "Running spikeinterface dev version" pip install --no-cache-dir git+https://github.com/NeuralEnsemble/python-neo diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 55d72f1be5..87c6b8acbc 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -399,10 +399,10 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save ---------- unit_ids: list or None Unit ids to retrieve waveforms for - mode: "average" | "median" | "std" | "percentile", default: "average" - The mode to compute the templates + operator: "average" | "median" | "std" | "percentile", default: "average" + The operator to compute the templates percentile: float, default: None - Percentile to use for mode="percentile" + Percentile to use for operator="percentile" save: bool, default True In case, the operator is not computed yet it can be saved to folder or zarr. @@ -437,6 +437,28 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save return np.array(templates) + def get_unit_template(self, unit_id, operator="average"): + """ + Return template for a single unit. + + Parameters + ---------- + unit_id: str | int + Unit id to retrieve waveforms for + operator: str + The operator to compute the templates + + Returns + ------- + template: np.array + The returned template (num_samples, num_channels) + """ + + templates = self.data[operator] + unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + + return np.array(templates[unit_index, :, :]) + compute_templates = ComputeTemplates.function_factory() register_result_extension(ComputeTemplates) @@ -522,6 +544,55 @@ def _select_extension_data(self, unit_ids): return new_data + def get_templates(self, unit_ids=None, operator="average"): + """ + Return average templates for multiple units. + + Parameters + ---------- + unit_ids: list or None, default: None + Unit ids to retrieve waveforms for + operator: str + MUST be "average" (only one supported by fast_templates) + The argument exist to have the same signature as ComputeTemplates.get_templates + + Returns + ------- + templates: np.array + The returned templates (num_units, num_samples, num_channels) + """ + + assert ( + operator == "average" + ), f"Analyzer extension `fast_templates` only works with 'average' templates. Given operator = {operator}" + templates = self.data["average"] + + if unit_ids is not None: + unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) + templates = templates[unit_indices, :, :] + + return np.array(templates) + + def get_unit_template(self, unit_id): + """ + Return average template for a single unit. + + Parameters + ---------- + unit_id: str | int + Unit id to retrieve waveforms for + + Returns + ------- + template: np.array + The returned template (num_samples, num_channels) + """ + + templates = self.data["average"] + unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + + return np.array(templates[unit_index, :, :]) + compute_fast_templates = ComputeFastTemplates.function_factory() register_result_extension(ComputeFastTemplates) diff --git a/src/spikeinterface/core/frameslicesorting.py b/src/spikeinterface/core/frameslicesorting.py index 4d6121cc9b..d8826618a9 100644 --- a/src/spikeinterface/core/frameslicesorting.py +++ b/src/spikeinterface/core/frameslicesorting.py @@ -56,8 +56,9 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike ), "`start_frame` should be smaller than the sortings' total number of samples." if check_spike_frames and has_exceeding_spikes(parent_sorting._recording, parent_sorting): raise ValueError( - "The sorting object has spikes exceeding the recording duration. You have to remove those spikes " - "with the `spikeinterface.curation.remove_excess_spikes()` function" + "The sorting object has spikes whose times go beyond the recording duration." + "This could indicate a bug in the sorter. " + "To remove those spikes, you can use `spikeinterface.curation.remove_excess_spikes()`." ) else: # Pull df end_frame from spikes diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 30f74e584b..52cf052ed2 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -180,8 +180,10 @@ def export_to_phy( # export templates/templates_ind/similar_templates # shape (num_units, num_samples, max_num_channels) - templates_ext = sorting_analyzer.get_extension("templates") - templates_ext is not None, "export_to_phy need SortingAnalyzer with extension 'templates'" + templates_ext = sorting_analyzer.get_extension("templates") or sorting_analyzer.get_extension("fast_templates") + assert ( + templates_ext is not None + ), "export_to_phy need SortingAnalyzer with extension 'templates' or 'fast_templates'" max_num_channels = max(len(chan_inds) for chan_inds in sparse_dict.values()) dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode) num_samples = dense_templates.shape[1] diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 287522a870..3234456e0f 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -10,15 +10,6 @@ from spikeinterface.core.core_tools import define_function_from_class -def import_lazily(): - "Makes annotations / typing available lazily" - global NWBFile, ElectricalSeries, Units, NWBHDF5IO - from pynwb import NWBFile - from pynwb.ecephys import ElectricalSeries - from pynwb.misc import Units - from pynwb import NWBHDF5IO - - def read_file_from_backend( *, file_path: str | Path | None, @@ -111,7 +102,7 @@ def read_nwbfile( cache: bool = False, stream_cache_path: str | Path | None = None, storage_options: dict | None = None, -) -> NWBFile: +) -> "NWBFile": """ Read an NWB file and return the NWBFile object. @@ -176,8 +167,8 @@ def read_nwbfile( def _retrieve_electrical_series_pynwb( - nwbfile: NWBFile, electrical_series_path: Optional[str] = None -) -> ElectricalSeries: + nwbfile: "NWBFile", electrical_series_path: Optional[str] = None +) -> "ElectricalSeries": """ Get an ElectricalSeries object from an NWBFile. @@ -230,7 +221,7 @@ def _retrieve_electrical_series_pynwb( return electrical_series -def _retrieve_unit_table_pynwb(nwbfile: NWBFile, unit_table_path: Optional[str] = None) -> Units: +def _retrieve_unit_table_pynwb(nwbfile: "NWBFile", unit_table_path: Optional[str] = None) -> "Units": """ Get an Units object from an NWBFile. Units tables can be either the main unit table (nwbfile.units) or in the processing module. diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index 8a2a677d79..be7d8dc9ed 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -136,6 +136,21 @@ class IntanRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ] +class IntanRecordingTestMultipleFilesFormat(RecordingCommonTestSuite, unittest.TestCase): + ExtractorClass = IntanRecordingExtractor + downloads = ["intan"] + entities = [ + ("intan/intan_fpc_test_231117_052630/info.rhd", {"stream_name": "RHD2000 amplifier channel"}), + ("intan/intan_fpc_test_231117_052630/info.rhd", {"stream_name": "RHD2000 auxiliary input channel"}), + ("intan/intan_fpc_test_231117_052630/info.rhd", {"stream_name": "USB board ADC input channel"}), + ("intan/intan_fpc_test_231117_052630/info.rhd", {"stream_name": "USB board digital input channel"}), + ("intan/intan_fps_test_231117_052500/info.rhd", {"stream_name": "RHD2000 amplifier channel"}), + ("intan/intan_fps_test_231117_052500/info.rhd", {"stream_name": "RHD2000 auxiliary input channel"}), + ("intan/intan_fps_test_231117_052500/info.rhd", {"stream_name": "USB board ADC input channel"}), + ("intan/intan_fps_test_231117_052500/info.rhd", {"stream_name": "USB board digital input channel"}), + ] + + class NeuroScopeRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = NeuroScopeRecordingExtractor downloads = ["neuroscope"] diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 22aee972b9..3299164c07 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -137,7 +137,7 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float bin_size = int(round(fs * bin_ms * 1e-3)) window_size -= window_size % bin_size - bins = np.arange(0, window_size + bin_size, bin_size) # * 1e3 / fs + bins = np.arange(0, window_size + bin_size, bin_size, dtype=np.int64) spikes = sorting.to_spike_vector(concatenated=False) ISIs = np.zeros((num_units, len(bins) - 1), dtype=np.int64)