From ded9a669356514be59dd533e118ef68e37418290 Mon Sep 17 00:00:00 2001 From: bendichter Date: Wed, 1 Nov 2023 11:04:32 -0400 Subject: [PATCH 1/4] * optimize imports * add tqdm to longer steps * a bit of black-like formatting * add tqdm as a requirement for plexon --- neo/rawio/plexonrawio.py | 108 +++++++++++++++++++++++---------------- pyproject.toml | 1 + 2 files changed, 64 insertions(+), 45 deletions(-) diff --git a/neo/rawio/plexonrawio.py b/neo/rawio/plexonrawio.py index 358635a33..b8976d98d 100644 --- a/neo/rawio/plexonrawio.py +++ b/neo/rawio/plexonrawio.py @@ -21,13 +21,15 @@ Author: Samuel Garcia """ - -from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype, - _spike_channel_dtype, _event_channel_dtype) +import datetime +from collections import OrderedDict import numpy as np -from collections import OrderedDict -import datetime +from tqdm import tqdm, trange + +from .baserawio import ( + BaseRawIO, _signal_channel_dtype, _signal_stream_dtype, _spike_channel_dtype, _event_channel_dtype +) class PlexonRawIO(BaseRawIO): @@ -45,43 +47,52 @@ def _parse_header(self): # global header with open(self.filename, 'rb') as fid: - offset0 = 0 - global_header = read_as_dict(fid, GlobalHeader, offset=offset0) + global_header = read_as_dict(fid, GlobalHeader) - rec_datetime = datetime.datetime(global_header['Year'], - global_header['Month'], - global_header['Day'], - global_header['Hour'], - global_header['Minute'], - global_header['Second']) + rec_datetime = datetime.datetime( + global_header['Year'], + global_header['Month'], + global_header['Day'], + global_header['Hour'], + global_header['Minute'], + global_header['Second'], + ) # dsp channels header = spikes and waveforms nb_unit_chan = global_header['NumDSPChannels'] offset1 = np.dtype(GlobalHeader).itemsize - dspChannelHeaders = np.memmap(self.filename, dtype=DspChannelHeader, mode='r', - offset=offset1, shape=(nb_unit_chan,)) + dspChannelHeaders = np.memmap( + self.filename, dtype=DspChannelHeader, mode='r', offset=offset1, shape=(nb_unit_chan,) + ) # event channel header nb_event_chan = global_header['NumEventChannels'] offset2 = offset1 + np.dtype(DspChannelHeader).itemsize * nb_unit_chan - eventHeaders = np.memmap(self.filename, dtype=EventChannelHeader, mode='r', - offset=offset2, shape=(nb_event_chan,)) + eventHeaders = np.memmap( + self.filename, dtype=EventChannelHeader, mode='r', offset=offset2, shape=(nb_event_chan,) + ) # slow channel header = signal nb_sig_chan = global_header['NumSlowChannels'] offset3 = offset2 + np.dtype(EventChannelHeader).itemsize * nb_event_chan - slowChannelHeaders = np.memmap(self.filename, dtype=SlowChannelHeader, mode='r', - offset=offset3, shape=(nb_sig_chan,)) + slowChannelHeaders = np.memmap( + self.filename, dtype=SlowChannelHeader, mode='r', offset=offset3, shape=(nb_sig_chan,) + ) offset4 = offset3 + np.dtype(SlowChannelHeader).itemsize * nb_sig_chan # locate data blocks and group them by type and channel - block_pos = {1: {c: [] for c in dspChannelHeaders['Channel']}, - 4: {c: [] for c in eventHeaders['Channel']}, - 5: {c: [] for c in slowChannelHeaders['Channel']}, - } + block_pos = { + 1: {c: [] for c in dspChannelHeaders['Channel']}, + 4: {c: [] for c in eventHeaders['Channel']}, + 5: {c: [] for c in slowChannelHeaders['Channel']}, + } data = self._memmap = np.memmap(self.filename, dtype='u1', offset=0, mode='r') pos = offset4 + + # Create a tqdm object with a total of len(data) and an initial value of 0 for offset + progress_bar = tqdm(total=len(data), initial=0, desc="Parsing data blocks", leave=True) + while pos < data.size: bl_header = data[pos:pos + 16].view(DataBlockHeader)[0] length = bl_header['NumberOfWaveforms'] * bl_header['NumberOfWordsInWaveform'] * 2 + 16 @@ -90,6 +101,11 @@ def _parse_header(self): block_pos[bl_type][chan_id].append(pos) pos += length + # Update tqdm with the number of bytes processed in this iteration + progress_bar.update(length) + + progress_bar.close() + self._last_timestamps = bl_header['UpperByteOf5ByteTimestamp'] * \ 2 ** 32 + bl_header['TimeStamp'] @@ -105,9 +121,9 @@ def _parse_header(self): # Signals 5: np.dtype(dt_base + [('cumsum', 'int64'), ]), } - for bl_type in block_pos: + for bl_type in tqdm(block_pos, desc="Finalizing data blocks", leave=True): self._data_blocks[bl_type] = {} - for chan_id in block_pos[bl_type]: + for chan_id in tqdm(block_pos[bl_type], desc="Finalizing data blocks for type %d" % bl_type, leave=True): positions = block_pos[bl_type][chan_id] dt = dtype_by_bltype[bl_type] data_block = np.empty((len(positions)), dtype=dt) @@ -132,7 +148,7 @@ def _parse_header(self): data_block['label'][index] = bl_header['Unit'] elif bl_type == 5: # Signals if data_block.size > 0: - # cumulative some of sample index for fast access to chunks + # cumulative sum of sample index for fast access to chunks if index == 0: data_block['cumsum'][index] = 0 else: @@ -143,7 +159,7 @@ def _parse_header(self): # signals channels sig_channels = [] all_sig_length = [] - for chan_index in range(nb_sig_chan): + for chan_index in trange(nb_sig_chan, desc="Parsing signal channels", leave=True): h = slowChannelHeaders[chan_index] name = h['Name'].decode('utf8') chan_id = h['Channel'] @@ -164,8 +180,9 @@ def _parse_header(self): h['Gain'] * h['PreampGain']) offset = 0. stream_id = '0' - sig_channels.append((name, str(chan_id), sampling_rate, sig_dtype, - units, gain, offset, stream_id)) + sig_channels.append( + (name, str(chan_id), sampling_rate, sig_dtype, units, gain, offset, stream_id) + ) sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype) @@ -203,7 +220,7 @@ def _parse_header(self): # Spikes channels spike_channels = [] - for unit_index, (chan_id, unit_id) in enumerate(self.internal_unit_ids): + for unit_index, (chan_id, unit_id) in tqdm(enumerate(self.internal_unit_ids), desc="Parsing spike channels", leave=True): c = np.nonzero(dspChannelHeaders['Channel'] == chan_id)[0][0] h = dspChannelHeaders[c] @@ -223,28 +240,29 @@ def _parse_header(self): wf_offset = 0. wf_left_sweep = -1 # DONT KNOWN wf_sampling_rate = global_header['WaveformFreq'] - spike_channels.append((name, _id, wf_units, wf_gain, wf_offset, - wf_left_sweep, wf_sampling_rate)) + spike_channels.append( + (name, _id, wf_units, wf_gain, wf_offset, wf_left_sweep, wf_sampling_rate) + ) spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype) # Event channels event_channels = [] - for chan_index in range(nb_event_chan): + for chan_index in trange(nb_event_chan, desc="Parsing event channels", leave=True): h = eventHeaders[chan_index] chan_id = h['Channel'] name = h['Name'].decode('utf8') - _id = h['Channel'] - event_channels.append((name, _id, 'event')) + event_channels.append((name, chan_id, 'event')) event_channels = np.array(event_channels, dtype=_event_channel_dtype) - # fille into header dict - self.header = {} - self.header['nb_block'] = 1 - self.header['nb_segment'] = [1] - self.header['signal_streams'] = signal_streams - self.header['signal_channels'] = sig_channels - self.header['spike_channels'] = spike_channels - self.header['event_channels'] = event_channels + # fill into header dict + self.header = { + "nb_block": 1, + "nb_segment": [1], + "signal_streams": signal_streams, + "signal_channels": sig_channels, + "spike_channels": spike_channels, + "event_channels": event_channels, + } # Annotations self._generate_minimal_annotations() @@ -399,13 +417,13 @@ def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index) return event_times -def read_as_dict(fid, dtype, offset=None): +def read_as_dict(fid, dtype, offset: int = 0): """ Given a file descriptor and a numpy.dtype of the binary struct return a dict. Make conversion for strings. """ - if offset is not None: + if offset: fid.seek(offset) dt = np.dtype(dtype) h = np.frombuffer(fid.read(dt.itemsize), dt)[0] diff --git a/pyproject.toml b/pyproject.toml index d5228d753..763d0951b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,6 +98,7 @@ nwb = ["pynwb"] maxwell = ["h5py"] biocam = ["h5py"] med = ["dhn_med_py>=1.0.0"] +plexon = ["tqdm"] plexon2 = ["zugbruecke>=0.2; sys_platform!='win32'", "wenv; sys_platform!='win32'"] all = [ From 8c747f45825152bd8dd5fa1a217d257951121fdf Mon Sep 17 00:00:00 2001 From: bendichter Date: Wed, 1 Nov 2023 11:10:31 -0400 Subject: [PATCH 2/4] pep8 line lengths --- neo/rawio/plexonrawio.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/neo/rawio/plexonrawio.py b/neo/rawio/plexonrawio.py index b8976d98d..31519d65f 100644 --- a/neo/rawio/plexonrawio.py +++ b/neo/rawio/plexonrawio.py @@ -28,7 +28,11 @@ from tqdm import tqdm, trange from .baserawio import ( - BaseRawIO, _signal_channel_dtype, _signal_stream_dtype, _spike_channel_dtype, _event_channel_dtype + BaseRawIO, + _signal_channel_dtype, + _signal_stream_dtype, + _spike_channel_dtype, + _event_channel_dtype, ) @@ -69,7 +73,11 @@ def _parse_header(self): nb_event_chan = global_header['NumEventChannels'] offset2 = offset1 + np.dtype(DspChannelHeader).itemsize * nb_unit_chan eventHeaders = np.memmap( - self.filename, dtype=EventChannelHeader, mode='r', offset=offset2, shape=(nb_event_chan,) + self.filename, + dtype=EventChannelHeader, + mode='r', + offset=offset2, + shape=(nb_event_chan,), ) # slow channel header = signal @@ -123,7 +131,11 @@ def _parse_header(self): } for bl_type in tqdm(block_pos, desc="Finalizing data blocks", leave=True): self._data_blocks[bl_type] = {} - for chan_id in tqdm(block_pos[bl_type], desc="Finalizing data blocks for type %d" % bl_type, leave=True): + for chan_id in tqdm( + block_pos[bl_type], + desc="Finalizing data blocks for type %d" % bl_type, + leave=True, + ): positions = block_pos[bl_type][chan_id] dt = dtype_by_bltype[bl_type] data_block = np.empty((len(positions)), dtype=dt) @@ -220,7 +232,11 @@ def _parse_header(self): # Spikes channels spike_channels = [] - for unit_index, (chan_id, unit_id) in tqdm(enumerate(self.internal_unit_ids), desc="Parsing spike channels", leave=True): + for unit_index, (chan_id, unit_id) in tqdm( + enumerate(self.internal_unit_ids), + desc="Parsing spike channels", + leave=True, + ): c = np.nonzero(dspChannelHeaders['Channel'] == chan_id)[0][0] h = dspChannelHeaders[c] From 22321d3b424c1ef99cff9ec8bcb8dd6cbb5f9f07 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 2 Feb 2024 11:40:59 +0100 Subject: [PATCH 3/4] Make tqdm dependency optional for PlexonRawIO and add a option to show or not the progressbar --- neo/rawio/plexonrawio.py | 63 ++++++++++++++++++++++++++++------------ pyproject.toml | 1 - 2 files changed, 44 insertions(+), 20 deletions(-) diff --git a/neo/rawio/plexonrawio.py b/neo/rawio/plexonrawio.py index 31519d65f..668b43db7 100644 --- a/neo/rawio/plexonrawio.py +++ b/neo/rawio/plexonrawio.py @@ -25,7 +25,11 @@ from collections import OrderedDict import numpy as np -from tqdm import tqdm, trange +try: + from tqdm import tqdm, trange + HAVE_TQDM = True +except: + HAVE_TQDM = False from .baserawio import ( BaseRawIO, @@ -40,9 +44,10 @@ class PlexonRawIO(BaseRawIO): extensions = ['plx'] rawmode = 'one-file' - def __init__(self, filename=''): + def __init__(self, filename='', progress_bar=True): BaseRawIO.__init__(self) self.filename = filename + self.progress_bar = HAVE_TQDM and progress_bar def _source_name(self): return self.filename @@ -99,7 +104,8 @@ def _parse_header(self): pos = offset4 # Create a tqdm object with a total of len(data) and an initial value of 0 for offset - progress_bar = tqdm(total=len(data), initial=0, desc="Parsing data blocks", leave=True) + if self.progress_bar : + progress_bar = tqdm(total=len(data), initial=0, desc="Parsing data blocks", leave=True) while pos < data.size: bl_header = data[pos:pos + 16].view(DataBlockHeader)[0] @@ -110,9 +116,11 @@ def _parse_header(self): pos += length # Update tqdm with the number of bytes processed in this iteration - progress_bar.update(length) + if self.progress_bar : + progress_bar.update(length) - progress_bar.close() + if self.progress_bar : + progress_bar.close() self._last_timestamps = bl_header['UpperByteOf5ByteTimestamp'] * \ 2 ** 32 + bl_header['TimeStamp'] @@ -129,13 +137,21 @@ def _parse_header(self): # Signals 5: np.dtype(dt_base + [('cumsum', 'int64'), ]), } - for bl_type in tqdm(block_pos, desc="Finalizing data blocks", leave=True): + if self.progress_bar : + bl_loop = tqdm(block_pos, desc="Finalizing data blocks", leave=True) + else: + bl_loop = block_pos + for bl_type in bl_loop: self._data_blocks[bl_type] = {} - for chan_id in tqdm( - block_pos[bl_type], - desc="Finalizing data blocks for type %d" % bl_type, - leave=True, - ): + if self.progress_bar : + chan_loop = tqdm( + block_pos[bl_type], + desc="Finalizing data blocks for type %d" % bl_type, + leave=True, + ) + else: + chan_loop = block_pos[bl_type] + for chan_id in chan_loop: positions = block_pos[bl_type][chan_id] dt = dtype_by_bltype[bl_type] data_block = np.empty((len(positions)), dtype=dt) @@ -171,7 +187,11 @@ def _parse_header(self): # signals channels sig_channels = [] all_sig_length = [] - for chan_index in trange(nb_sig_chan, desc="Parsing signal channels", leave=True): + if self.progress_bar: + chan_loop = trange(nb_sig_chan, desc="Parsing signal channels", leave=True) + else: + chan_loop = range(nb_sig_chan) + for chan_index in chan_loop: h = slowChannelHeaders[chan_index] name = h['Name'].decode('utf8') chan_id = h['Channel'] @@ -232,11 +252,16 @@ def _parse_header(self): # Spikes channels spike_channels = [] - for unit_index, (chan_id, unit_id) in tqdm( - enumerate(self.internal_unit_ids), - desc="Parsing spike channels", - leave=True, - ): + if self.progress_bar: + unit_loop = tqdm( + enumerate(self.internal_unit_ids), + desc="Parsing spike channels", + leave=True, + ) + else: + unit_loop = enumerate(self.internal_unit_ids) + + for unit_index, (chan_id, unit_id) in unit_loop: c = np.nonzero(dspChannelHeaders['Channel'] == chan_id)[0][0] h = dspChannelHeaders[c] @@ -433,13 +458,13 @@ def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index) return event_times -def read_as_dict(fid, dtype, offset: int = 0): +def read_as_dict(fid, dtype, offset=None): """ Given a file descriptor and a numpy.dtype of the binary struct return a dict. Make conversion for strings. """ - if offset: + if offset is not None: fid.seek(offset) dt = np.dtype(dtype) h = np.frombuffer(fid.read(dt.itemsize), dt)[0] diff --git a/pyproject.toml b/pyproject.toml index 763d0951b..d5228d753 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,6 @@ nwb = ["pynwb"] maxwell = ["h5py"] biocam = ["h5py"] med = ["dhn_med_py>=1.0.0"] -plexon = ["tqdm"] plexon2 = ["zugbruecke>=0.2; sys_platform!='win32'", "wenv; sys_platform!='win32'"] all = [ From 603a89b875df1c99148c1135f154747960f62730 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 2 Feb 2024 15:24:27 +0100 Subject: [PATCH 4/4] oups --- neo/rawio/plexonrawio.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/neo/rawio/plexonrawio.py b/neo/rawio/plexonrawio.py index 668b43db7..82d50e3d1 100644 --- a/neo/rawio/plexonrawio.py +++ b/neo/rawio/plexonrawio.py @@ -45,6 +45,16 @@ class PlexonRawIO(BaseRawIO): rawmode = 'one-file' def __init__(self, filename='', progress_bar=True): + """ + + Parameters + ---------- + filename: str + The filename. + progress_bar: bool, default True + Display progress bar using tqdm (if installed) when parsing the file. + + """ BaseRawIO.__init__(self) self.filename = filename self.progress_bar = HAVE_TQDM and progress_bar @@ -288,7 +298,11 @@ def _parse_header(self): # Event channels event_channels = [] - for chan_index in trange(nb_event_chan, desc="Parsing event channels", leave=True): + if self.progress_bar: + chan_loop = trange(nb_event_chan, desc="Parsing event channels", leave=True) + else: + chan_loop = range(nb_event) + for chan_index in chan_loop: h = eventHeaders[chan_index] chan_id = h['Channel'] name = h['Name'].decode('utf8')