Skip to content

Commit

Permalink
Fix icoh normalisation term and add unit test (#366)
Browse files Browse the repository at this point in the history
* Update bispectra feature

* Fix icoh normalisation term and add unit test

* expose nperseg

* Update test and coh equations

---------

Co-authored-by: timonmerk <[email protected]>
  • Loading branch information
tsbinns and timonmerk authored Oct 4, 2024
1 parent 6fc5f95 commit ea669e2
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 6 deletions.
1 change: 1 addition & 0 deletions py_neuromodulation/default_settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ coherence_settings:
method:
coh: true
icoh: true
nperseg: 128

fooof_settings:
aperiodic:
Expand Down
17 changes: 11 additions & 6 deletions py_neuromodulation/features/coherence.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class CoherenceFeatures(BoolSelector):
mean_fband: bool = True
max_fband: bool = True
max_allfbands: bool = True


ListOfTwoStr = Annotated[list[str], Field(min_length=2, max_length=2)]

Expand All @@ -35,6 +35,7 @@ class CoherenceSettings(NMBaseModel):
features: CoherenceFeatures = CoherenceFeatures()
method: CoherenceMethods = CoherenceMethods()
channels: list[ListOfTwoStr] = []
nperseg: int = Field(default=128, ge=0)
frequency_bands: list[str] = Field(default=["high_beta"], min_length=1)

@field_validator("frequency_bands")
Expand All @@ -49,6 +50,7 @@ def __init__(
window: str,
fbands: list[FrequencyRange],
fband_names: list[str],
nperseg: int,
ch_1_name: str,
ch_2_name: str,
ch_1_idx: int,
Expand All @@ -65,6 +67,7 @@ def __init__(
self.ch_2 = ch_2_name
self.ch_1_idx = ch_1_idx
self.ch_2_idx = ch_2_idx
self.nperseg = nperseg
self.coh = coh
self.icoh = icoh
self.features_coh = features_coh
Expand All @@ -79,14 +82,15 @@ def __init__(
def get_coh(self, feature_results, x, y):
from scipy.signal import welch, csd

self.f, self.Pxx = welch(x, self.sfreq, self.window, nperseg=128)
self.Pyy = welch(y, self.sfreq, self.window, nperseg=128)[1]
self.Pxy = csd(x, y, self.sfreq, self.window, nperseg=128)[1]
self.f, self.Pxx = welch(x, self.sfreq, self.window, nperseg=self.nperseg)
self.Pyy = welch(y, self.sfreq, self.window, nperseg=self.nperseg)[1]
self.Pxy = csd(x, y, self.sfreq, self.window, nperseg=self.nperseg)[1]

if self.coh:
self.coh_val = np.abs(self.Pxy**2) / (self.Pxx * self.Pyy)
# XXX: gives different output to abs(Sxy) / sqrt(Sxx * Syy)
self.coh_val = np.abs(self.Pxy) ** 2 / (self.Pxx * self.Pyy)
if self.icoh:
self.icoh_val = np.array(self.Pxy / (self.Pxx * self.Pyy)).imag
self.icoh_val = self.Pxy.imag / np.sqrt(self.Pxx * self.Pyy)

for coh_idx, coh_type in enumerate([self.coh, self.icoh]):
if coh_type:
Expand Down Expand Up @@ -180,6 +184,7 @@ def __init__(
"hann",
fband_specs,
fband_names,
self.settings.nperseg,
ch_1_name,
ch_2_name,
ch_1_idx,
Expand Down
119 changes: 119 additions & 0 deletions tests/test_coherence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import numpy as np
from mne_connectivity import make_signals_in_freq_bands

import py_neuromodulation as nm


def test_coherence():
"""Check that coherence features compute properly and match expected values."""
# Simulate connectivity data (interaction at specified frequency band)
sfreq = 500 # Hz
n_epochs = 1
n_times = sfreq * 2 # samples
fband = (15, 20) # frequency band of interaction, Hz
trans = 2 # transition bandwidth of signal, Hz
delay = 50 # samples
epochs = make_signals_in_freq_bands(
n_seeds=1,
n_targets=1,
freq_band=fband,
n_epochs=n_epochs,
n_times=n_times,
sfreq=sfreq,
trans_bandwidth=trans,
connection_delay=delay,
ch_names=["seed", "target"],
snr=0.7, # change here requires change in `signal/noise_con` vars below
rng_seed=44
)

# Set up py_nm channels info
ch_names = epochs.ch_names
ch_types = epochs.get_channel_types()
channels = nm.utils.set_channels(
ch_names=ch_names,
ch_types=ch_types,
reference="default",
bads=None,
new_names="default",
used_types=tuple(np.unique(ch_types)),
target_keywords=None,
)

# Set up pn_nm processing settings
settings = nm.NMSettings.get_default()
settings.reset()
settings.features.coherence = True

# redefine freq. bands of interest
# (accounts for signal-noise transition bandwdith when defining frequencies)
settings.frequency_ranges_hz = {
"signal": { # strong connectivity expected
"frequency_low_hz": fband[0],
"frequency_high_hz": fband[1],
},
"noise_low": { # weak connectivity expected from 1 Hz to start of interaction
"frequency_low_hz": 1,
"frequency_high_hz": fband[0] - trans * 2,
},
"noise_high": { # weak connectivity expected from end of interaction to Nyquist
"frequency_low_hz": fband[1] + trans * 2,
"frequency_high_hz": sfreq // 2 - 1,
},
}
settings.coherence_settings.frequency_bands = ["signal", "noise_low", "noise_high"]

# only average within each band required
settings.coherence_settings.features = {
"mean_fband": True, "max_fband": False, "max_allfbands": False
}

# unique all-to-all connectivity indices, i.e.: ([0], [1])
# XXX: avoids pydantic ValidationError that lists are too short (length == 1)
settings.coherence_settings.channels = [ch_names]

# do not normalise features for this test!
# (normalisation changes interpretability of connectivity values, making it harder to
# define 'expected' connectivity values)
settings.postprocessing.feature_normalization = False

# Set up py_nm stream
stream = nm.Stream(
settings=settings,
channels=channels,
path_grids=None,
verbose=True,
sfreq=epochs.info["sfreq"],
)

# Compute connectivity
features = stream.run(
epochs.get_data(copy=False)[0], # extract first (and only) epoch from obj
out_dir="./test_data",
experiment_name="test_coherence",
)

# Aggregate results over windows
results = {key: None for key in features.keys()}
results.pop("time")
for key in results.keys():
# average over windows; take absolute before averaging icoh values
results[key] = np.abs(features[key].values).mean()

node_name = "seed_to_target"
for con_method in ["coh", "icoh"]:
# Define expected connectivity values for signal and noise frequencies
noise_con = 0.15
signal_con = 0.25

# Assert that frequencies of simulated interaction have strong connectivity
np.testing.assert_array_less(
signal_con, results[f"{con_method}_{node_name}_mean_fband_signal"]
)
# Assert that frequencies of noise have weak connectivity
np.testing.assert_array_less(
results[f"{con_method}_{node_name}_mean_fband_noise_low"], noise_con
)
np.testing.assert_array_less(
results[f"{con_method}_{node_name}_mean_fband_noise_high"], noise_con
)

0 comments on commit ea669e2

Please sign in to comment.