diff --git a/neo/rawio/plexonrawio.py b/neo/rawio/plexonrawio.py index 358635a33..82d50e3d1 100644 --- a/neo/rawio/plexonrawio.py +++ b/neo/rawio/plexonrawio.py @@ -21,22 +21,43 @@ Author: Samuel Garcia """ - -from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype, - _spike_channel_dtype, _event_channel_dtype) +import datetime +from collections import OrderedDict import numpy as np -from collections import OrderedDict -import datetime +try: + from tqdm import tqdm, trange + HAVE_TQDM = True +except: + HAVE_TQDM = False + +from .baserawio import ( + BaseRawIO, + _signal_channel_dtype, + _signal_stream_dtype, + _spike_channel_dtype, + _event_channel_dtype, +) class PlexonRawIO(BaseRawIO): extensions = ['plx'] rawmode = 'one-file' - def __init__(self, filename=''): + def __init__(self, filename='', progress_bar=True): + """ + + Parameters + ---------- + filename: str + The filename. + progress_bar: bool, default True + Display progress bar using tqdm (if installed) when parsing the file. + + """ BaseRawIO.__init__(self) self.filename = filename + self.progress_bar = HAVE_TQDM and progress_bar def _source_name(self): return self.filename @@ -45,43 +66,57 @@ def _parse_header(self): # global header with open(self.filename, 'rb') as fid: - offset0 = 0 - global_header = read_as_dict(fid, GlobalHeader, offset=offset0) + global_header = read_as_dict(fid, GlobalHeader) - rec_datetime = datetime.datetime(global_header['Year'], - global_header['Month'], - global_header['Day'], - global_header['Hour'], - global_header['Minute'], - global_header['Second']) + rec_datetime = datetime.datetime( + global_header['Year'], + global_header['Month'], + global_header['Day'], + global_header['Hour'], + global_header['Minute'], + global_header['Second'], + ) # dsp channels header = spikes and waveforms nb_unit_chan = global_header['NumDSPChannels'] offset1 = np.dtype(GlobalHeader).itemsize - dspChannelHeaders = np.memmap(self.filename, dtype=DspChannelHeader, mode='r', - offset=offset1, shape=(nb_unit_chan,)) + dspChannelHeaders = np.memmap( + self.filename, dtype=DspChannelHeader, mode='r', offset=offset1, shape=(nb_unit_chan,) + ) # event channel header nb_event_chan = global_header['NumEventChannels'] offset2 = offset1 + np.dtype(DspChannelHeader).itemsize * nb_unit_chan - eventHeaders = np.memmap(self.filename, dtype=EventChannelHeader, mode='r', - offset=offset2, shape=(nb_event_chan,)) + eventHeaders = np.memmap( + self.filename, + dtype=EventChannelHeader, + mode='r', + offset=offset2, + shape=(nb_event_chan,), + ) # slow channel header = signal nb_sig_chan = global_header['NumSlowChannels'] offset3 = offset2 + np.dtype(EventChannelHeader).itemsize * nb_event_chan - slowChannelHeaders = np.memmap(self.filename, dtype=SlowChannelHeader, mode='r', - offset=offset3, shape=(nb_sig_chan,)) + slowChannelHeaders = np.memmap( + self.filename, dtype=SlowChannelHeader, mode='r', offset=offset3, shape=(nb_sig_chan,) + ) offset4 = offset3 + np.dtype(SlowChannelHeader).itemsize * nb_sig_chan # locate data blocks and group them by type and channel - block_pos = {1: {c: [] for c in dspChannelHeaders['Channel']}, - 4: {c: [] for c in eventHeaders['Channel']}, - 5: {c: [] for c in slowChannelHeaders['Channel']}, - } + block_pos = { + 1: {c: [] for c in dspChannelHeaders['Channel']}, + 4: {c: [] for c in eventHeaders['Channel']}, + 5: {c: [] for c in slowChannelHeaders['Channel']}, + } data = self._memmap = np.memmap(self.filename, dtype='u1', offset=0, mode='r') pos = offset4 + + # Create a tqdm object with a total of len(data) and an initial value of 0 for offset + if self.progress_bar : + progress_bar = tqdm(total=len(data), initial=0, desc="Parsing data blocks", leave=True) + while pos < data.size: bl_header = data[pos:pos + 16].view(DataBlockHeader)[0] length = bl_header['NumberOfWaveforms'] * bl_header['NumberOfWordsInWaveform'] * 2 + 16 @@ -90,6 +125,13 @@ def _parse_header(self): block_pos[bl_type][chan_id].append(pos) pos += length + # Update tqdm with the number of bytes processed in this iteration + if self.progress_bar : + progress_bar.update(length) + + if self.progress_bar : + progress_bar.close() + self._last_timestamps = bl_header['UpperByteOf5ByteTimestamp'] * \ 2 ** 32 + bl_header['TimeStamp'] @@ -105,9 +147,21 @@ def _parse_header(self): # Signals 5: np.dtype(dt_base + [('cumsum', 'int64'), ]), } - for bl_type in block_pos: + if self.progress_bar : + bl_loop = tqdm(block_pos, desc="Finalizing data blocks", leave=True) + else: + bl_loop = block_pos + for bl_type in bl_loop: self._data_blocks[bl_type] = {} - for chan_id in block_pos[bl_type]: + if self.progress_bar : + chan_loop = tqdm( + block_pos[bl_type], + desc="Finalizing data blocks for type %d" % bl_type, + leave=True, + ) + else: + chan_loop = block_pos[bl_type] + for chan_id in chan_loop: positions = block_pos[bl_type][chan_id] dt = dtype_by_bltype[bl_type] data_block = np.empty((len(positions)), dtype=dt) @@ -132,7 +186,7 @@ def _parse_header(self): data_block['label'][index] = bl_header['Unit'] elif bl_type == 5: # Signals if data_block.size > 0: - # cumulative some of sample index for fast access to chunks + # cumulative sum of sample index for fast access to chunks if index == 0: data_block['cumsum'][index] = 0 else: @@ -143,7 +197,11 @@ def _parse_header(self): # signals channels sig_channels = [] all_sig_length = [] - for chan_index in range(nb_sig_chan): + if self.progress_bar: + chan_loop = trange(nb_sig_chan, desc="Parsing signal channels", leave=True) + else: + chan_loop = range(nb_sig_chan) + for chan_index in chan_loop: h = slowChannelHeaders[chan_index] name = h['Name'].decode('utf8') chan_id = h['Channel'] @@ -164,8 +222,9 @@ def _parse_header(self): h['Gain'] * h['PreampGain']) offset = 0. stream_id = '0' - sig_channels.append((name, str(chan_id), sampling_rate, sig_dtype, - units, gain, offset, stream_id)) + sig_channels.append( + (name, str(chan_id), sampling_rate, sig_dtype, units, gain, offset, stream_id) + ) sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype) @@ -203,7 +262,16 @@ def _parse_header(self): # Spikes channels spike_channels = [] - for unit_index, (chan_id, unit_id) in enumerate(self.internal_unit_ids): + if self.progress_bar: + unit_loop = tqdm( + enumerate(self.internal_unit_ids), + desc="Parsing spike channels", + leave=True, + ) + else: + unit_loop = enumerate(self.internal_unit_ids) + + for unit_index, (chan_id, unit_id) in unit_loop: c = np.nonzero(dspChannelHeaders['Channel'] == chan_id)[0][0] h = dspChannelHeaders[c] @@ -223,28 +291,33 @@ def _parse_header(self): wf_offset = 0. wf_left_sweep = -1 # DONT KNOWN wf_sampling_rate = global_header['WaveformFreq'] - spike_channels.append((name, _id, wf_units, wf_gain, wf_offset, - wf_left_sweep, wf_sampling_rate)) + spike_channels.append( + (name, _id, wf_units, wf_gain, wf_offset, wf_left_sweep, wf_sampling_rate) + ) spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype) # Event channels event_channels = [] - for chan_index in range(nb_event_chan): + if self.progress_bar: + chan_loop = trange(nb_event_chan, desc="Parsing event channels", leave=True) + else: + chan_loop = range(nb_event) + for chan_index in chan_loop: h = eventHeaders[chan_index] chan_id = h['Channel'] name = h['Name'].decode('utf8') - _id = h['Channel'] - event_channels.append((name, _id, 'event')) + event_channels.append((name, chan_id, 'event')) event_channels = np.array(event_channels, dtype=_event_channel_dtype) - # fille into header dict - self.header = {} - self.header['nb_block'] = 1 - self.header['nb_segment'] = [1] - self.header['signal_streams'] = signal_streams - self.header['signal_channels'] = sig_channels - self.header['spike_channels'] = spike_channels - self.header['event_channels'] = event_channels + # fill into header dict + self.header = { + "nb_block": 1, + "nb_segment": [1], + "signal_streams": signal_streams, + "signal_channels": sig_channels, + "spike_channels": spike_channels, + "event_channels": event_channels, + } # Annotations self._generate_minimal_annotations()