Skip to content

Commit

Permalink
Merge pull request #1343 from bendichter/acc_plexon
Browse files Browse the repository at this point in the history
PlexonRawIO style
  • Loading branch information
apdavison authored Feb 2, 2024
2 parents 288e101 + 603a89b commit ca7e142
Showing 1 changed file with 117 additions and 44 deletions.
161 changes: 117 additions & 44 deletions neo/rawio/plexonrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,43 @@
Author: Samuel Garcia
"""

from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype,
_spike_channel_dtype, _event_channel_dtype)
import datetime
from collections import OrderedDict

import numpy as np
from collections import OrderedDict
import datetime
try:
from tqdm import tqdm, trange
HAVE_TQDM = True
except:
HAVE_TQDM = False

from .baserawio import (
BaseRawIO,
_signal_channel_dtype,
_signal_stream_dtype,
_spike_channel_dtype,
_event_channel_dtype,
)


class PlexonRawIO(BaseRawIO):
extensions = ['plx']
rawmode = 'one-file'

def __init__(self, filename=''):
def __init__(self, filename='', progress_bar=True):
"""
Parameters
----------
filename: str
The filename.
progress_bar: bool, default True
Display progress bar using tqdm (if installed) when parsing the file.
"""
BaseRawIO.__init__(self)
self.filename = filename
self.progress_bar = HAVE_TQDM and progress_bar

def _source_name(self):
return self.filename
Expand All @@ -45,43 +66,57 @@ def _parse_header(self):

# global header
with open(self.filename, 'rb') as fid:
offset0 = 0
global_header = read_as_dict(fid, GlobalHeader, offset=offset0)
global_header = read_as_dict(fid, GlobalHeader)

rec_datetime = datetime.datetime(global_header['Year'],
global_header['Month'],
global_header['Day'],
global_header['Hour'],
global_header['Minute'],
global_header['Second'])
rec_datetime = datetime.datetime(
global_header['Year'],
global_header['Month'],
global_header['Day'],
global_header['Hour'],
global_header['Minute'],
global_header['Second'],
)

# dsp channels header = spikes and waveforms
nb_unit_chan = global_header['NumDSPChannels']
offset1 = np.dtype(GlobalHeader).itemsize
dspChannelHeaders = np.memmap(self.filename, dtype=DspChannelHeader, mode='r',
offset=offset1, shape=(nb_unit_chan,))
dspChannelHeaders = np.memmap(
self.filename, dtype=DspChannelHeader, mode='r', offset=offset1, shape=(nb_unit_chan,)
)

# event channel header
nb_event_chan = global_header['NumEventChannels']
offset2 = offset1 + np.dtype(DspChannelHeader).itemsize * nb_unit_chan
eventHeaders = np.memmap(self.filename, dtype=EventChannelHeader, mode='r',
offset=offset2, shape=(nb_event_chan,))
eventHeaders = np.memmap(
self.filename,
dtype=EventChannelHeader,
mode='r',
offset=offset2,
shape=(nb_event_chan,),
)

# slow channel header = signal
nb_sig_chan = global_header['NumSlowChannels']
offset3 = offset2 + np.dtype(EventChannelHeader).itemsize * nb_event_chan
slowChannelHeaders = np.memmap(self.filename, dtype=SlowChannelHeader, mode='r',
offset=offset3, shape=(nb_sig_chan,))
slowChannelHeaders = np.memmap(
self.filename, dtype=SlowChannelHeader, mode='r', offset=offset3, shape=(nb_sig_chan,)
)

offset4 = offset3 + np.dtype(SlowChannelHeader).itemsize * nb_sig_chan

# locate data blocks and group them by type and channel
block_pos = {1: {c: [] for c in dspChannelHeaders['Channel']},
4: {c: [] for c in eventHeaders['Channel']},
5: {c: [] for c in slowChannelHeaders['Channel']},
}
block_pos = {
1: {c: [] for c in dspChannelHeaders['Channel']},
4: {c: [] for c in eventHeaders['Channel']},
5: {c: [] for c in slowChannelHeaders['Channel']},
}
data = self._memmap = np.memmap(self.filename, dtype='u1', offset=0, mode='r')
pos = offset4

# Create a tqdm object with a total of len(data) and an initial value of 0 for offset
if self.progress_bar :
progress_bar = tqdm(total=len(data), initial=0, desc="Parsing data blocks", leave=True)

while pos < data.size:
bl_header = data[pos:pos + 16].view(DataBlockHeader)[0]
length = bl_header['NumberOfWaveforms'] * bl_header['NumberOfWordsInWaveform'] * 2 + 16
Expand All @@ -90,6 +125,13 @@ def _parse_header(self):
block_pos[bl_type][chan_id].append(pos)
pos += length

# Update tqdm with the number of bytes processed in this iteration
if self.progress_bar :
progress_bar.update(length)

if self.progress_bar :
progress_bar.close()

self._last_timestamps = bl_header['UpperByteOf5ByteTimestamp'] * \
2 ** 32 + bl_header['TimeStamp']

Expand All @@ -105,9 +147,21 @@ def _parse_header(self):
# Signals
5: np.dtype(dt_base + [('cumsum', 'int64'), ]),
}
for bl_type in block_pos:
if self.progress_bar :
bl_loop = tqdm(block_pos, desc="Finalizing data blocks", leave=True)
else:
bl_loop = block_pos
for bl_type in bl_loop:
self._data_blocks[bl_type] = {}
for chan_id in block_pos[bl_type]:
if self.progress_bar :
chan_loop = tqdm(
block_pos[bl_type],
desc="Finalizing data blocks for type %d" % bl_type,
leave=True,
)
else:
chan_loop = block_pos[bl_type]
for chan_id in chan_loop:
positions = block_pos[bl_type][chan_id]
dt = dtype_by_bltype[bl_type]
data_block = np.empty((len(positions)), dtype=dt)
Expand All @@ -132,7 +186,7 @@ def _parse_header(self):
data_block['label'][index] = bl_header['Unit']
elif bl_type == 5: # Signals
if data_block.size > 0:
# cumulative some of sample index for fast access to chunks
# cumulative sum of sample index for fast access to chunks
if index == 0:
data_block['cumsum'][index] = 0
else:
Expand All @@ -143,7 +197,11 @@ def _parse_header(self):
# signals channels
sig_channels = []
all_sig_length = []
for chan_index in range(nb_sig_chan):
if self.progress_bar:
chan_loop = trange(nb_sig_chan, desc="Parsing signal channels", leave=True)
else:
chan_loop = range(nb_sig_chan)
for chan_index in chan_loop:
h = slowChannelHeaders[chan_index]
name = h['Name'].decode('utf8')
chan_id = h['Channel']
Expand All @@ -164,8 +222,9 @@ def _parse_header(self):
h['Gain'] * h['PreampGain'])
offset = 0.
stream_id = '0'
sig_channels.append((name, str(chan_id), sampling_rate, sig_dtype,
units, gain, offset, stream_id))
sig_channels.append(
(name, str(chan_id), sampling_rate, sig_dtype, units, gain, offset, stream_id)
)

sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)

Expand Down Expand Up @@ -203,7 +262,16 @@ def _parse_header(self):

# Spikes channels
spike_channels = []
for unit_index, (chan_id, unit_id) in enumerate(self.internal_unit_ids):
if self.progress_bar:
unit_loop = tqdm(
enumerate(self.internal_unit_ids),
desc="Parsing spike channels",
leave=True,
)
else:
unit_loop = enumerate(self.internal_unit_ids)

for unit_index, (chan_id, unit_id) in unit_loop:
c = np.nonzero(dspChannelHeaders['Channel'] == chan_id)[0][0]
h = dspChannelHeaders[c]

Expand All @@ -223,28 +291,33 @@ def _parse_header(self):
wf_offset = 0.
wf_left_sweep = -1 # DONT KNOWN
wf_sampling_rate = global_header['WaveformFreq']
spike_channels.append((name, _id, wf_units, wf_gain, wf_offset,
wf_left_sweep, wf_sampling_rate))
spike_channels.append(
(name, _id, wf_units, wf_gain, wf_offset, wf_left_sweep, wf_sampling_rate)
)
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)

# Event channels
event_channels = []
for chan_index in range(nb_event_chan):
if self.progress_bar:
chan_loop = trange(nb_event_chan, desc="Parsing event channels", leave=True)
else:
chan_loop = range(nb_event)
for chan_index in chan_loop:
h = eventHeaders[chan_index]
chan_id = h['Channel']
name = h['Name'].decode('utf8')
_id = h['Channel']
event_channels.append((name, _id, 'event'))
event_channels.append((name, chan_id, 'event'))
event_channels = np.array(event_channels, dtype=_event_channel_dtype)

# fille into header dict
self.header = {}
self.header['nb_block'] = 1
self.header['nb_segment'] = [1]
self.header['signal_streams'] = signal_streams
self.header['signal_channels'] = sig_channels
self.header['spike_channels'] = spike_channels
self.header['event_channels'] = event_channels
# fill into header dict
self.header = {
"nb_block": 1,
"nb_segment": [1],
"signal_streams": signal_streams,
"signal_channels": sig_channels,
"spike_channels": spike_channels,
"event_channels": event_channels,
}

# Annotations
self._generate_minimal_annotations()
Expand Down

0 comments on commit ca7e142

Please sign in to comment.