Skip to content

Commit

Permalink
Merge pull request #62 from schmidtfa/minor_fixes
Browse files Browse the repository at this point in the history
added aperiodic error estimation to tf irasa
  • Loading branch information
schmidtfa authored Sep 16, 2024
2 parents 617b179 + 84b3c66 commit eeae646
Show file tree
Hide file tree
Showing 10 changed files with 452 additions and 28 deletions.
10 changes: 5 additions & 5 deletions examples/irasa_mne.ipynb

Large diffs are not rendered by default.

246 changes: 230 additions & 16 deletions examples/irasa_sprint.ipynb

Large diffs are not rendered by default.

11 changes: 8 additions & 3 deletions pyrasa/irasa.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,14 @@ def _local_irasa_fun(
hset=hset,
time=time,
)
single_ch_dim = 2
sgramm_aperiodic = (
sgramm_aperiodic[np.newaxis, :, :] if sgramm_aperiodic.ndim == single_ch_dim else sgramm_aperiodic
)
sgramm_periodic = sgramm_periodic[np.newaxis, :, :] if sgramm_periodic.ndim == single_ch_dim else sgramm_periodic

freq, sgramm_aperiodic, sgramm_periodic, sgramm = _crop_data(
band, freq, sgramm_aperiodic, sgramm_periodic, sgramm, axis=0
band, freq, sgramm_aperiodic, sgramm_periodic, sgramm, axis=1
)

