diff --git a/spikeextractors/extractors/cedextractors/cedrecordingextractor.py b/spikeextractors/extractors/cedextractors/cedrecordingextractor.py index 682393f8..3641539f 100644 --- a/spikeextractors/extractors/cedextractors/cedrecordingextractor.py +++ b/spikeextractors/extractors/cedextractors/cedrecordingextractor.py @@ -5,6 +5,7 @@ import numpy as np from pathlib import Path from typing import Union +from copy import deepcopy try: from sonpy import lib as sp @@ -27,7 +28,7 @@ class CEDRecordingExtractor(RecordingExtractor): ---------- file_path: str Path to the .smrx file to be extracted - smrx_ch_inds: list of int + smrx_channel_ids: list of int List with indexes of valid smrx channels. Does not match necessarily with extractor id. """ @@ -38,10 +39,11 @@ class CEDRecordingExtractor(RecordingExtractor): mode = 'file' installation_mesg = "To use the CED extractor, install sonpy: \n\n pip install sonpy\n\n" # error message when not installed - def __init__(self, file_path: PathType, smrx_ch_inds: list): + def __init__(self, file_path: PathType, smrx_channel_ids: list): assert HAVE_SONPY, self.installation_mesg file_path = Path(file_path) assert file_path.is_file() and file_path.suffix == '.smrx', 'file_path must lead to a .smrx file!' + assert len(smrx_channel_ids) > 0, "'smrx_channel_ids' cannot be an empty list!" super().__init__() @@ -55,7 +57,8 @@ def __init__(self, file_path: PathType, smrx_ch_inds: list): # get channel info / set channel gains self._channelid_to_smrxind = dict() self._channel_smrxinfo = dict() - for i, ind in enumerate(smrx_ch_inds): + self._channel_names = [] + for i, ind in enumerate(smrx_channel_ids): if self._recording_file.ChannelType(ind) == sp.DataType.Off: raise ValueError(f'Channel {ind} is type Off and cannot be used') self._channelid_to_smrxind[i] = ind @@ -70,9 +73,19 @@ def __init__(self, file_path: PathType, smrx_ch_inds: list): channel_ids=[i], gains=gains ) + self._channel_names.append(self._channel_smrxinfo[i]['title']) + + rate0 = self._channel_smrxinfo[0]['rate'] + for chan, info in self._channel_smrxinfo.items(): + assert info['rate'] == rate0, "Inconsistency between 'sampling_frequency' of different channels. The " \ + "extractor only supports channels with the same 'rate'" self._kwargs = {'file_path': str(Path(file_path).absolute()), - 'smrx_ch_inds': smrx_ch_inds} + 'smrx_channel_ids': smrx_channel_ids} + + @property + def channel_names(self): + return deepcopy(self._channel_names) @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None):