From 1d26a552679ab764172d760893b5f9e325d3e814 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 8 Dec 2020 18:23:09 +0100 Subject: [PATCH 1/3] CEDREcordingExtractor: assert that the list of channels is not empty --- .../extractors/cedextractors/cedrecordingextractor.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/spikeextractors/extractors/cedextractors/cedrecordingextractor.py b/spikeextractors/extractors/cedextractors/cedrecordingextractor.py index 682393f8..70666cf9 100644 --- a/spikeextractors/extractors/cedextractors/cedrecordingextractor.py +++ b/spikeextractors/extractors/cedextractors/cedrecordingextractor.py @@ -27,7 +27,7 @@ class CEDRecordingExtractor(RecordingExtractor): ---------- file_path: str Path to the .smrx file to be extracted - smrx_ch_inds: list of int + smrx_channel_ids: list of int List with indexes of valid smrx channels. Does not match necessarily with extractor id. """ @@ -38,10 +38,11 @@ class CEDRecordingExtractor(RecordingExtractor): mode = 'file' installation_mesg = "To use the CED extractor, install sonpy: \n\n pip install sonpy\n\n" # error message when not installed - def __init__(self, file_path: PathType, smrx_ch_inds: list): + def __init__(self, file_path: PathType, smrx_channel_ids: list): assert HAVE_SONPY, self.installation_mesg file_path = Path(file_path) assert file_path.is_file() and file_path.suffix == '.smrx', 'file_path must lead to a .smrx file!' + assert len(smrx_channel_ids) > 0, "'smrx_channel_ids' cannot be an empty list!" super().__init__() @@ -55,7 +56,7 @@ def __init__(self, file_path: PathType, smrx_ch_inds: list): # get channel info / set channel gains self._channelid_to_smrxind = dict() self._channel_smrxinfo = dict() - for i, ind in enumerate(smrx_ch_inds): + for i, ind in enumerate(smrx_channel_ids): if self._recording_file.ChannelType(ind) == sp.DataType.Off: raise ValueError(f'Channel {ind} is type Off and cannot be used') self._channelid_to_smrxind[i] = ind @@ -72,7 +73,7 @@ def __init__(self, file_path: PathType, smrx_ch_inds: list): ) self._kwargs = {'file_path': str(Path(file_path).absolute()), - 'smrx_ch_inds': smrx_ch_inds} + 'smrx_channel_ids': smrx_channel_ids} @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None): From 44e5c1cc98017ad245b77318f4c0acd991753a23 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 8 Dec 2020 18:35:24 +0100 Subject: [PATCH 2/3] Added assertion on same sampling rate between channels --- .../extractors/cedextractors/cedrecordingextractor.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/spikeextractors/extractors/cedextractors/cedrecordingextractor.py b/spikeextractors/extractors/cedextractors/cedrecordingextractor.py index 70666cf9..19764038 100644 --- a/spikeextractors/extractors/cedextractors/cedrecordingextractor.py +++ b/spikeextractors/extractors/cedextractors/cedrecordingextractor.py @@ -72,6 +72,11 @@ def __init__(self, file_path: PathType, smrx_channel_ids: list): gains=gains ) + rate0 = self._channel_smrxinfo[0]['rate'] + for chan, info in self._channel_smrxinfo.items(): + assert info['rate'] == rate0, "Inconsistency between 'sampling_frequency' of different channels. The " \ + "extractor only supports channels with the same 'rate'" + self._kwargs = {'file_path': str(Path(file_path).absolute()), 'smrx_channel_ids': smrx_channel_ids} From 9738948a177e1ed74c5e4a9015a6cb20a7a8e5ea Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 8 Dec 2020 18:42:04 +0100 Subject: [PATCH 3/3] Added 'channel_names' property --- .../extractors/cedextractors/cedrecordingextractor.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/spikeextractors/extractors/cedextractors/cedrecordingextractor.py b/spikeextractors/extractors/cedextractors/cedrecordingextractor.py index 19764038..3641539f 100644 --- a/spikeextractors/extractors/cedextractors/cedrecordingextractor.py +++ b/spikeextractors/extractors/cedextractors/cedrecordingextractor.py @@ -5,6 +5,7 @@ import numpy as np from pathlib import Path from typing import Union +from copy import deepcopy try: from sonpy import lib as sp @@ -56,6 +57,7 @@ def __init__(self, file_path: PathType, smrx_channel_ids: list): # get channel info / set channel gains self._channelid_to_smrxind = dict() self._channel_smrxinfo = dict() + self._channel_names = [] for i, ind in enumerate(smrx_channel_ids): if self._recording_file.ChannelType(ind) == sp.DataType.Off: raise ValueError(f'Channel {ind} is type Off and cannot be used') @@ -71,6 +73,7 @@ def __init__(self, file_path: PathType, smrx_channel_ids: list): channel_ids=[i], gains=gains ) + self._channel_names.append(self._channel_smrxinfo[i]['title']) rate0 = self._channel_smrxinfo[0]['rate'] for chan, info in self._channel_smrxinfo.items(): @@ -80,6 +83,10 @@ def __init__(self, file_path: PathType, smrx_channel_ids: list): self._kwargs = {'file_path': str(Path(file_path).absolute()), 'smrx_channel_ids': smrx_channel_ids} + @property + def channel_names(self): + return deepcopy(self._channel_names) + @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None): '''This function extracts and returns a trace from the recorded data from the