# adjust time info (i.e. cut the padded stuff)
Expand All @@ -332,7 +337,7 @@ def _local_irasa_fun(
freqs=freq[freq_mask],
time=time[t_mask],
raw_spectrum=sgramm,
periodic=sgramm_periodic[:, t_mask][freq_mask, :],
aperiodic=sgramm_aperiodic[:, t_mask][freq_mask, :],
periodic=sgramm_periodic[:, :, t_mask][:, freq_mask, :],
aperiodic=sgramm_aperiodic[:, :, t_mask][:, freq_mask, :],
ch_names=ch_names,
)
2 changes: 1 addition & 1 deletion pyrasa/utils/aperiodic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def compute_aperiodic_model(

# generate channel names if not given
if ch_names is None:
ch_names = np.arange(aperiodic_spectrum.shape[0])
ch_names = [str(i) for i in np.arange(aperiodic_spectrum.shape[0])]

if scale:

Expand Down
2 changes: 1 addition & 1 deletion pyrasa/utils/fit_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _get_gof(psd: np.ndarray, psd_pred: np.ndarray, k: int, fit_type: str) -> pd

# https://robjhyndman.com/hyndsight/lm_aic.html
# c is in practice sometimes dropped. Only relevant when comparing models with different n
# c = np.log(n) + np.log(n) * np.log(2 * np.pi)
# c = n + n * np.log(2 * np.pi)
# aic = 2 * k + n * np.log(mse) + c #real
aic = 2 * k + np.log(n) * np.log(mse) # + c
# aic = 2 * k + n * mse
Expand Down
67 changes: 67 additions & 0 deletions pyrasa/utils/irasa_tf_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,70 @@ def get_peaks(
polyorder=polyorder,
peak_width_limits=peak_width_limits,
)

def get_aperiodic_error(self, peak_kwargs: None | dict = None) -> np.ndarray:
"""
Computes the frequency resolved error of the aperiodic spectrum.
This method first computes the absolute of the periodic spectrum and subsequently zeroes out
any peaks in the spectrum that are potentially "oscillations", yielding the residual error of the aperiodic
spectrum as a function of frequency.
This can be useful when trying to optimize hyperparameters such as the hset.
peak_kwargs : dict
A dictionary containing keyword arguments that are passed on to the peak finding method 'get_peaks'
Returns
-------
np.ndarray
A numpy array containing the frequency resolved squared error of the aperiodic
spectrum extracted using irasa
Notes
-----
While not strictly necessary, setting peak_kwargs is highly recommended.
The reason for this is that through up-/downsampling and averaging "broadband"
parameters such as spectral knees can bleed in the periodic spectrum and could be wrongfully
interpreted as oscillations. This can be avoided by e.g. explicitely setting `min_peak_height`.
A good way of making a decision for the periodic parameters is to base it on the settings
used in peak detection.
"""

if peak_kwargs is None:
peak_kwargs = {}

# get absolute periodic spectrum & zero-out peaks
freqs = self.freqs
peaks = self.get_peaks(**peak_kwargs)
peak_times = peaks['time'].unique()
ch_names = peaks['ch_name'].unique()

valid_peak_times = [cur_t in peak_times for cur_t in self.time]
aperiodic_error = np.abs(self.periodic)
aperiodic_error_cut = aperiodic_error[:, :, valid_peak_times]

aperiodic_errors_ch = []
for c_ix, ch in enumerate(ch_names):
cur_ch_ape = aperiodic_error_cut[c_ix, :, :]
cur_peak_ch = peaks.query(f'ch_name == "{ch}"')

aperiodic_errors_t = []
for t_ix, cur_t in enumerate(peak_times):
cur_t_ape = cur_ch_ape[:, t_ix]
cur_t_ch = cur_peak_ch.query(f'time == "{cur_t}"')

for _, peak in cur_t_ch.iterrows():
cur_upper = peak['cf'] + peak['bw']
cur_lower = peak['cf'] - peak['bw']

freq_mask = np.logical_and(freqs < cur_upper, freqs > cur_lower)

cur_t_ape[freq_mask] = 0

aperiodic_errors_t.append(cur_t_ape)

aperiodic_errors_ch.append(np.array(aperiodic_errors_t).T)

return np.array(aperiodic_errors_ch)
2 changes: 1 addition & 1 deletion pyrasa/utils/peak_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def get_peak_params(

# generate channel names if not given
if ch_names is None:
ch_names = np.arange(periodic_spectrum.shape[0])
ch_names = [str(i) for i in np.arange(periodic_spectrum.shape[0])]

# cut data
if cut_spectrum is not None:
Expand Down
74 changes: 74 additions & 0 deletions simulations/notebooks/aperiodic_error_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#%%
import sys
from neurodsp.sim import set_random_seed
from neurodsp.sim import sim_powerlaw, sim_oscillation
from neurodsp.utils import create_times
from neurodsp.plts import plot_timefrequency#

from neurodsp.timefrequency import compute_wavelet_transform
import numpy as np
import matplotlib.pyplot as plt
#import seaborn as sns
import pandas as pd

import matplotlib as mpl
new_rc_params = {'text.usetex': False,
"svg.fonttype": 'none'
}
mpl.rcParams.update(new_rc_params)

set_random_seed(84)

from pyrasa.irasa import irasa_sprint
# %%
# Set some general settings, to be used across all simulations
fs = 500
n_seconds = 15
duration=4
overlap=0.5

# Create a times vector for the simulations
times = create_times(n_seconds, fs)


alpha = sim_oscillation(n_seconds=.5, fs=fs, freq=10)
no_alpha = np.zeros(len(alpha))
beta = sim_oscillation(n_seconds=.5, fs=fs, freq=25)
no_beta = np.zeros(len(beta))

exp_1 = sim_powerlaw(n_seconds=2.5, fs=fs, exponent=-1)
exp_2 = sim_powerlaw(n_seconds=2.5, fs=fs, exponent=-2)


alphas = np.concatenate([no_alpha, alpha, no_alpha, alpha, no_alpha])
betas = np.concatenate([beta, no_beta, beta, no_beta, beta])

sim_ts = np.concatenate([exp_1 + alphas,
exp_1 + alphas + betas,
exp_1 + betas,
exp_2 + alphas,
exp_2 + alphas + betas,
exp_2 + betas, ])

# %%
freqs = np.arange(1, 50, 0.5)
import scipy.signal as dsp

irasa_sprint_spectrum = irasa_sprint(sim_ts,#np.array([sim_ts, sim_ts]),
fs=fs,
band=(1, 50),
overlap_fraction=.95,
win_duration=.5,
ch_names=['A'],
hset_info=(1.05, 4., 0.05),
win_func=dsp.windows.hann)
# %%
peak_kwargs = { 'smooth': True,
'smoothing_window':1,
'peak_threshold':5,
'min_peak_height':.01,
'peak_width_limits': (0.5, 12)}
ap_error = irasa_sprint_spectrum.get_aperiodic_error(peak_kwargs)
# %%
plt.plot(ap_error[0,:,:].mean(axis=1))
# %%
29 changes: 29 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,35 @@ def ts4sprint(fs, exponent_1, exponent_2):
yield sim_ts


@pytest.fixture(scope='session')
def ts4sprint_knee(fs, exponent_1, exponent_2):
alpha = sim_oscillation(n_seconds=0.5, fs=fs, freq=10)
no_alpha = np.zeros(len(alpha))
beta = sim_oscillation(n_seconds=0.5, fs=fs, freq=25)
no_beta = np.zeros(len(beta))

knee1 = 20 ** np.abs(exponent_1)
knee2 = 20 ** np.abs(exponent_2)
exp_1 = sim_knee(n_seconds=2.5, fs=fs, exponent1=0, exponent2=exponent_1, knee=knee1)
exp_2 = sim_knee(n_seconds=2.5, fs=fs, exponent1=0, exponent2=exponent_2, knee=knee2)

# %%
alphas = np.concatenate([no_alpha, alpha, no_alpha, alpha, no_alpha])
betas = np.concatenate([beta, no_beta, beta, no_beta, beta])

sim_ts = np.concatenate(
[
exp_1 + alphas,
exp_1 + alphas + betas,
exp_1 + betas,
exp_2 + alphas,
exp_2 + alphas + betas,
exp_2 + betas,
]
)
yield sim_ts


@pytest.fixture(scope='session')
def gen_mne_data_raw():
data_path = sample.data_path()
Expand Down
37 changes: 36 additions & 1 deletion tests/test_irasa_knee.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
import scipy.signal as dsp

from pyrasa import irasa
from pyrasa import irasa, irasa_sprint

from .settings import EXP_KNEE_COMBO, FS, KNEE_TOLERANCE, MIN_CORR_PSD_CMB, OSC_FREQ, TOLERANCE

Expand Down Expand Up @@ -109,3 +109,38 @@ def test_aperiodic_error(load_knee_cmb_signal, fs, exponent, knee, osc_freq):
)

assert np.mean(irasa_out.get_aperiodic_error()) < np.mean(irasa_out_bad.get_aperiodic_error())


@pytest.mark.parametrize('fs', [1000], scope='session')
@pytest.mark.parametrize('exponent_1', [-0], scope='session')
@pytest.mark.parametrize('exponent_2', [-2], scope='session')
def test_aperiodic_error_tf(ts4sprint_knee, fs, exponent_1, exponent_2):
irasa_out = irasa_sprint(
ts4sprint_knee,
fs=fs,
band=(0.1, 50),
overlap_fraction=0.95,
win_duration=0.5,
hset_info=(1, 2.0, 0.05),
)

irasa_out_bad = irasa_sprint(
ts4sprint_knee,
fs=fs,
band=(0.1, 50),
overlap_fraction=0.95,
win_duration=0.5,
hset_info=(1, 8.0, 0.05),
)

kwargs = {
'cut_spectrum': (1, 40),
'smooth': True,
'smoothing_window': 3,
'min_peak_height': 0.01,
'peak_width_limits': (0.5, 12),
}

assert np.mean(irasa_out.get_aperiodic_error(peak_kwargs=kwargs)) < np.mean(
irasa_out_bad.get_aperiodic_error(peak_kwargs=kwargs)
)

0 comments on commit eeae646

Please sign in to comment.