Skip to content

Commit

Permalink
adjusted tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabi committed Aug 22, 2024
1 parent e14a7e2 commit c4a0224
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 9 deletions.
13 changes: 8 additions & 5 deletions pyrasa/utils/fit_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,14 @@ def _get_gof(psd: np.ndarray, psd_pred: np.ndarray, k: int, fit_type: str) -> pd
n = len(psd)

# https://robjhyndman.com/hyndsight/lm_aic.html
c = n + n * np.log(2 * np.pi)
aic = 2 * k + n * np.log(mse) + c
# c is in practice sometimes dropped. Only relevant when comparing models with different n
# c = n + n * np.log(2 * np.pi)
# aic = 2 * k + n * np.log(mse) + c #real
aic = 2 * k + np.log(n) * np.log(mse)
# according to Sclove 1987 only difference between BIC and AIC
# is that BIC uses log(n) * k instead of 2
bic = np.log(n) * k + n * np.log(mse) + c
# is that BIC uses log(n) * k instead of 2 * k
# bic = np.log(n) * k + n * np.log(mse) + c #real
bic = np.log(n) * k + np.log(n) * np.log(mse)
# Sclove 1987 also hints at sample size adjusted bic
an = np.log((n + 2) / 24) # defined in Rissanen 1978 based on minimum-bit representation of a signal
# abic -> https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7299313/
Expand All @@ -89,7 +92,7 @@ def _get_gof(psd: np.ndarray, psd_pred: np.ndarray, k: int, fit_type: str) -> pd
r2 = 1 - (ss_res / ss_tot)
r2_adj = 1 - (((1 - r2) * (n - 1)) / (n - k - 1))

gof = pd.DataFrame({'mse': mse, 'R2': r2, 'R2_adj.': r2_adj, 'BIC': bic, 'ABIC': abic, 'AIC': aic}, index=[0])
gof = pd.DataFrame({'mse': mse, 'R2': r2, 'R2_adj.': r2_adj, 'BIC': bic, 'BIC_adj.': abic, 'AIC': aic}, index=[0])
gof['fit_type'] = fit_type
return gof

Expand Down
2 changes: 1 addition & 1 deletion tests/test_compute_slope.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_slope_fitting_fixed(fixed_aperiodic_signal, fs, exponent):
# test the effect of scaling
aperiodic_fit_fs = compute_aperiodic_model(psd, freqs, fit_func='fixed', scale=True)
assert np.isclose(aperiodic_fit_fs.aperiodic_params['Exponent'], aperiodic_fit_f.aperiodic_params['Exponent'])
assert np.isclose(aperiodic_fit_fs.gof['r_squared'], aperiodic_fit_f.gof['r_squared'])
assert np.isclose(aperiodic_fit_fs.gof['R2'], aperiodic_fit_f.gof['R2'])


@pytest.mark.parametrize('exponent, fs', [(-1, 500)], scope='session')
Expand Down
4 changes: 2 additions & 2 deletions tests/test_irasa_knee.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_irasa_knee_peakless(load_knee_aperiodic_signal, fs, exponent, knee):
assert bool(np.isclose(knee_hat, knee_real, atol=KNEE_TOLERANCE))
# test bic/aic -> should be better for knee
assert slope_fit_k.gof['AIC'][0] < slope_fit_f.gof['AIC'][0]
assert slope_fit_k.gof['BIC'][0] < slope_fit_f.gof['BIC'][0]
assert slope_fit_k.gof['BIC_adj.'][0] < slope_fit_f.gof['BIC_adj.'][0]


# knee model
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_irasa_knee_cmb(load_knee_cmb_signal, fs, exponent, knee, osc_freq):
assert bool(np.isclose(knee_hat, knee_real, atol=KNEE_TOLERANCE))
# test bic/aic -> should be better for knee
assert slope_fit_k.gof['AIC'][0] < slope_fit_f.gof['AIC'][0]
assert slope_fit_k.gof['BIC'][0] < slope_fit_f.gof['BIC'][0]
assert slope_fit_k.gof['BIC_adj.'][0] < slope_fit_f.gof['BIC_adj.'][0]
# test whether we can reconstruct the peak frequency correctly
pe_params = irasa_out.get_peaks()
assert bool(np.isclose(np.round(pe_params['cf'], 0), osc_freq))
2 changes: 1 addition & 1 deletion tests/test_irasa_sprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_irasa_sprint(ts4sprint, fs, exponent_1, exponent_2):
# irasa_tf.aperiodic[np.newaxis, :, :], freqs=irasa_tf.freqs, times=irasa_tf.time,
# )

assert slope_fit.gof['r_squared'].mean() > MIN_R2_SPRINT
assert slope_fit.gof['R2'].mean() > MIN_R2_SPRINT
assert np.isclose(
np.mean(slope_fit.aperiodic_params.query('time < 7')['Exponent']), np.abs(exponent_1), atol=SPRINT_TOLERANCE
)
Expand Down

0 comments on commit c4a0224

Please sign in to comment.