From ea669e20692da4c654b22949c466863b54c845bc Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Fri, 4 Oct 2024 11:19:08 +0200 Subject: [PATCH] Fix icoh normalisation term and add unit test (#366) * Update bispectra feature * Fix icoh normalisation term and add unit test * expose nperseg * Update test and coh equations --------- Co-authored-by: timonmerk --- py_neuromodulation/default_settings.yaml | 1 + py_neuromodulation/features/coherence.py | 17 ++-- tests/test_coherence.py | 119 +++++++++++++++++++++++ 3 files changed, 131 insertions(+), 6 deletions(-) create mode 100644 tests/test_coherence.py diff --git a/py_neuromodulation/default_settings.yaml b/py_neuromodulation/default_settings.yaml index c72e2fdd..80d79b6c 100644 --- a/py_neuromodulation/default_settings.yaml +++ b/py_neuromodulation/default_settings.yaml @@ -193,6 +193,7 @@ coherence_settings: method: coh: true icoh: true + nperseg: 128 fooof_settings: aperiodic: diff --git a/py_neuromodulation/features/coherence.py b/py_neuromodulation/features/coherence.py index 21ca471b..421037d9 100644 --- a/py_neuromodulation/features/coherence.py +++ b/py_neuromodulation/features/coherence.py @@ -26,7 +26,7 @@ class CoherenceFeatures(BoolSelector): mean_fband: bool = True max_fband: bool = True max_allfbands: bool = True - + ListOfTwoStr = Annotated[list[str], Field(min_length=2, max_length=2)] @@ -35,6 +35,7 @@ class CoherenceSettings(NMBaseModel): features: CoherenceFeatures = CoherenceFeatures() method: CoherenceMethods = CoherenceMethods() channels: list[ListOfTwoStr] = [] + nperseg: int = Field(default=128, ge=0) frequency_bands: list[str] = Field(default=["high_beta"], min_length=1) @field_validator("frequency_bands") @@ -49,6 +50,7 @@ def __init__( window: str, fbands: list[FrequencyRange], fband_names: list[str], + nperseg: int, ch_1_name: str, ch_2_name: str, ch_1_idx: int, @@ -65,6 +67,7 @@ def __init__( self.ch_2 = ch_2_name self.ch_1_idx = ch_1_idx self.ch_2_idx = ch_2_idx + self.nperseg = nperseg self.coh = coh self.icoh = icoh self.features_coh = features_coh @@ -79,14 +82,15 @@ def __init__( def get_coh(self, feature_results, x, y): from scipy.signal import welch, csd - self.f, self.Pxx = welch(x, self.sfreq, self.window, nperseg=128) - self.Pyy = welch(y, self.sfreq, self.window, nperseg=128)[1] - self.Pxy = csd(x, y, self.sfreq, self.window, nperseg=128)[1] + self.f, self.Pxx = welch(x, self.sfreq, self.window, nperseg=self.nperseg) + self.Pyy = welch(y, self.sfreq, self.window, nperseg=self.nperseg)[1] + self.Pxy = csd(x, y, self.sfreq, self.window, nperseg=self.nperseg)[1] if self.coh: - self.coh_val = np.abs(self.Pxy**2) / (self.Pxx * self.Pyy) + # XXX: gives different output to abs(Sxy) / sqrt(Sxx * Syy) + self.coh_val = np.abs(self.Pxy) ** 2 / (self.Pxx * self.Pyy) if self.icoh: - self.icoh_val = np.array(self.Pxy / (self.Pxx * self.Pyy)).imag + self.icoh_val = self.Pxy.imag / np.sqrt(self.Pxx * self.Pyy) for coh_idx, coh_type in enumerate([self.coh, self.icoh]): if coh_type: @@ -180,6 +184,7 @@ def __init__( "hann", fband_specs, fband_names, + self.settings.nperseg, ch_1_name, ch_2_name, ch_1_idx, diff --git a/tests/test_coherence.py b/tests/test_coherence.py new file mode 100644 index 00000000..a3684923 --- /dev/null +++ b/tests/test_coherence.py @@ -0,0 +1,119 @@ +import numpy as np +from mne_connectivity import make_signals_in_freq_bands + +import py_neuromodulation as nm + + +def test_coherence(): + """Check that coherence features compute properly and match expected values.""" + # Simulate connectivity data (interaction at specified frequency band) + sfreq = 500 # Hz + n_epochs = 1 + n_times = sfreq * 2 # samples + fband = (15, 20) # frequency band of interaction, Hz + trans = 2 # transition bandwidth of signal, Hz + delay = 50 # samples + epochs = make_signals_in_freq_bands( + n_seeds=1, + n_targets=1, + freq_band=fband, + n_epochs=n_epochs, + n_times=n_times, + sfreq=sfreq, + trans_bandwidth=trans, + connection_delay=delay, + ch_names=["seed", "target"], + snr=0.7, # change here requires change in `signal/noise_con` vars below + rng_seed=44 + ) + + # Set up py_nm channels info + ch_names = epochs.ch_names + ch_types = epochs.get_channel_types() + channels = nm.utils.set_channels( + ch_names=ch_names, + ch_types=ch_types, + reference="default", + bads=None, + new_names="default", + used_types=tuple(np.unique(ch_types)), + target_keywords=None, + ) + + # Set up pn_nm processing settings + settings = nm.NMSettings.get_default() + settings.reset() + settings.features.coherence = True + + # redefine freq. bands of interest + # (accounts for signal-noise transition bandwdith when defining frequencies) + settings.frequency_ranges_hz = { + "signal": { # strong connectivity expected + "frequency_low_hz": fband[0], + "frequency_high_hz": fband[1], + }, + "noise_low": { # weak connectivity expected from 1 Hz to start of interaction + "frequency_low_hz": 1, + "frequency_high_hz": fband[0] - trans * 2, + }, + "noise_high": { # weak connectivity expected from end of interaction to Nyquist + "frequency_low_hz": fband[1] + trans * 2, + "frequency_high_hz": sfreq // 2 - 1, + }, + } + settings.coherence_settings.frequency_bands = ["signal", "noise_low", "noise_high"] + + # only average within each band required + settings.coherence_settings.features = { + "mean_fband": True, "max_fband": False, "max_allfbands": False + } + + # unique all-to-all connectivity indices, i.e.: ([0], [1]) + # XXX: avoids pydantic ValidationError that lists are too short (length == 1) + settings.coherence_settings.channels = [ch_names] + + # do not normalise features for this test! + # (normalisation changes interpretability of connectivity values, making it harder to + # define 'expected' connectivity values) + settings.postprocessing.feature_normalization = False + + # Set up py_nm stream + stream = nm.Stream( + settings=settings, + channels=channels, + path_grids=None, + verbose=True, + sfreq=epochs.info["sfreq"], + ) + + # Compute connectivity + features = stream.run( + epochs.get_data(copy=False)[0], # extract first (and only) epoch from obj + out_dir="./test_data", + experiment_name="test_coherence", + ) + + # Aggregate results over windows + results = {key: None for key in features.keys()} + results.pop("time") + for key in results.keys(): + # average over windows; take absolute before averaging icoh values + results[key] = np.abs(features[key].values).mean() + + node_name = "seed_to_target" + for con_method in ["coh", "icoh"]: + # Define expected connectivity values for signal and noise frequencies + noise_con = 0.15 + signal_con = 0.25 + + # Assert that frequencies of simulated interaction have strong connectivity + np.testing.assert_array_less( + signal_con, results[f"{con_method}_{node_name}_mean_fband_signal"] + ) + # Assert that frequencies of noise have weak connectivity + np.testing.assert_array_less( + results[f"{con_method}_{node_name}_mean_fband_noise_low"], noise_con + ) + np.testing.assert_array_less( + results[f"{con_method}_{node_name}_mean_fband_noise_high"], noise_con + )