diff --git a/spikeextractors/extractors/maxoneextractors/maxoneextractors.py b/spikeextractors/extractors/maxoneextractors/maxoneextractors.py index 117f6e92..26a247cd 100644 --- a/spikeextractors/extractors/maxoneextractors/maxoneextractors.py +++ b/spikeextractors/extractors/maxoneextractors/maxoneextractors.py @@ -103,9 +103,10 @@ def get_sampling_frequency(self): def get_traces(self, channel_ids=None, start_frame=None, end_frame=None): if np.array(channel_ids).size > 1: if np.any(np.diff(channel_ids) < 0): - sorted_idx = np.argsort(channel_ids) - recordings = self._signals[np.sort(channel_ids), start_frame:end_frame] - return (recordings[sorted_idx] * self._lsb).astype('float') + sorted_channel_ids = np.sort(channel_ids) + sorted_idx = np.array([list(sorted_channel_ids).index(ch) for ch in channel_ids]) + recordings = (self._signals[sorted_channel_ids, start_frame:end_frame] * self._lsb).astype('float32') + return recordings[sorted_idx] else: return (self._signals[np.array(channel_ids), start_frame:end_frame] * self._lsb).astype('float32') else: diff --git a/spikeextractors/extractors/mcsh5recordingextractor/mcsh5recordingextractor.py b/spikeextractors/extractors/mcsh5recordingextractor/mcsh5recordingextractor.py index 5a97d7d5..bdf624f8 100644 --- a/spikeextractors/extractors/mcsh5recordingextractor/mcsh5recordingextractor.py +++ b/spikeextractors/extractors/mcsh5recordingextractor/mcsh5recordingextractor.py @@ -83,8 +83,9 @@ def get_traces(self, channel_ids=None, start_frame=None, end_frame=None): if np.array(channel_idxs).size > 1: if np.any(np.diff(channel_idxs) < 0): - sorted_idx = np.argsort(channel_idxs) - recordings = stream.get('ChannelData')[np.sort(channel_idxs), start_frame:end_frame] + sorted_channel_ids = np.sort(channel_idxs) + sorted_idx = np.array([list(sorted_channel_ids).index(ch) for ch in channel_idxs]) + recordings = stream.get('ChannelData')[sorted_channel_ids, start_frame:end_frame] return recordings[sorted_idx] * conv else: return stream.get('ChannelData')[np.sort(channel_idxs), start_frame:end_frame] * conv diff --git a/spikeextractors/extractors/mea1kextractors/mea1kextractors.py b/spikeextractors/extractors/mea1kextractors/mea1kextractors.py index fc46208b..d57b5997 100644 --- a/spikeextractors/extractors/mea1kextractors/mea1kextractors.py +++ b/spikeextractors/extractors/mea1kextractors/mea1kextractors.py @@ -184,9 +184,10 @@ def get_sampling_frequency(self): def get_traces(self, channel_ids=None, start_frame=None, end_frame=None): if np.array(channel_ids).size > 1: if np.any(np.diff(channel_ids) < 0): - sorted_idx = np.argsort(channel_ids) - recordings = self._signals[np.sort(channel_ids), start_frame:end_frame] - return recordings[sorted_idx].astype('float') + sorted_channel_ids = np.sort(channel_ids) + sorted_idx = np.array([list(sorted_channel_ids).index(ch) for ch in channel_ids]) + recordings = (self._signals[sorted_channel_ids, start_frame:end_frame] * self._lsb).astype('float32') + return recordings[sorted_idx] else: return (self._signals[np.array(channel_ids), start_frame:end_frame] * self._lsb).astype('float32') else: diff --git a/spikeextractors/extractors/mearecextractors/mearecextractors.py b/spikeextractors/extractors/mearecextractors/mearecextractors.py index 64713237..5cfa70a6 100644 --- a/spikeextractors/extractors/mearecextractors/mearecextractors.py +++ b/spikeextractors/extractors/mearecextractors/mearecextractors.py @@ -77,13 +77,14 @@ def get_sampling_frequency(self): @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None): if np.any(np.diff(channel_ids) < 0): - sorted_idx = np.argsort(channel_ids) - recordings = self._recordings[start_frame:end_frame, np.sort(channel_ids)] - return np.array(recordings[sorted_idx]).transpose() + sorted_channel_ids = np.sort(channel_ids) + sorted_idx = np.array([list(sorted_channel_ids).index(ch) for ch in channel_ids]) + recordings = self._recordings[start_frame:end_frame, sorted_channel_ids] + return np.array(recordings[:, sorted_idx]).T else: if sorted(channel_ids) == channel_ids and np.all(np.diff(channel_ids) == 1): channel_ids = slice(channel_ids[0], channel_ids[0] + len(channel_ids)) - return np.array(self._recordings[start_frame:end_frame, channel_ids]).transpose() + return np.array(self._recordings[start_frame:end_frame, channel_ids]).T @staticmethod def write_recording(recording, save_path, check_suffix=True): diff --git a/spikeextractors/extractors/nwbextractors/nwbextractors.py b/spikeextractors/extractors/nwbextractors/nwbextractors.py index 095d03c3..be2666d2 100644 --- a/spikeextractors/extractors/nwbextractors/nwbextractors.py +++ b/spikeextractors/extractors/nwbextractors/nwbextractors.py @@ -287,8 +287,9 @@ def get_traces(self, channel_ids=None, start_frame=None, end_frame=None): if np.array(channel_ids).size > 1 and np.any(np.diff(channel_ids) < 0): # get around h5py constraint that it does not allow datasets # to be indexed out of order - sorted_idx = np.argsort(channel_inds) - recordings = es.data[start_frame:end_frame, np.sort(channel_inds)].T + sorted_channel_ids = np.sort(channel_ids) + sorted_idx = np.array([list(sorted_channel_ids).index(ch) for ch in channel_ids]) + recordings = es.data[start_frame:end_frame, sorted_channel_ids].T traces = recordings[sorted_idx, :] else: traces = es.data[start_frame:end_frame, channel_inds].T