Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
… into circus2_improvements
  • Loading branch information
yger committed Mar 28, 2024
2 parents 62ceeb3 + 9704a8f commit 2510af4
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/actions/build-test-environment/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ runs:
- name: Force installation of latest dev from key-packages when running dev (not release)
run: |
source ${{ github.workspace }}/test_env/bin/activate
spikeinterface_is_dev_version=$(python -c "import importlib.metadata; version = importlib.metadata.version('spikeinterface'); print(version.endswith('dev0'))")
spikeinterface_is_dev_version=$(python -c "import spikeinterface; print(spikeinterface.DEV_MODE)")
if [ $spikeinterface_is_dev_version = "True" ]; then
echo "Running spikeinterface dev version"
pip install --no-cache-dir git+https://github.com/NeuralEnsemble/python-neo
Expand Down
77 changes: 74 additions & 3 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,10 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
----------
unit_ids: list or None
Unit ids to retrieve waveforms for
mode: "average" | "median" | "std" | "percentile", default: "average"
The mode to compute the templates
operator: "average" | "median" | "std" | "percentile", default: "average"
The operator to compute the templates
percentile: float, default: None
Percentile to use for mode="percentile"
Percentile to use for operator="percentile"
save: bool, default True
In case, the operator is not computed yet it can be saved to folder or zarr.
Expand Down Expand Up @@ -437,6 +437,28 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save

return np.array(templates)

def get_unit_template(self, unit_id, operator="average"):
"""
Return template for a single unit.
Parameters
----------
unit_id: str | int
Unit id to retrieve waveforms for
operator: str
The operator to compute the templates
Returns
-------
template: np.array
The returned template (num_samples, num_channels)
"""

templates = self.data[operator]
unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id)

return np.array(templates[unit_index, :, :])


compute_templates = ComputeTemplates.function_factory()
register_result_extension(ComputeTemplates)
Expand Down Expand Up @@ -522,6 +544,55 @@ def _select_extension_data(self, unit_ids):

return new_data

def get_templates(self, unit_ids=None, operator="average"):
"""
Return average templates for multiple units.
Parameters
----------
unit_ids: list or None, default: None
Unit ids to retrieve waveforms for
operator: str
MUST be "average" (only one supported by fast_templates)
The argument exist to have the same signature as ComputeTemplates.get_templates
Returns
-------
templates: np.array
The returned templates (num_units, num_samples, num_channels)
"""

assert (
operator == "average"
), f"Analyzer extension `fast_templates` only works with 'average' templates. Given operator = {operator}"
templates = self.data["average"]

if unit_ids is not None:
unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids)
templates = templates[unit_indices, :, :]

return np.array(templates)

def get_unit_template(self, unit_id):
"""
Return average template for a single unit.
Parameters
----------
unit_id: str | int
Unit id to retrieve waveforms for
Returns
-------
template: np.array
The returned template (num_samples, num_channels)
"""

templates = self.data["average"]
unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id)

return np.array(templates[unit_index, :, :])


compute_fast_templates = ComputeFastTemplates.function_factory()
register_result_extension(ComputeFastTemplates)
Expand Down
5 changes: 3 additions & 2 deletions src/spikeinterface/core/frameslicesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike
), "`start_frame` should be smaller than the sortings' total number of samples."
if check_spike_frames and has_exceeding_spikes(parent_sorting._recording, parent_sorting):
raise ValueError(
"The sorting object has spikes exceeding the recording duration. You have to remove those spikes "
"with the `spikeinterface.curation.remove_excess_spikes()` function"
"The sorting object has spikes whose times go beyond the recording duration."
"This could indicate a bug in the sorter. "
"To remove those spikes, you can use `spikeinterface.curation.remove_excess_spikes()`."
)
else:
# Pull df end_frame from spikes
Expand Down
6 changes: 4 additions & 2 deletions src/spikeinterface/exporters/to_phy.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,10 @@ def export_to_phy(

# export templates/templates_ind/similar_templates
# shape (num_units, num_samples, max_num_channels)
templates_ext = sorting_analyzer.get_extension("templates")
templates_ext is not None, "export_to_phy need SortingAnalyzer with extension 'templates'"
templates_ext = sorting_analyzer.get_extension("templates") or sorting_analyzer.get_extension("fast_templates")
assert (
templates_ext is not None
), "export_to_phy need SortingAnalyzer with extension 'templates' or 'fast_templates'"
max_num_channels = max(len(chan_inds) for chan_inds in sparse_dict.values())
dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode)
num_samples = dense_templates.shape[1]
Expand Down
17 changes: 4 additions & 13 deletions src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,6 @@
from spikeinterface.core.core_tools import define_function_from_class


