Skip to content

Commit

Permalink
Merge branch 'master' into numpy-2-0
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 authored Dec 6, 2024
2 parents 7449082 + 4061fdf commit b078766
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 21 deletions.
73 changes: 53 additions & 20 deletions neo/rawio/openephysbinaryrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def _parse_header(self):
if name + "_npy" in info:
data = np.load(info[name + "_npy"], mmap_mode="r")
info[name] = data

# check that events have timestamps
assert "timestamps" in info, "Event stream does not have timestamps!"
# Updates for OpenEphys v0.6:
Expand Down Expand Up @@ -253,30 +252,64 @@ def _parse_header(self):
# 'states' was introduced in OpenEphys v0.6. For previous versions, events used 'channel_states'
if "states" in info or "channel_states" in info:
states = info["channel_states"] if "channel_states" in info else info["states"]

if states.size > 0:
timestamps = info["timestamps"]
labels = info["labels"]
rising = np.where(states > 0)[0]
falling = np.where(states < 0)[0]

# infer durations
# Identify unique channels based on state values
channels = np.unique(np.abs(states))

rising_indices = []
falling_indices = []

# all channels are packed into the same `states` array.
# So the states array includes positive and negative values for each channel:
# for example channel one rising would be +1 and channel one falling would be -1,
# channel two rising would be +2 and channel two falling would be -2, etc.
# This is the case for sure for version >= 0.6.x.
for channel in channels:
# Find rising and falling edges for each channel
rising = np.where(states == channel)[0]
falling = np.where(states == -channel)[0]

# Ensure each rising has a corresponding falling
if rising.size > 0 and falling.size > 0:
if rising[0] > falling[0]:
falling = falling[1:]
if rising.size > falling.size:
rising = rising[:-1]

# ensure that the number of rising and falling edges are the same:
if len(rising) != len(falling):
warn(
f"Channel {channel} has {len(rising)} rising edges and "
f"{len(falling)} falling edges. The number of rising and "
f"falling edges should be equal. Skipping events from this channel."
)
continue

rising_indices.extend(rising)
falling_indices.extend(falling)

rising_indices = np.array(rising_indices)
falling_indices = np.array(falling_indices)

# Sort the indices to maintain chronological order
sorted_order = np.argsort(rising_indices)
rising_indices = rising_indices[sorted_order]
falling_indices = falling_indices[sorted_order]

durations = None
if len(states) > 0:
# make sure first event is rising and last is falling
if states[0] < 0:
falling = falling[1:]
if states[-1] > 0:
rising = rising[:-1]

if len(rising) == len(falling):
durations = timestamps[falling] - timestamps[rising]
if not self._use_direct_evt_timestamps:
timestamps = timestamps / info["sample_rate"]
durations = durations / info["sample_rate"]

info["rising"] = rising
info["timestamps"] = timestamps[rising]
info["labels"] = labels[rising]
# if len(rising_indices) == len(falling_indices):
durations = timestamps[falling_indices] - timestamps[rising_indices]
if not self._use_direct_evt_timestamps:
timestamps = timestamps / info["sample_rate"]
durations = durations / info["sample_rate"]

info["rising"] = rising_indices
info["timestamps"] = timestamps[rising_indices]
info["labels"] = labels[rising_indices]
info["durations"] = durations

# no spike read yet
Expand Down
21 changes: 21 additions & 0 deletions neo/test/rawiotest/test_openephysbinaryrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from neo.rawio.openephysbinaryrawio import OpenEphysBinaryRawIO
from neo.test.rawiotest.common_rawio_test import BaseTestRawIO

import numpy as np


class TestOpenEphysBinaryRawIO(BaseTestRawIO, unittest.TestCase):
rawioclass = OpenEphysBinaryRawIO
Expand Down Expand Up @@ -57,6 +59,25 @@ def test_missing_folders(self):
)
rawio.parse_header()

def test_multiple_ttl_events_parsing(self):
rawio = OpenEphysBinaryRawIO(
self.get_local_path("openephysbinary/v0.6.x_neuropixels_with_sync"), load_sync_channel=False
)
rawio.parse_header()
rawio.header = rawio.header
# Testing co
# This is the TTL events from the NI Board channel
ttl_events = rawio._evt_streams[0][0][1]
assert "rising" in ttl_events.keys()
assert "labels" in ttl_events.keys()
assert "durations" in ttl_events.keys()
assert "timestamps" in ttl_events.keys()

# Check that durations of different event streams are correctly parsed:
assert np.allclose(ttl_events["durations"][ttl_events["labels"] == "1"], 0.5, atol=0.001)
assert np.allclose(ttl_events["durations"][ttl_events["labels"] == "6"], 0.025, atol=0.001)
assert np.allclose(ttl_events["durations"][ttl_events["labels"] == "7"], 0.016666, atol=0.001)


if __name__ == "__main__":
unittest.main()
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ igorproio = ["igor2"]
kwikio = ["klusta"]
neomatlabio = ["scipy>=1.0.0"]
nixio = ["nixio>=1.5.0"]
stimfitio = ["stfio"]
tiffio = ["pillow"]
edf = ["pyedflib"]
ced = ["sonpy"]
Expand Down

0 comments on commit b078766

Please sign in to comment.