Skip to content

Commit

Permalink
added function for aperiodic error
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabi committed Aug 25, 2024
1 parent 7e67975 commit f0d29a7
Show file tree
Hide file tree
Showing 3 changed files with 648 additions and 5 deletions.
592 changes: 592 additions & 0 deletions examples/hset_optimization.ipynb

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions pyrasa/utils/fit_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,19 @@ def _get_gof(psd: np.ndarray, psd_pred: np.ndarray, k: int, fit_type: str) -> pd

# https://robjhyndman.com/hyndsight/lm_aic.html
# c is in practice sometimes dropped. Only relevant when comparing models with different n
# c = n + n * np.log(2 * np.pi)
# c = np.log(n) + np.log(n) * np.log(2 * np.pi)
# aic = 2 * k + n * np.log(mse) + c #real
aic = 2 * k + np.log(n) * np.log(mse)
aic = 2 * k + np.log(n) * np.log(mse) # + c
# aic = 2 * k + n * mse
# according to Sclove 1987 only difference between BIC and AIC
# is that BIC uses log(n) * k instead of 2 * k
# bic = np.log(n) * k + n * np.log(mse) + c #real
bic = np.log(n) * k + np.log(n) * np.log(mse)
bic = np.log(n) * k + np.log(n) * np.log(mse) # + c
# bic = np.log(n) * k + n * mse
# Sclove 1987 also hints at sample size adjusted bic
an = np.log((n + 2) / 24) # defined in Rissanen 1978 based on minimum-bit representation of a signal
an = (n + 2) / 24 # defined in Rissanen 1978 based on minimum-bit representation of a signal
# abic -> https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7299313/
abic = np.log(an) * k + an * np.log(mse)
abic = np.log(an) * k + np.log(n) * np.log(mse)

r2 = 1 - (ss_res / ss_tot)
r2_adj = 1 - (((1 - r2) * (n - 1)) / (n - k - 1))
Expand Down
49 changes: 49 additions & 0 deletions pyrasa/utils/irasa_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,52 @@ def get_peaks(
polyorder=polyorder,
peak_width_limits=peak_width_limits,
)

def get_aperiodic_error(self, peak_kwargs: None | dict = None) -> np.ndarray:
"""
Computes the frequency resolved error of the aperiodic spectrum.
This method first computes the absolute of the periodic spectrum and subsequently zeroes out
any peaks in the spectrum that are potentially "oscillations", yielding the residual error of the aperiodic
spectrum as a function of frequency.
This can be useful when trying to optimize hyperparameters such as the hset.
peak_kwargs : dict
A dictionary containing keyword arguments that are passed on to the peak finding method 'get_peaks'
Returns
-------
np.ndarray
A numpy array containing the frequency resolved squared error of the aperiodic
spectrum extracted using irasa
Notes
-----
While not strictly necessary, setting peak_kwargs is highly recommended.
The reason for this is that through up-/downsampling and averaging "broadband"
parameters such as spectral knees can bleed in the periodic spectrum and could be wrongfully
interpreted as oscillations. This can be avoided by e.g. explicitely setting `min_peak_height`.
A good way of making a decision for the periodic parameters is to base it on the settings
used in peak detection.
"""

if peak_kwargs is None:
peak_kwargs = {}
# get absolute periodic spectrum
aperiodic_error = np.abs(self.periodic[0, :])

# zero-out peaks
peaks = self.get_peaks(**peak_kwargs)
freqs = self.freqs

for _, peak in peaks.iterrows():
cur_upper = peak['cf'] + peak['bw']
cur_lower = peak['cf'] - peak['bw']

freq_mask = np.logical_and(freqs < cur_upper, freqs > cur_lower)

aperiodic_error[freq_mask] = 0

return aperiodic_error

0 comments on commit f0d29a7

Please sign in to comment.