def import_lazily():
"Makes annotations / typing available lazily"
global NWBFile, ElectricalSeries, Units, NWBHDF5IO
from pynwb import NWBFile
from pynwb.ecephys import ElectricalSeries
from pynwb.misc import Units
from pynwb import NWBHDF5IO


def read_file_from_backend(
*,
file_path: str | Path | None,
Expand Down Expand Up @@ -111,7 +102,7 @@ def read_nwbfile(
cache: bool = False,
stream_cache_path: str | Path | None = None,
storage_options: dict | None = None,
) -> NWBFile:
) -> "NWBFile":
"""
Read an NWB file and return the NWBFile object.
Expand Down Expand Up @@ -176,8 +167,8 @@ def read_nwbfile(


def _retrieve_electrical_series_pynwb(
nwbfile: NWBFile, electrical_series_path: Optional[str] = None
) -> ElectricalSeries:
nwbfile: "NWBFile", electrical_series_path: Optional[str] = None
) -> "ElectricalSeries":
"""
Get an ElectricalSeries object from an NWBFile.
Expand Down Expand Up @@ -230,7 +221,7 @@ def _retrieve_electrical_series_pynwb(
return electrical_series


def _retrieve_unit_table_pynwb(nwbfile: NWBFile, unit_table_path: Optional[str] = None) -> Units:
def _retrieve_unit_table_pynwb(nwbfile: "NWBFile", unit_table_path: Optional[str] = None) -> "Units":
"""
Get an Units object from an NWBFile.
Units tables can be either the main unit table (nwbfile.units) or in the processing module.
Expand Down
15 changes: 15 additions & 0 deletions src/spikeinterface/extractors/tests/test_neoextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,21 @@ class IntanRecordingTest(RecordingCommonTestSuite, unittest.TestCase):
]


class IntanRecordingTestMultipleFilesFormat(RecordingCommonTestSuite, unittest.TestCase):
ExtractorClass = IntanRecordingExtractor
downloads = ["intan"]
entities = [
("intan/intan_fpc_test_231117_052630/info.rhd", {"stream_name": "RHD2000 amplifier channel"}),
("intan/intan_fpc_test_231117_052630/info.rhd", {"stream_name": "RHD2000 auxiliary input channel"}),
("intan/intan_fpc_test_231117_052630/info.rhd", {"stream_name": "USB board ADC input channel"}),
("intan/intan_fpc_test_231117_052630/info.rhd", {"stream_name": "USB board digital input channel"}),
("intan/intan_fps_test_231117_052500/info.rhd", {"stream_name": "RHD2000 amplifier channel"}),
("intan/intan_fps_test_231117_052500/info.rhd", {"stream_name": "RHD2000 auxiliary input channel"}),
("intan/intan_fps_test_231117_052500/info.rhd", {"stream_name": "USB board ADC input channel"}),
("intan/intan_fps_test_231117_052500/info.rhd", {"stream_name": "USB board digital input channel"}),
]


class NeuroScopeRecordingTest(RecordingCommonTestSuite, unittest.TestCase):
ExtractorClass = NeuroScopeRecordingExtractor
downloads = ["neuroscope"]
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/isi.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float
bin_size = int(round(fs * bin_ms * 1e-3))
window_size -= window_size % bin_size

bins = np.arange(0, window_size + bin_size, bin_size) # * 1e3 / fs
bins = np.arange(0, window_size + bin_size, bin_size, dtype=np.int64)
spikes = sorting.to_spike_vector(concatenated=False)

ISIs = np.zeros((num_units, len(bins) - 1), dtype=np.int64)
Expand Down

0 comments on commit 2510af4

Please sign in to comment.