diff --git a/neo/rawio/openephysbinaryrawio.py b/neo/rawio/openephysbinaryrawio.py index 74534ac7d..a59137696 100644 --- a/neo/rawio/openephysbinaryrawio.py +++ b/neo/rawio/openephysbinaryrawio.py @@ -98,13 +98,19 @@ def _parse_header(self): self._sig_streams[block_index][seg_index] = {} self._evt_streams[block_index][seg_index] = {} for stream_index, stream_name in enumerate(sig_stream_names): - d = all_streams[block_index][seg_index]['continuous'][stream_name] - d['stream_name'] = stream_name - self._sig_streams[block_index][seg_index][stream_index] = d + info_cnt = all_streams[block_index][seg_index]['continuous'][stream_name] + info_cnt['stream_name'] = stream_name + self._sig_streams[block_index][seg_index][stream_index] = info_cnt + + # check for SYNC channel for Neuropixels streams + has_sync_trace = any(["SYNC" in ch["channel_name"] + for ch in info_cnt["channels"]]) + self._sig_streams[block_index][seg_index][stream_index]['has_sync_trace'] \ + = has_sync_trace for i, stream_name in enumerate(event_stream_names): - d = all_streams[block_index][seg_index]['events'][stream_name] - d['stream_name'] = stream_name - self._evt_streams[block_index][seg_index][i] = d + info_evt = all_streams[block_index][seg_index]['events'][stream_name] + info_evt['stream_name'] = stream_name + self._evt_streams[block_index][seg_index][i] = info_evt # signals zone # create signals channel map: several channel per stream @@ -112,9 +118,9 @@ def _parse_header(self): for stream_index, stream_name in enumerate(sig_stream_names): # stream_index is the index in vector sytream names stream_id = str(stream_index) - d = self._sig_streams[0][0][stream_index] + info = self._sig_streams[0][0][stream_index] new_channels = [] - for chan_info in d['channels']: + for chan_info in info['channels']: chan_id = chan_info['channel_name'] if "SYNC" in chan_id and not self.load_sync_channel: continue @@ -124,7 +130,7 @@ def _parse_header(self): else: units = chan_info["units"] new_channels.append((chan_info['channel_name'], - chan_id, float(d['sample_rate']), d['dtype'], units, + chan_id, float(info['sample_rate']), info['dtype'], units, chan_info['bit_volts'], 0., stream_id)) signal_channels.extend(new_channels) signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype) @@ -138,56 +144,49 @@ def _parse_header(self): # create memmap for signals for block_index in range(nb_block): for seg_index in range(nb_segment_per_block[block_index]): - for stream_index, d in self._sig_streams[block_index][seg_index].items(): - num_channels = len(d['channels']) - memmap_sigs = np.memmap(d['raw_filename'], d['dtype'], + for stream_index, info in self._sig_streams[block_index][seg_index].items(): + num_channels = len(info['channels']) + memmap_sigs = np.memmap(info['raw_filename'], info['dtype'], order='C', mode='r').reshape(-1, num_channels) - channel_names = [ch["channel_name"] for ch in d["channels"]] - # if there is a sync channel and it should not be loaded, - # find the right channel index and slice the memmap - if any(["SYNC" in ch for ch in channel_names]) and \ - not self.load_sync_channel: - sync_channel_name = [ch for ch in channel_names if "SYNC" in ch][0] - sync_channel_index = channel_names.index(sync_channel_name) - - # only sync channel in last position is supported to keep memmap - if sync_channel_index == num_channels - 1: - memmap_sigs = memmap_sigs[:, :-1] - else: - raise NotImplementedError("SYNC channel removal is only supported " - "when the sync channel is in the last " - "position") - d['memmap'] = memmap_sigs + has_sync_trace = \ + self._sig_streams[block_index][seg_index][stream_index]['has_sync_trace'] + + # check sync channel validity (only for AP and LF) + if not has_sync_trace and self.load_sync_channel \ + and "NI-DAQ" not in info["stream_name"]: + raise ValueError("SYNC channel is not present in the recording. " + "Set load_sync_channel to False") + info['memmap'] = memmap_sigs # events zone # channel map: one channel one stream event_channels = [] for stream_ind, stream_name in enumerate(event_stream_names): - d = self._evt_streams[0][0][stream_ind] - if 'states' in d: + info = self._evt_streams[0][0][stream_ind] + if 'states' in info: evt_channel_type = "epoch" else: evt_channel_type = "event" - event_channels.append((d['channel_name'], d['channel_name'], evt_channel_type)) + event_channels.append((info['channel_name'], info['channel_name'], evt_channel_type)) event_channels = np.array(event_channels, dtype=_event_channel_dtype) # create memmap for events for block_index in range(nb_block): for seg_index in range(nb_segment_per_block[block_index]): - for stream_index, d in self._evt_streams[block_index][seg_index].items(): + for stream_index, info in self._evt_streams[block_index][seg_index].items(): for name in _possible_event_stream_names: - if name + '_npy' in d: - data = np.load(d[name + '_npy'], mmap_mode='r') - d[name] = data + if name + '_npy' in info: + data = np.load(info[name + '_npy'], mmap_mode='r') + info[name] = data # check that events have timestamps - assert 'timestamps' in d, "Event stream does not have timestamps!" + assert 'timestamps' in info, "Event stream does not have timestamps!" # Updates for OpenEphys v0.6: # In new vesion (>=0.6) timestamps.npy is now called sample_numbers.npy # The timestamps are already in seconds, so that event times don't require scaling # see https://open-ephys.github.io/gui-docs/User-Manual/Recording-data/Binary-format.html#events - if 'sample_numbers' in d: + if 'sample_numbers' in info: self._use_direct_evt_timestamps = True else: self._use_direct_evt_timestamps = False @@ -196,26 +195,28 @@ def _parse_header(self): # of event (ttl, text, binary) # and this is transform into unicode # all theses data are put in event array annotations - if 'text' in d: + if 'text' in info: # text case - d['labels'] = d['text'].astype('U') - elif 'metadata' in d: + info['labels'] = info['text'].astype('U') + elif 'metadata' in info: # binary case - d['labels'] = d['channels'].astype('U') - elif 'channels' in d: + info['labels'] = info['channels'].astype('U') + elif 'channels' in info: # ttl case use channels - d['labels'] = d['channels'].astype('U') - elif 'states' in d: + info['labels'] = info['channels'].astype('U') + elif 'states' in info: # ttl case use states - d['labels'] = d['states'].astype('U') + info['labels'] = info['states'].astype('U') else: - raise ValueError(f'There is no possible labels for this event: {stream_name}') + raise ValueError( + f'There is no possible labels for this event!' + ) # # If available, use 'states' to compute event duration - if 'states' in d and d["states"].size: - states = d["states"] - timestamps = d["timestamps"] - labels = d["labels"] + if 'states' in info and info["states"].size: + states = info["states"] + timestamps = info["timestamps"] + labels = info["labels"] rising = np.where(states > 0)[0] falling = np.where(states < 0)[0] @@ -231,12 +232,12 @@ def _parse_header(self): if len(rising) == len(falling): durations = timestamps[falling] - timestamps[rising] - d["rising"] = rising - d["timestamps"] = timestamps[rising] - d["labels"] = labels[rising] - d["durations"] = durations + info["rising"] = rising + info["timestamps"] = timestamps[rising] + info["labels"] = labels[rising] + info["durations"] = durations else: - d["durations"] = None + info["durations"] = None # no spike read yet # can be implemented on user demand @@ -253,9 +254,9 @@ def _parse_header(self): global_t_stop = None # loop over signals - for stream_index, d in self._sig_streams[block_index][seg_index].items(): - t_start = d['t_start'] - dur = d['memmap'].shape[0] / float(d['sample_rate']) + for stream_index, info in self._sig_streams[block_index][seg_index].items(): + t_start = info['t_start'] + dur = info['memmap'].shape[0] / float(info['sample_rate']) t_stop = t_start + dur if global_t_start is None or global_t_start > t_start: global_t_start = t_start @@ -264,15 +265,15 @@ def _parse_header(self): # loop over events for stream_index, stream_name in enumerate(event_stream_names): - d = self._evt_streams[block_index][seg_index][stream_index] - if d['timestamps'].size == 0: + info = self._evt_streams[block_index][seg_index][stream_index] + if info['timestamps'].size == 0: continue - t_start = d['timestamps'][0] - t_stop = d['timestamps'][-1] + t_start = info['timestamps'][0] + t_stop = info['timestamps'][-1] if not self._use_direct_evt_timestamps: - t_start /= d['sample_rate'] - t_stop /= d['sample_rate'] + t_start /= info['sample_rate'] + t_stop /= info['sample_rate'] if global_t_start is None or global_t_start > t_start: global_t_start = t_start @@ -301,35 +302,40 @@ def _parse_header(self): # array annotations for signal channels for stream_index, stream_name in enumerate(sig_stream_names): sig_ann = seg_ann['signals'][stream_index] - d = self._sig_streams[0][0][stream_index] + info = self._sig_streams[block_index][seg_index][stream_index] + has_sync_trace = \ + self._sig_streams[block_index][seg_index][stream_index]['has_sync_trace'] + for k in ('identifier', 'history', 'source_processor_index', 'recorded_processor_index'): - if k in d['channels'][0]: - values = np.array([chan_info[k] for chan_info in d['channels']]) + if k in info['channels'][0]: + values = np.array([chan_info[k] for chan_info in info['channels']]) + if has_sync_trace: + values = values[:-1] sig_ann['__array_annotations__'][k] = values # array annotations for event channels # use other possible data in _possible_event_stream_names for stream_index, stream_name in enumerate(event_stream_names): ev_ann = seg_ann['events'][stream_index] - d = self._evt_streams[0][0][stream_index] - if 'rising' in d: - selected_indices = d["rising"] + info = self._evt_streams[0][0][stream_index] + if 'rising' in info: + selected_indices = info["rising"] else: selected_indices = None for k in _possible_event_stream_names: if k in ('timestamps', 'rising'): continue - if k in d: + if k in info: # split custom dtypes into separate annotations - if d[k].dtype.names: - for name in d[k].dtype.names: - arr_ann = d[k][name].flatten() + if info[k].dtype.names: + for name in info[k].dtype.names: + arr_ann = info[k][name].flatten() if selected_indices is not None: arr_ann = arr_ann[selected_indices] ev_ann['__array_annotations__'][name] = arr_ann else: - arr_ann = d[k] + arr_ann = info[k] if selected_indices is not None: arr_ann = arr_ann[selected_indices] ev_ann['__array_annotations__'][k] = arr_ann @@ -360,6 +366,11 @@ def _get_signal_t_start(self, block_index, seg_index, stream_index): def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): sigs = self._sig_streams[block_index][seg_index][stream_index]['memmap'] + has_sync_trace = self._sig_streams[block_index][seg_index][stream_index]['has_sync_trace'] + + if not self.load_sync_channel and has_sync_trace: + sigs = sigs[:, :-1] + sigs = sigs[i_start:i_stop, :] if channel_indexes is not None: sigs = sigs[:, channel_indexes] @@ -383,15 +394,15 @@ def _event_count(self, block_index, seg_index, event_channel_index): return timestamps.size def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): - d = self._evt_streams[block_index][seg_index][event_channel_index] - timestamps = d['timestamps'] - durations = d["durations"] - labels = d['labels'] + info = self._evt_streams[block_index][seg_index][event_channel_index] + timestamps = info['timestamps'] + durations = info["durations"] + labels = info['labels'] # slice it if needed if t_start is not None: if not self._use_direct_evt_timestamps: - ind_start = int(t_start * d['sample_rate']) + ind_start = int(t_start * info['sample_rate']) mask = timestamps >= ind_start else: mask = timestamps >= t_start @@ -399,7 +410,7 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s labels = labels[mask] if t_stop is not None: if not self._use_direct_evt_timestamps: - ind_stop = int(t_stop * d['sample_rate']) + ind_stop = int(t_stop * info['sample_rate']) mask = timestamps < ind_stop else: mask = timestamps < t_stop @@ -408,17 +419,17 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s return timestamps, durations, labels def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index): - d = self._evt_streams[0][0][event_channel_index] + info = self._evt_streams[0][0][event_channel_index] if not self._use_direct_evt_timestamps: - event_times = event_timestamps.astype(dtype) / float(d['sample_rate']) + event_times = event_timestamps.astype(dtype) / float(info['sample_rate']) else: event_times = event_timestamps.astype(dtype) return event_times def _rescale_epoch_duration(self, raw_duration, dtype, event_channel_index): - d = self._evt_streams[0][0][event_channel_index] + info = self._evt_streams[0][0][event_channel_index] if not self._use_direct_evt_timestamps: - durations = raw_duration.astype(dtype) / float(d['sample_rate']) + durations = raw_duration.astype(dtype) / float(info['sample_rate']) else: durations = raw_duration.astype(dtype) return durations @@ -516,32 +527,32 @@ def explore_folder(dirname, experiment_names=None): if (recording_folder / 'continuous').exists() and len(rec_structure['continuous']) > 0: recording['streams']['continuous'] = {} - for d in rec_structure['continuous']: + for info in rec_structure['continuous']: # when multi Record Node the stream name also contains # the node name to make it unique - oe_stream_name = Path(d["folder_name"]).name # remove trailing slash + oe_stream_name = Path(info["folder_name"]).name # remove trailing slash if len(node_name) > 0: stream_name = node_name + '#' + oe_stream_name else: stream_name = oe_stream_name - raw_filename = recording_folder / 'continuous' / d['folder_name'] / 'continuous.dat' + raw_filename = recording_folder / 'continuous' / info['folder_name'] / 'continuous.dat' # Updates for OpenEphys v0.6: # In new vesion (>=0.6) timestamps.npy is now called sample_numbers.npy # see https://open-ephys.github.io/gui-docs/User-Manual/Recording-data/Binary-format.html#continuous - sample_numbers = recording_folder / 'continuous' / d['folder_name'] / \ + sample_numbers = recording_folder / 'continuous' / info['folder_name'] / \ 'sample_numbers.npy' if sample_numbers.is_file(): timestamp_file = sample_numbers else: - timestamp_file = recording_folder / 'continuous' / d['folder_name'] / \ + timestamp_file = recording_folder / 'continuous' / info['folder_name'] / \ 'timestamps.npy' timestamps = np.load(str(timestamp_file), mmap_mode='r') timestamp0 = timestamps[0] - t_start = timestamp0 / d['sample_rate'] + t_start = timestamp0 / info['sample_rate'] # TODO for later : gap checking - signal_stream = d.copy() + signal_stream = info.copy() signal_stream['raw_filename'] = str(raw_filename) signal_stream['dtype'] = 'int16' signal_stream['timestamp0'] = timestamp0 @@ -551,13 +562,13 @@ def explore_folder(dirname, experiment_names=None): if (root / 'events').exists() and len(rec_structure['events']) > 0: recording['streams']['events'] = {} - for d in rec_structure['events']: - oe_stream_name = Path(d["folder_name"]).name # remove trailing slash + for info in rec_structure['events']: + oe_stream_name = Path(info["folder_name"]).name # remove trailing slash stream_name = node_name + '#' + oe_stream_name - event_stream = d.copy() + event_stream = info.copy() for name in _possible_event_stream_names: - npy_filename = root / 'events' / d['folder_name'] / f'{name}.npy' + npy_filename = root / 'events' / info['folder_name'] / f'{name}.npy' if npy_filename.is_file(): event_stream[f'{name}_npy'] = str(npy_filename) diff --git a/neo/rawio/spikeglxrawio.py b/neo/rawio/spikeglxrawio.py index 495b428b2..15f969e5d 100644 --- a/neo/rawio/spikeglxrawio.py +++ b/neo/rawio/spikeglxrawio.py @@ -88,7 +88,6 @@ def _parse_header(self): # sort stream_name by higher sampling rate first srates = {info['stream_name']: info['sampling_rate'] for info in self.signals_info_list} stream_names = sorted(list(srates.keys()), key=lambda e: srates[e])[::-1] - nb_segment = np.unique([info['seg_index'] for info in self.signals_info_list]).size self._memmaps = {} @@ -122,10 +121,15 @@ def _parse_header(self): chan_name = info['channel_names'][local_chan] chan_id = f'{stream_name}#{chan_name}' signal_channels.append((chan_name, chan_id, info['sampling_rate'], 'int16', - info['units'], info['channel_gains'][local_chan], - info['channel_offsets'][local_chan], stream_id)) - if not self.load_sync_channel: - signal_channels = signal_channels[:-1] + info['units'], info['channel_gains'][local_chan], + info['channel_offsets'][local_chan], stream_id)) + # check sync channel validity + if "nidq" not in stream_name: + if not self.load_sync_channel and info['has_sync_trace']: + signal_channels = signal_channels[:-1] + if self.load_sync_channel and not info['has_sync_trace']: + raise ValueError("SYNC channel is not present in the recording. " + "Set load_sync_channel to False") signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype) signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype) @@ -181,7 +185,8 @@ def _parse_header(self): # one fake channel for "sys0" loc = np.concatenate((loc, [[0., 0.]]), axis=0) for ndim in range(loc.shape[1]): - sig_ann['__array_annotations__'][f'channel_location_{ndim}'] = loc[:, ndim] + sig_ann['__array_annotations__'][f'channel_location_{ndim}'] = \ + loc[:, ndim] def _segment_t_start(self, block_index, seg_index): return 0. @@ -201,25 +206,18 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): stream_id = self.header['signal_streams'][stream_index]['id'] memmap = self._memmaps[seg_index, stream_id] + stream_name = self.header['signal_streams']['name'][stream_index] + + # take care of sync channel + info = self.signals_info_dict[0, stream_name] + if not self.load_sync_channel and info['has_sync_trace']: + memmap = memmap[:, :-1] + + # since we cut the memmap, we can simplify the channel selection if channel_indexes is None: - if self.load_sync_channel: - channel_selection = slice(None) - else: - channel_selection = slice(-1) + channel_selection = slice(None) elif isinstance(channel_indexes, slice): - if self.load_sync_channel: - # simple - channel_selection = channel_indexes - else: - # more tricky because negative - sl_start = channel_indexes.start - sl_stop = channel_indexes.stop - sl_step = channel_indexes.step - if sl_stop is not None and sl_stop < 0: - sl_stop = sl_stop - 1 - elif sl_stop is None: - sl_stop = -1 - channel_selection = slice(sl_start, sl_stop, sl_step) + channel_selection = channel_indexes elif not isinstance(channel_indexes, slice): if np.all(np.diff(channel_indexes) == 1): # consecutive channel then slice this avoid a copy (because of ndarray.take(...) @@ -286,7 +284,7 @@ def parse_spikeglx_fname(fname): Parse recording identifiers from a SpikeGLX style filename. spikeglx naming follow this rules: - https://github.com/billkarsh/SpikeGLX/blob/master/Markdown/UserManual.md#gates-and-triggers + https://github.com/billkarsh/SpikeGLX/blob/15ec8898e17829f9f08c226bf04f46281f106e5f/Markdown/UserManual.md#gates-and-triggers Example file name structure: Consider the filenames: `Noise4Sam_g0_t0.nidq.bin` or `Noise4Sam_g0_t0.imec0.lf.bin` @@ -372,6 +370,13 @@ def extract_stream_info(meta_file, meta): """Extract info from the meta dict""" num_chan = int(meta['nSavedChans']) + if "snsApLfSy" in meta: + # AP and LF meta have this field + ap, lf, sy = [int(s) for s in meta["snsApLfSy"].split(",")] + has_sync_trace = sy == 1 + else: + # NIDQ case + has_sync_trace = False fname = Path(meta_file).stem run_name, gate_num, trigger_num, device, stream_kind = parse_spikeglx_fname(fname) @@ -385,11 +390,10 @@ def extract_stream_info(meta_file, meta): # metad['imroTbl'] contain two gain per channel AP and LF # except for the last fake channel per_channel_gain = np.ones(num_chan, dtype='float64') - if 'imDatPrb_type' not in meta or meta['imDatPrb_type'] == '0' or meta['imDatPrb_type'] in ('1015', '1022', '1030', '1031', '1032'): + if 'imDatPrb_type' not in meta or meta['imDatPrb_type'] == '0' or meta['imDatPrb_type'] \ + in ('1015', '1022', '1030', '1031', '1032'): # This work with NP 1.0 case with different metadata versions - # https://github.com/billkarsh/SpikeGLX/blob/gh-pages/Support/Metadata_3A.md#imec - # https://github.com/billkarsh/SpikeGLX/blob/gh-pages/Support/Metadata_3B1.md#imec - # https://github.com/billkarsh/SpikeGLX/blob/gh-pages/Support/Metadata_3B2.md#imec + # https://github.com/billkarsh/SpikeGLX/blob/15ec8898e17829f9f08c226bf04f46281f106e5f/Markdown/Metadata_30.md if stream_kind == 'ap': index_imroTbl = 3 elif stream_kind == 'lf': @@ -399,11 +403,11 @@ def extract_stream_info(meta_file, meta): per_channel_gain[c] = 1. / float(v) gain_factor = float(meta['imAiRangeMax']) / 512 channel_gains = gain_factor * per_channel_gain * 1e6 - elif meta['imDatPrb_type'] in ('21', '24') and stream_kind == 'ap': + elif meta['imDatPrb_type'] in ('21', '24', '2003', '2004', '2013', '2014'): # This work with NP 2.0 case with different metadata versions - # https://github.com/billkarsh/SpikeGLX/blob/gh-pages/Support/Metadata_20.md#channel-entries-by-type - # https://github.com/billkarsh/SpikeGLX/blob/gh-pages/Support/Metadata_20.md#imec - # https://github.com/billkarsh/SpikeGLX/blob/gh-pages/Support/Metadata_30.md#imec + # https://github.com/billkarsh/SpikeGLX/blob/15ec8898e17829f9f08c226bf04f46281f106e5f/Markdown/Metadata_30.md#imec + # We allow also LF streams for NP2.0 because CatGT can produce them + # See: https://github.com/SpikeInterface/spikeinterface/issues/1949 per_channel_gain[:-1] = 1 / 80. gain_factor = float(meta['imAiRangeMax']) / 8192 channel_gains = gain_factor * per_channel_gain * 1e6 @@ -447,5 +451,6 @@ def extract_stream_info(meta_file, meta): info['channel_names'] = [txt.split(';')[0] for txt in meta['snsChanMap']] info['channel_gains'] = channel_gains info['channel_offsets'] = np.zeros(info['num_chan']) + info['has_sync_trace'] = has_sync_trace return info diff --git a/neo/test/rawiotest/test_openephysbinaryrawio.py b/neo/test/rawiotest/test_openephysbinaryrawio.py index 27fd33011..093f86f9f 100644 --- a/neo/test/rawiotest/test_openephysbinaryrawio.py +++ b/neo/test/rawiotest/test_openephysbinaryrawio.py @@ -6,16 +6,52 @@ class TestOpenEphysBinaryRawIO(BaseTestRawIO, unittest.TestCase): rawioclass = OpenEphysBinaryRawIO - entities_to_download = [ - 'openephysbinary' - ] + entities_to_download = ["openephysbinary"] entities_to_test = [ - 'openephysbinary/v0.5.3_two_neuropixels_stream', - 'openephysbinary/v0.4.4.1_with_video_tracking', - 'openephysbinary/v0.5.x_two_nodes', - 'openephysbinary/v0.6.x_neuropixels_multiexp_multistream', + "openephysbinary/v0.5.3_two_neuropixels_stream", + "openephysbinary/v0.4.4.1_with_video_tracking", + "openephysbinary/v0.5.x_two_nodes", + "openephysbinary/v0.6.x_neuropixels_multiexp_multistream", + "openephysbinary/v0.6.x_neuropixels_with_sync", ] + def test_sync(self): + rawio_with_sync = OpenEphysBinaryRawIO( + self.get_local_path("openephysbinary/v0.6.x_neuropixels_with_sync"), + load_sync_channel=True + ) + rawio_with_sync.parse_header() + stream_name = [s_name for s_name in rawio_with_sync.header["signal_streams"]["name"] + if "AP" in s_name][0] + stream_index = list(rawio_with_sync.header["signal_streams"]["name"]).index(stream_name) + + # AP stream has 385 channels + chunk = rawio_with_sync.get_analogsignal_chunk( + block_index=0, seg_index=0, i_start=0, i_stop=100, stream_index=stream_index + ) + assert chunk.shape[1] == 385 + + rawio_no_sync = OpenEphysBinaryRawIO( + self.get_local_path("openephysbinary/v0.6.x_neuropixels_with_sync"), + load_sync_channel=False + ) + rawio_no_sync.parse_header() + + # AP stream has 384 channels + chunk = rawio_no_sync.get_analogsignal_chunk( + block_index=0, seg_index=0, i_start=0, i_stop=100, stream_index=stream_index + ) + assert chunk.shape[1] == 384 + + def test_no_sync(self): + # requesting sync channel when there is none raises an error + with self.assertRaises(ValueError): + rawio_no_sync = OpenEphysBinaryRawIO( + self.get_local_path("openephysbinary/v0.6.x_neuropixels_multiexp_multistream"), + load_sync_channel=True + ) + rawio_no_sync.parse_header() + if __name__ == "__main__": unittest.main() diff --git a/neo/test/rawiotest/test_spikeglxrawio.py b/neo/test/rawiotest/test_spikeglxrawio.py index e4fd3a810..8b36dc0e7 100644 --- a/neo/test/rawiotest/test_spikeglxrawio.py +++ b/neo/test/rawiotest/test_spikeglxrawio.py @@ -10,37 +10,81 @@ class TestSpikeGLXRawIO(BaseTestRawIO, unittest.TestCase): rawioclass = SpikeGLXRawIO - entities_to_download = [ - 'spikeglx' - ] + entities_to_download = ["spikeglx"] entities_to_test = [ - 'spikeglx/Noise4Sam_g0', - 'spikeglx/TEST_20210920_0_g0', - + "spikeglx/Noise4Sam_g0", + "spikeglx/TEST_20210920_0_g0", # this is only g0 multi index - 'spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI0/5-19-2022-CI0_g0', + "spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI0/5-19-2022-CI0_g0", # this is only g1 multi index - 'spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI0/5-19-2022-CI0_g1', + "spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI0/5-19-2022-CI0_g1", # this mix both multi gate and multi trigger (and also multi probe) - 'spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI0', - - 'spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI1', - 'spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI2', - 'spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI3', - 'spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI4', - 'spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI5', - + "spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI0", + "spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI1", + "spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI2", + "spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI3", + "spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI4", + "spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI5", + # different sync/sybset options with commercial NP2 + "spikeglx/NP2_with_sync", + "spikeglx/NP2_no_sync", + "spikeglx/NP2_subset_with_sync", ] def test_with_location(self): - rawio = SpikeGLXRawIO(self.get_local_path('spikeglx/Noise4Sam_g0'), load_channel_location=True) + rawio = SpikeGLXRawIO(self.get_local_path("spikeglx/Noise4Sam_g0"), load_channel_location=True) rawio.parse_header() # one of the stream have channel location have_location = [] - for sig_anotations in rawio.raw_annotations['blocks'][0]['segments'][0]['signals']: - have_location.append('channel_location_0' in sig_anotations['__array_annotations__']) + for sig_anotations in rawio.raw_annotations["blocks"][0]["segments"][0]["signals"]: + have_location.append("channel_location_0" in sig_anotations["__array_annotations__"]) assert any(have_location) + def test_sync(self): + rawio_with_sync = SpikeGLXRawIO(self.get_local_path("spikeglx/NP2_with_sync"), load_sync_channel=True) + rawio_with_sync.parse_header() + stream_index = list(rawio_with_sync.header["signal_streams"]["name"]).index("imec0.ap") + + # AP stream has 385 channels + chunk = rawio_with_sync.get_analogsignal_chunk( + block_index=0, seg_index=0, i_start=0, i_stop=100, stream_index=stream_index + ) + assert chunk.shape[1] == 385 + + rawio_no_sync = SpikeGLXRawIO(self.get_local_path("spikeglx/NP2_with_sync"), load_sync_channel=False) + rawio_no_sync.parse_header() + + # AP stream has 384 channels + chunk = rawio_no_sync.get_analogsignal_chunk( + block_index=0, seg_index=0, i_start=0, i_stop=100, stream_index=stream_index + ) + assert chunk.shape[1] == 384 + + def test_no_sync(self): + # requesting sync channel when there is none raises an error + with self.assertRaises(ValueError): + rawio_no_sync = SpikeGLXRawIO(self.get_local_path("spikeglx/NP2_no_sync"), load_sync_channel=True) + rawio_no_sync.parse_header() + + def test_subset_with_sync(self): + rawio_sub = SpikeGLXRawIO(self.get_local_path("spikeglx/NP2_subset_with_sync"), load_sync_channel=True) + rawio_sub.parse_header() + stream_index = list(rawio_sub.header["signal_streams"]["name"]).index("imec0.ap") + + # AP stream has 121 channels + chunk = rawio_sub.get_analogsignal_chunk( + block_index=0, seg_index=0, i_start=0, i_stop=100, stream_index=stream_index + ) + assert chunk.shape[1] == 121 + + rawio_sub_no_sync = SpikeGLXRawIO(self.get_local_path("spikeglx/NP2_subset_with_sync"), load_sync_channel=False) + rawio_sub_no_sync.parse_header() + # AP stream has 120 channels + chunk = rawio_sub_no_sync.get_analogsignal_chunk( + block_index=0, seg_index=0, i_start=0, i_stop=100, stream_index=stream_index + ) + assert chunk.shape[1] == 120 + if __name__ == "__main__": unittest.main()