Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix overflow of Plexon in numpy 2.0 #1613

Merged
merged 9 commits into from
Dec 19, 2024
39 changes: 24 additions & 15 deletions neo/rawio/plexonrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ def _parse_header(self):
for index, pos in enumerate(positions):
bl_header = data[pos : pos + 16].view(DataBlockHeader)[0]

# To avoid overflow errors when doing arithmetic operations on numpy scalars
np_scalar_to_python_scalar = lambda x: x.item() if isinstance(x, np.generic) else x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason not to make this lambda function near the top so it is reusable. Or even define this as a true private function. How are you thinking about using the same lambda function twice being generated in the for loop. It should be relatively low cost although seems a bit of a waste no? We could define the function once and reuse in both the for loops, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping the definition close to the use so people can quickly see what it is for. This takes nanoseconds:

image

Extracting the data from plexon takes minutes. It does not factor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figured you'd have checked. Fair enough. :)

bl_header = {key: np_scalar_to_python_scalar(bl_header[key]) for key in bl_header.dtype.names}

current_upper_byte_of_5_byte_timestamp = int(bl_header["UpperByteOf5ByteTimestamp"])
current_bl_timestamp = int(bl_header["TimeStamp"])
timestamp = current_upper_byte_of_5_byte_timestamp * 2**32 + current_bl_timestamp
Expand Down Expand Up @@ -255,24 +259,29 @@ def _parse_header(self):
else:
chan_loop = range(nb_sig_chan)
for chan_index in chan_loop:
h = slowChannelHeaders[chan_index]
name = h["Name"].decode("utf8")
chan_id = h["Channel"]
channel_headers = slowChannelHeaders[chan_index]

# To avoid overflow errors when doing arithmetic operations on numpy scalars
np_scalar_to_python_scalar = lambda x: x.item() if isinstance(x, np.generic) else x
channel_headers = {key: np_scalar_to_python_scalar(channel_headers[key]) for key in channel_headers.dtype.names}

name = channel_headers["Name"].decode("utf8")
chan_id = channel_headers["Channel"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last question.

so below you called it dsp_channel_headers which was really nice. Any interest in changing this to slow_channel_headers, just so we make sure we never accidentally overwrite any variable. I think the explicit makes it cleaner and since we are already fixing up the variable naming might be worth the push?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sure.

length = self._data_blocks[5][chan_id]["size"].sum() // 2
if length == 0:
continue # channel not added
source_id.append(h["SrcId"])
source_id.append(channel_headers["SrcId"])
channel_num_samples.append(length)
sampling_rate = float(h["ADFreq"])
sampling_rate = float(channel_headers["ADFreq"])
sig_dtype = "int16"
units = "" # I don't know units
if global_header["Version"] in [100, 101]:
gain = 5000.0 / (2048 * h["Gain"] * 1000.0)
gain = 5000.0 / (2048 * channel_headers["Gain"] * 1000.0)
elif global_header["Version"] in [102]:
gain = 5000.0 / (2048 * h["Gain"] * h["PreampGain"])
gain = 5000.0 / (2048 * channel_headers["Gain"] * channel_headers["PreampGain"])
elif global_header["Version"] >= 103:
gain = global_header["SlowMaxMagnitudeMV"] / (
0.5 * (2 ** global_header["BitsPerSpikeSample"]) * h["Gain"] * h["PreampGain"]
0.5 * (2 ** global_header["BitsPerSpikeSample"]) * channel_headers["Gain"] * channel_headers["PreampGain"]
)
offset = 0.0

Expand Down Expand Up @@ -358,21 +367,21 @@ def _parse_header(self):
unit_loop = enumerate(self.internal_unit_ids)

for unit_index, (chan_id, unit_id) in unit_loop:
c = np.nonzero(dspChannelHeaders["Channel"] == chan_id)[0][0]
h = dspChannelHeaders[c]
channel_index = np.nonzero(dspChannelHeaders["Channel"] == chan_id)[0][0]
dsp_channel_headers = dspChannelHeaders[channel_index]

name = h["Name"].decode("utf8")
name = dsp_channel_headers["Name"].decode("utf8")
_id = f"ch{chan_id}#{unit_id}"
wf_units = ""
if global_header["Version"] < 103:
wf_gain = 3000.0 / (2048 * h["Gain"] * 1000.0)
wf_gain = 3000.0 / (2048 * dsp_channel_headers["Gain"] * 1000.0)
elif 103 <= global_header["Version"] < 105:
wf_gain = global_header["SpikeMaxMagnitudeMV"] / (
0.5 * 2.0 ** (global_header["BitsPerSpikeSample"]) * h["Gain"] * 1000.0
0.5 * 2.0 ** (global_header["BitsPerSpikeSample"]) * dsp_channel_headers["Gain"] * 1000.0
)
elif global_header["Version"] >= 105:
wf_gain = global_header["SpikeMaxMagnitudeMV"] / (
0.5 * 2.0 ** (global_header["BitsPerSpikeSample"]) * h["Gain"] * global_header["SpikePreAmpGain"]
0.5 * 2.0 ** (global_header["BitsPerSpikeSample"]) * dsp_channel_headers["Gain"] * global_header["SpikePreAmpGain"]
)
wf_offset = 0.0
wf_left_sweep = -1 # DONT KNOWN
Expand Down Expand Up @@ -576,7 +585,7 @@ def read_as_dict(fid, dtype, offset=None):
v = v.replace("\x03", "")
v = v.replace("\x00", "")

info[k] = v
info[k] = v.item() if isinstance(v, np.generic) else v
return info


Expand Down
Loading