Skip to content

Commit

Permalink
Merge pull request #60 from schmidtfa/minor_fixes
Browse files Browse the repository at this point in the history
Add a method to estimate the error of the aperiodic model
  • Loading branch information
schmidtfa authored Aug 26, 2024
2 parents a4b8a49 + a2d85a5 commit cc3fc86
Show file tree
Hide file tree
Showing 4 changed files with 673 additions and 5 deletions.
592 changes: 592 additions & 0 deletions examples/hset_optimization.ipynb

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions pyrasa/utils/fit_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,19 @@ def _get_gof(psd: np.ndarray, psd_pred: np.ndarray, k: int, fit_type: str) -> pd

# https://robjhyndman.com/hyndsight/lm_aic.html
# c is in practice sometimes dropped. Only relevant when comparing models with different n
# c = n + n * np.log(2 * np.pi)
# c = np.log(n) + np.log(n) * np.log(2 * np.pi)
# aic = 2 * k + n * np.log(mse) + c #real
aic = 2 * k + np.log(n) * np.log(mse)
aic = 2 * k + np.log(n) * np.log(mse) # + c
# aic = 2 * k + n * mse
# according to Sclove 1987 only difference between BIC and AIC
# is that BIC uses log(n) * k instead of 2 * k
# bic = np.log(n) * k + n * np.log(mse) + c #real
bic = np.log(n) * k + np.log(n) * np.log(mse)
bic = np.log(n) * k + np.log(n) * np.log(mse) # + c
# bic = np.log(n) * k + n * mse
# Sclove 1987 also hints at sample size adjusted bic
an = np.log((n + 2) / 24) # defined in Rissanen 1978 based on minimum-bit representation of a signal
an = (n + 2) / 24 # defined in Rissanen 1978 based on minimum-bit representation of a signal
# abic -> https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7299313/
abic = np.log(an) * k + an * np.log(mse)
abic = np.log(an) * k + np.log(n) * np.log(mse)

r2 = 1 - (ss_res / ss_tot)
r2_adj = 1 - (((1 - r2) * (n - 1)) / (n - k - 1))
Expand Down
49 changes: 49 additions & 0 deletions pyrasa/utils/irasa_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,52 @@ def get_peaks(
polyorder=polyorder,
peak_width_limits=peak_width_limits,
)

def get_aperiodic_error(self, peak_kwargs: None | dict = None) -> np.ndarray:
"""
Computes the frequency resolved error of the aperiodic spectrum.
This method first computes the absolute of the periodic spectrum and subsequently zeroes out
any peaks in the spectrum that are potentially "oscillations", yielding the residual error of the aperiodic
spectrum as a function of frequency.
This can be useful when trying to optimize hyperparameters such as the hset.
peak_kwargs : dict
A dictionary containing keyword arguments that are passed on to the peak finding method 'get_peaks'
Returns
-------
np.ndarray
A numpy array containing the frequency resolved squared error of the aperiodic
spectrum extracted using irasa
Notes
-----
While not strictly necessary, setting peak_kwargs is highly recommended.
The reason for this is that through up-/downsampling and averaging "broadband"
parameters such as spectral knees can bleed in the periodic spectrum and could be wrongfully
interpreted as oscillations. This can be avoided by e.g. explicitely setting `min_peak_height`.
A good way of making a decision for the periodic parameters is to base it on the settings
used in peak detection.
"""

if peak_kwargs is None:
peak_kwargs = {}
# get absolute periodic spectrum
aperiodic_error = np.abs(self.periodic[0, :])

# zero-out peaks
peaks = self.get_peaks(**peak_kwargs)
freqs = self.freqs

for _, peak in peaks.iterrows():
cur_upper = peak['cf'] + peak['bw']
cur_lower = peak['cf'] - peak['bw']

freq_mask = np.logical_and(freqs < cur_upper, freqs > cur_lower)

aperiodic_error[freq_mask] = 0

return aperiodic_error
25 changes: 25 additions & 0 deletions tests/test_irasa_knee.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,28 @@ def test_irasa_knee_cmb(load_knee_cmb_signal, fs, exponent, knee, osc_freq):
# test whether we can reconstruct the peak frequency correctly
pe_params = irasa_out.get_peaks()
assert bool(np.isclose(np.round(pe_params['cf'], 0), osc_freq))


@pytest.mark.parametrize('exponent, knee', [(-1.5, 1000)], scope='session')
@pytest.mark.parametrize('fs', [1000], scope='session')
@pytest.mark.parametrize('osc_freq', [10], scope='session')
def test_aperiodic_error(load_knee_cmb_signal, fs, exponent, knee, osc_freq):
duration = 4
overlap = 0.5
irasa_out = irasa(
load_knee_cmb_signal,
fs=fs,
band=(0.1, 50),
psd_kwargs={'nperseg': duration * fs, 'noverlap': duration * fs * overlap},
hset_info=(1, 2.0, 0.05),
)

irasa_out_bad = irasa(
load_knee_cmb_signal,
fs=fs,
band=(0.1, 50),
psd_kwargs={'nperseg': duration * fs, 'noverlap': duration * fs * overlap},
hset_info=(1, 8.0, 0.05),
)

assert np.mean(irasa_out.get_aperiodic_error()) < np.mean(irasa_out_bad.get_aperiodic_error())

0 comments on commit cc3fc86

Please sign in to comment.