Skip to content

Commit

Permalink
Merge pull request NeuralEnsemble#1590 from NeuralEnsemble/black-form…
Browse files Browse the repository at this point in the history
…atting

Black formatting
  • Loading branch information
zm711 authored Oct 27, 2024
2 parents 96a28af + 35256ce commit 4d079ef
Show file tree
Hide file tree
Showing 18 changed files with 162 additions and 170 deletions.
19 changes: 10 additions & 9 deletions neo/rawio/axonrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def _parse_header(self):
# Get raw data by segment
# self._raw_signals = {}
self._t_starts = {}
self._buffer_descriptions = {0 :{}}
self._stream_buffer_slice = {stream_id : None}
self._buffer_descriptions = {0: {}}
self._stream_buffer_slice = {stream_id: None}
pos = 0
for seg_index in range(nb_segment):
length = episode_array[seg_index]["len"]
Expand All @@ -174,12 +174,12 @@ def _parse_header(self):

self._buffer_descriptions[0][seg_index] = {}
self._buffer_descriptions[0][seg_index][buffer_id] = {
"type" : "raw",
"file_path" : str(self.filename),
"dtype" : str(sig_dtype),
"type": "raw",
"file_path": str(self.filename),
"dtype": str(sig_dtype),
"order": "C",
"file_offset" : head_offset + pos * sig_dtype.itemsize,
"shape" : (int(length // nbchannel), int(nbchannel)),
"file_offset": head_offset + pos * sig_dtype.itemsize,
"shape": (int(length // nbchannel), int(nbchannel)),
}
pos += length

Expand Down Expand Up @@ -239,7 +239,9 @@ def _parse_header(self):
else:
gain, offset = 1.0, 0.0

signal_channels.append((name, str(chan_id), self._sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id))
signal_channels.append(
(name, str(chan_id), self._sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id)
)

signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)

Expand Down Expand Up @@ -313,7 +315,6 @@ def _get_signal_t_start(self, block_index, seg_index, stream_index):
def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id):
return self._buffer_descriptions[block_index][seg_index][buffer_id]


def _event_count(self, block_index, seg_index, event_channel_index):
return self._raw_ev_timestamps.size

Expand Down
54 changes: 28 additions & 26 deletions neo/rawio/baserawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,6 @@ def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id
raise (NotImplementedError)



class BaseRawWithBufferApiIO(BaseRawIO):
"""
Generic class for reader that support "buffer api".
Expand All @@ -1402,7 +1401,7 @@ def _get_signal_size(self, block_index, seg_index, stream_index):
buffer_desc = self.get_analogsignal_buffer_description(block_index, seg_index, buffer_id)
# some hdf5 revert teh buffer
time_axis = buffer_desc.get("time_axis", 0)
return buffer_desc['shape'][time_axis]
return buffer_desc["shape"][time_axis]

def _get_analogsignal_chunk(
self,
Expand All @@ -1413,57 +1412,63 @@ def _get_analogsignal_chunk(
stream_index: int,
channel_indexes: list[int] | None,
):

stream_id = self.header["signal_streams"][stream_index]["id"]
buffer_id = self.header["signal_streams"][stream_index]["buffer_id"]

buffer_slice = self._stream_buffer_slice[stream_id]

buffer_slice = self._stream_buffer_slice[stream_id]

buffer_desc = self.get_analogsignal_buffer_description(block_index, seg_index, buffer_id)

i_start = i_start or 0
i_stop = i_stop or buffer_desc['shape'][0]
i_stop = i_stop or buffer_desc["shape"][0]

if buffer_desc['type'] == "raw":
if buffer_desc["type"] == "raw":

# open files on demand and keep reference to opened file
if not hasattr(self, '_memmap_analogsignal_buffers'):
# open files on demand and keep reference to opened file
if not hasattr(self, "_memmap_analogsignal_buffers"):
self._memmap_analogsignal_buffers = {}
if block_index not in self._memmap_analogsignal_buffers:
self._memmap_analogsignal_buffers[block_index] = {}
if seg_index not in self._memmap_analogsignal_buffers[block_index]:
self._memmap_analogsignal_buffers[block_index][seg_index] = {}
if buffer_id not in self._memmap_analogsignal_buffers[block_index][seg_index]:
fid = open(buffer_desc['file_path'], mode='rb')
fid = open(buffer_desc["file_path"], mode="rb")
self._memmap_analogsignal_buffers[block_index][seg_index][buffer_id] = fid
else:
fid = self._memmap_analogsignal_buffers[block_index][seg_index][buffer_id]

num_channels = buffer_desc['shape'][1]

raw_sigs = get_memmap_chunk_from_opened_file(fid, num_channels, i_start, i_stop, np.dtype(buffer_desc['dtype']), file_offset=buffer_desc['file_offset'])


elif buffer_desc['type'] == 'hdf5':
num_channels = buffer_desc["shape"][1]

raw_sigs = get_memmap_chunk_from_opened_file(
fid,
num_channels,
i_start,
i_stop,
np.dtype(buffer_desc["dtype"]),
file_offset=buffer_desc["file_offset"],
)

elif buffer_desc["type"] == "hdf5":

# open files on demand and keep reference to opened file
if not hasattr(self, '_hdf5_analogsignal_buffers'):
# open files on demand and keep reference to opened file
if not hasattr(self, "_hdf5_analogsignal_buffers"):
self._hdf5_analogsignal_buffers = {}
if block_index not in self._hdf5_analogsignal_buffers:
self._hdf5_analogsignal_buffers[block_index] = {}
if seg_index not in self._hdf5_analogsignal_buffers[block_index]:
self._hdf5_analogsignal_buffers[block_index][seg_index] = {}
if buffer_id not in self._hdf5_analogsignal_buffers[block_index][seg_index]:
import h5py
h5file = h5py.File(buffer_desc['file_path'], mode="r")

h5file = h5py.File(buffer_desc["file_path"], mode="r")
self._hdf5_analogsignal_buffers[block_index][seg_index][buffer_id] = h5file
else:
h5file = self._hdf5_analogsignal_buffers[block_index][seg_index][buffer_id]

hdf5_path = buffer_desc["hdf5_path"]
full_raw_sigs = h5file[hdf5_path]

time_axis = buffer_desc.get("time_axis", 0)
if time_axis == 0:
raw_sigs = full_raw_sigs[i_start:i_stop, :]
Expand All @@ -1475,31 +1480,28 @@ def _get_analogsignal_chunk(
if buffer_slice is not None:
raw_sigs = raw_sigs[:, buffer_slice]



else:
raise NotImplementedError()

# this is a pre slicing when the stream do not contain all channels (for instance spikeglx when load_sync_channel=False)
if buffer_slice is not None:
raw_sigs = raw_sigs[:, buffer_slice]

# channel slice requested
if channel_indexes is not None:
raw_sigs = raw_sigs[:, channel_indexes]


return raw_sigs

def __del__(self):
if hasattr(self, '_memmap_analogsignal_buffers'):
if hasattr(self, "_memmap_analogsignal_buffers"):
for block_index in self._memmap_analogsignal_buffers.keys():
for seg_index in self._memmap_analogsignal_buffers[block_index].keys():
for buffer_id, fid in self._memmap_analogsignal_buffers[block_index][seg_index].items():
fid.close()
del self._memmap_analogsignal_buffers

if hasattr(self, '_hdf5_analogsignal_buffers'):
if hasattr(self, "_hdf5_analogsignal_buffers"):
for block_index in self._hdf5_analogsignal_buffers.keys():
for seg_index in self._hdf5_analogsignal_buffers[block_index].keys():
for buffer_id, h5_file in self._hdf5_analogsignal_buffers[block_index][seg_index].items():
Expand Down
13 changes: 6 additions & 7 deletions neo/rawio/brainvisionrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,18 @@ def _parse_header(self):

sig_dtype = np.dtype(fmts[fmt])


stream_id = "0"
buffer_id = "0"
self._buffer_descriptions = {0 :{0 : {}}}
self._buffer_descriptions = {0: {0: {}}}
self._stream_buffer_slice = {}
shape = get_memmap_shape(binary_filename, sig_dtype, num_channels=nb_channel, offset=0)
self._buffer_descriptions[0][0][buffer_id] = {
"type" : "raw",
"file_path" : binary_filename,
"dtype" : str(sig_dtype),
"type": "raw",
"file_path": binary_filename,
"dtype": str(sig_dtype),
"order": "C",
"file_offset" : 0,
"shape" : shape,
"file_offset": 0,
"shape": shape,
}
self._stream_buffer_slice[stream_id] = None

Expand Down
11 changes: 7 additions & 4 deletions neo/rawio/edfrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,17 @@ def _parse_header(self):
for array_key in array_keys:
array_anno = {array_key: [h[array_key] for h in self.signal_headers]}
seg_ann["signals"].append({"__array_annotations__": array_anno})

# We store the following attributes for rapid access without needing the reader

self._t_stop = self.edf_reader.datarecord_duration * self.edf_reader.datarecords_in_file
# use sample count of first signal in stream
self._stream_index_samples = {stream_index : self.edf_reader.getNSamples()[chidx][0] for stream_index, chidx in self.stream_idx_to_chidx.items()}
self._stream_index_samples = {
stream_index: self.edf_reader.getNSamples()[chidx][0]
for stream_index, chidx in self.stream_idx_to_chidx.items()
}
self._number_of_events = len(self.edf_reader.readAnnotations()[0])

self.close()

def _get_stream_channels(self, stream_index):
Expand Down
14 changes: 7 additions & 7 deletions neo/rawio/elanrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _parse_header(self):
buffer_id = "0"
signal_buffers = np.array([("Signals", buffer_id)], dtype=_signal_buffer_dtype)
signal_streams = np.array([("Signals", stream_id, buffer_id)], dtype=_signal_stream_dtype)

sig_channels = []
for c, chan_info in enumerate(channel_infos):
chan_name = chan_info["label"]
Expand Down Expand Up @@ -197,16 +197,16 @@ def _parse_header(self):
sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)

# raw data
self._buffer_descriptions = {0 :{0 : {}}}
self._buffer_descriptions = {0: {0: {}}}
self._stream_buffer_slice = {}
shape = get_memmap_shape(self.filename, sig_dtype, num_channels=nb_channel + 2, offset=0)
self._buffer_descriptions[0][0][buffer_id] = {
"type" : "raw",
"file_path" : self.filename,
"dtype" : sig_dtype,
"type": "raw",
"file_path": self.filename,
"dtype": sig_dtype,
"order": "C",
"file_offset" : 0,
"shape" : shape,
"file_offset": 0,
"shape": shape,
}
self._stream_buffer_slice["0"] = slice(0, -2)

Expand Down
18 changes: 10 additions & 8 deletions neo/rawio/maxwellrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ def _parse_header(self):

# create signal channels
max_sig_length = 0
self._buffer_descriptions = {0 :{0 :{}}}
self._buffer_descriptions = {0: {0: {}}}
self._stream_buffer_slice = {}
sig_channels = []
well_indices_to_remove = []
for stream_index, stream_id in enumerate(signal_streams["id"]):

if int(version) == 20160704:
sr = 20000.0
settings = h5file["settings"]
Expand Down Expand Up @@ -163,14 +163,14 @@ def _parse_header(self):
continue

self._stream_buffer_slice[stream_id] = None

buffer_id = stream_id
shape = h5file[hdf5_path].shape
self._buffer_descriptions[0][0][buffer_id] = {
"type" : "hdf5",
"file_path" : str(self.filename),
"hdf5_path" : hdf5_path,
"shape" : shape,
"type": "hdf5",
"file_path": str(self.filename),
"hdf5_path": hdf5_path,
"shape": shape,
"time_axis": 1,
}
self._stream_buffer_slice[stream_id] = slice(None)
Expand Down Expand Up @@ -232,7 +232,9 @@ def _get_signal_t_start(self, block_index, seg_index, stream_index):

def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes):
try:
return super()._get_analogsignal_chunk(block_index, seg_index, i_start, i_stop, stream_index, channel_indexes)
return super()._get_analogsignal_chunk(
block_index, seg_index, i_start, i_stop, stream_index, channel_indexes
)
except OSError as e:
print("*" * 10)
print(_hdf_maxwell_error)
Expand Down
23 changes: 10 additions & 13 deletions neo/rawio/micromedrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,13 @@ class MicromedRawIO(BaseRawWithBufferApiIO):
extensions = ["trc", "TRC"]
rawmode = "one-file"


def __init__(self, filename=""):
BaseRawWithBufferApiIO.__init__(self)
self.filename = filename

def _parse_header(self):

self._buffer_descriptions = {0 :{ 0: {}}}
self._buffer_descriptions = {0: {0: {}}}

with open(self.filename, "rb") as fid:
f = StructFile(fid)
Expand Down Expand Up @@ -106,16 +105,14 @@ def _parse_header(self):
buffer_id = "0"
stream_id = "0"
self._buffer_descriptions[0][0][buffer_id] = {
"type" : "raw",
"file_path" : str(self.filename),
"dtype" : sig_dtype,
"type": "raw",
"file_path": str(self.filename),
"dtype": sig_dtype,
"order": "C",
"file_offset" : 0,
"shape" : signal_shape,
"file_offset": 0,
"shape": signal_shape,
}



# Reading Code Info
zname2, pos, length = zones["ORDER"]
f.seek(pos)
Expand Down Expand Up @@ -144,12 +141,12 @@ def _parse_header(self):
sampling_rate *= Rate_Min
chan_id = str(c)


signal_channels.append((chan_name, chan_id, sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id))

signal_channels.append(
(chan_name, chan_id, sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id)
)

signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)

self._stream_buffer_slice = {"0": slice(None)}
signal_buffers = np.array([("Signals", buffer_id)], dtype=_signal_buffer_dtype)
signal_streams = np.array([("Signals", stream_id, buffer_id)], dtype=_signal_stream_dtype)
Expand Down
12 changes: 6 additions & 6 deletions neo/rawio/neuronexusrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,14 @@ def _parse_header(self):

# the will cretae a memory map with teh generic mechanism
buffer_id = "0"
self._buffer_descriptions = {0 :{0 :{}}}
self._buffer_descriptions = {0: {0: {}}}
self._buffer_descriptions[0][0][buffer_id] = {
"type" : "raw",
"file_path" : str(binary_file),
"dtype" : BINARY_DTYPE,
"type": "raw",
"file_path": str(binary_file),
"dtype": BINARY_DTYPE,
"order": "C",
"file_offset" : 0,
"shape" : (self._n_samples, self._n_channels),
"file_offset": 0,
"shape": (self._n_samples, self._n_channels),
}
# Make the memory map for timestamp
self._timestamps = np.memmap(
Expand Down
Loading

0 comments on commit 4d079ef

Please sign in to comment.