Skip to content

Commit

Permalink
Merge branch 'main' into factoring_omp
Browse files Browse the repository at this point in the history
  • Loading branch information
yger authored Sep 13, 2023
2 parents dda7803 + b240298 commit e4b99cb
Show file tree
Hide file tree
Showing 7 changed files with 519 additions and 34 deletions.
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ NEO-based
.. autofunction:: read_mcsraw
.. autofunction:: read_neuralynx
.. autofunction:: read_neuralynx_sorting
.. autofunction:: read_neuroexplorer
.. autofunction:: read_neuroscope
.. autofunction:: read_nix
.. autofunction:: read_openephys
Expand All @@ -102,6 +103,7 @@ NEO-based
.. autofunction:: read_spikeglx
.. autofunction:: read_tdt


Non-NEO-based
~~~~~~~~~~~~~
.. automodule:: spikeinterface.extractors
Expand Down
17 changes: 9 additions & 8 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math

import warnings
import numpy as np
from typing import Union, Optional, List, Literal

Expand Down Expand Up @@ -1037,13 +1037,14 @@ def __init__(
parent_recording: Union[BaseRecording, None] = None,
num_samples: Optional[List[int]] = None,
upsample_vector: Union[List[int], None] = None,
check_borbers: bool = True,
check_borders: bool = False,
) -> None:
templates = np.asarray(templates)
if check_borbers:
# TODO: this should be external to this class. It is not the responsability of this class to check the templates
if check_borders:
self._check_templates(templates)
# lets test this only once so force check_borbers=false for kwargs
check_borbers = False
# lets test this only once so force check_borders=False for kwargs
check_borders = False
self.templates = templates

channel_ids = parent_recording.channel_ids if parent_recording is not None else list(range(templates.shape[2]))
Expand Down Expand Up @@ -1131,7 +1132,7 @@ def __init__(
"nbefore": nbefore,
"amplitude_factor": amplitude_factor,
"upsample_vector": upsample_vector,
"check_borbers": check_borbers,
"check_borders": check_borders,
}
if parent_recording is None:
self._kwargs["num_samples"] = num_samples
Expand All @@ -1144,8 +1145,8 @@ def _check_templates(templates: np.ndarray):
threshold = 0.01 * max_value

if max(np.max(np.abs(templates[:, 0])), np.max(np.abs(templates[:, -1]))) > threshold:
raise Exception(
"Warning!\nYour templates do not go to 0 on the edges in InjectTemplatesRecording.__init__\nPlease make your window bigger."
warnings.warn(
"Warning! Your templates do not go to 0 on the edges in InjectTemplatesRecording. Please make your window bigger."
)


Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/extractors/neoextractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
read_neuroscope_sorting,
read_neuroscope,
)
from .neuroexplorer import NeuroExplorerRecordingExtractor, read_neuroexplorer
from .nix import NixRecordingExtractor, read_nix
from .openephys import (
OpenEphysLegacyRecordingExtractor,
Expand Down Expand Up @@ -62,6 +63,7 @@
SpikeGadgetsRecordingExtractor,
SpikeGLXRecordingExtractor,
TdtRecordingExtractor,
NeuroExplorerRecordingExtractor,
]

neo_sorting_extractors_list = [
Expand Down
66 changes: 66 additions & 0 deletions src/spikeinterface/extractors/neoextractors/neuroexplorer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from pathlib import Path

from spikeinterface.core.core_tools import define_function_from_class

from .neobaseextractor import NeoBaseRecordingExtractor


class NeuroExplorerRecordingExtractor(NeoBaseRecordingExtractor):
"""
Class for reading NEX (NeuroExplorer data format) files.
Based on :py:class:`neo.rawio.NeuroExplorerRawIO`
Importantly, at the moment, this recorder only extracts one channel of the recording.
This is because the NeuroExplorerRawIO class does not support multi-channel recordings
as in the NeuroExplorer format they might have different sampling rates.
Consider extracting all the channels and then concatenating them with the aggregate_channels function.
>>> from spikeinterface.extractors.neoextractors.neuroexplorer import NeuroExplorerRecordingExtractor
>>> from spikeinterface.core import aggregate_channels
>>>
>>> file_path="/the/path/to/your/nex/file.nex"
>>>
>>> streams = NeuroExplorerRecordingExtractor.get_streams(file_path=file_path)
>>> stream_names = streams[0]
>>>
>>> your_signal_stream_names = "Here goes the logic to filter from stream names the ones that you know have the same sampling rate and you want to aggregate"
>>>
>>> recording_list = [NeuroExplorerRecordingExtractor(file_path=file_path, stream_name=stream_name) for stream_name in your_signal_stream_names]
>>> recording = aggregate_channels(recording_list)
Parameters
----------
file_path: str
The file path to load the recordings from.
stream_id: str, optional
If there are several streams, specify the stream id you want to load.
For this neo reader streams are defined by their sampling frequency.
stream_name: str, optional
If there are several streams, specify the stream name you want to load.
all_annotations: bool, default: False
Load exhaustively all annotations from neo.
"""

mode = "file"
NeoRawIOClass = "NeuroExplorerRawIO"
name = "neuroexplorer"

def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False):
neo_kwargs = {"filename": str(file_path)}
NeoBaseRecordingExtractor.__init__(
self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs
)
self._kwargs.update({"file_path": str(Path(file_path).absolute())})
self.extra_requirements.append("neo[edf]")

@classmethod
def map_to_neo_kwargs(cls, file_path):
neo_kwargs = {"filename": str(file_path)}
return neo_kwargs


read_neuroexplorer = define_function_from_class(source_class=NeuroExplorerRecordingExtractor, name="read_neuroexplorer")
11 changes: 11 additions & 0 deletions src/spikeinterface/extractors/tests/test_neoextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,17 @@ class NeuroScopeRecordingTest(RecordingCommonTestSuite, unittest.TestCase):
]


class NeuroExplorerRecordingTest(RecordingCommonTestSuite, unittest.TestCase):
ExtractorClass = NeuroExplorerRecordingExtractor
downloads = ["neuroexplorer"]
entities = [
("neuroexplorer/File_neuroexplorer_1.nex", {"stream_name": "ContChannel01"}),
("neuroexplorer/File_neuroexplorer_1.nex", {"stream_name": "ContChannel02"}),
("neuroexplorer/File_neuroexplorer_2.nex", {"stream_name": "ContChannel01"}),
("neuroexplorer/File_neuroexplorer_2.nex", {"stream_name": "ContChannel02"}),
]


class NeuroScopeSortingTest(SortingCommonTestSuite, unittest.TestCase):
ExtractorClass = NeuroScopeSortingExtractor
downloads = ["neuroscope"]
Expand Down
Loading

0 comments on commit e4b99cb

Please sign in to comment.