Skip to content

Commit

Permalink
Merge pull request #59 from schmidtfa/minor_fixes
Browse files Browse the repository at this point in the history
fixing information criteria and some minor issues
  • Loading branch information
schmidtfa authored Aug 23, 2024
2 parents e02f7df + 7e67975 commit a4b8a49
Show file tree
Hide file tree
Showing 11 changed files with 345 additions and 167 deletions.
248 changes: 161 additions & 87 deletions examples/basic_functionality.ipynb

Large diffs are not rendered by default.

153 changes: 111 additions & 42 deletions examples/irasa_pitfalls.ipynb

Large diffs are not rendered by default.

15 changes: 11 additions & 4 deletions pyrasa/irasa_mne/mne_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,14 @@ def fit_aperiodic_model(
Returns
-------
AperiodicFit
An object containing two pandas DataFrames:
An object containing three pandas DataFrames:
- aperiodic_params : pd.DataFrame
A DataFrame containing the fitted aperiodic parameters for each channel.
- gof : pd.DataFrame
A DataFrame containing the goodness of fit metrics for each channel.
- model : pd.DataFrame
A DataFrame containing the predicted aperiodic model for each channel and each time point.
Notes
-----
Expand Down Expand Up @@ -525,11 +528,14 @@ def fit_aperiodic_model(
Returns
-------
AperiodicFit
An object containing two pandas DataFrames:
An object containing three pandas DataFrames:
- aperiodic_params : pd.DataFrame
A DataFrame containing the fitted aperiodic parameters for each channel.
- gof : pd.DataFrame
A DataFrame containing the goodness of fit metrics for each channel.
- model : pd.DataFrame
A DataFrame containing the predicted aperiodic model for each channel and each time point.
Notes
-----
Expand All @@ -551,7 +557,7 @@ def fit_aperiodic_model(
event_dict = {val: key for key, val in self.event_id.items()}
events = self.events[:, 2]

aps_list, gof_list = [], []
aps_list, gof_list, pred_list = [], [], []
for ix, cur_epoch in enumerate(self.get_data()):
slope_fit = compute_aperiodic_model(
cur_epoch,
Expand All @@ -566,8 +572,9 @@ def fit_aperiodic_model(
slope_fit.gof['event_id'] = event_dict[events[ix]]
aps_list.append(slope_fit.aperiodic_params.copy())
gof_list.append(slope_fit.gof.copy())
pred_list.append(slope_fit.model)

return AperiodicFit(aperiodic_params=pd.concat(aps_list), gof=pd.concat(gof_list))
return AperiodicFit(aperiodic_params=pd.concat(aps_list), gof=pd.concat(gof_list), model=pd.concat(pred_list))


@define
Expand Down
30 changes: 19 additions & 11 deletions pyrasa/utils/aperiodic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def _compute_aperiodic_model(
freq: np.ndarray,
fit_func: str | type[AbstractFitFun],
scale_factor: float | int = 1,
) -> tuple[pd.DataFrame, pd.DataFrame]:
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""helper function to model the aperiodic spectrum"""

if isinstance(fit_func, str):
Expand All @@ -27,9 +27,9 @@ def _compute_aperiodic_model(
raise ValueError('fit_func should be either a string ("fixed", "knee") or of type AbastractFitFun')

fit_f = fit_func(freq, aperiodic_spectrum, scale_factor=scale_factor)
params, gof = fit_f.fit_func()
params, gof, pred = fit_f.fit_func()

return params, gof
return params, gof, pred


def compute_aperiodic_model(
Expand Down Expand Up @@ -72,11 +72,13 @@ def compute_aperiodic_model(
Returns
-------
AperiodicFit
An object containing two pandas DataFrames:
An object containing three pandas DataFrames:
- aperiodic_params : pd.DataFrame
A DataFrame containing the fitted aperiodic parameters for each channel.
- gof : pd.DataFrame
A DataFrame containing the goodness of fit metrics for each channel.
- model : pd.DataFrame
A DataFrame containing the predicted aperiodic model for each channel.
Notes
-----
Expand Down Expand Up @@ -136,9 +138,9 @@ def num_zeros(decimal: int) -> float:
else:
scale_factor = 1

ap_list, gof_list = [], []
ap_list, gof_list, pred_list = [], [], []
for ix, ch_name in enumerate(ch_names):
params, gof = _compute_aperiodic_model(
params, gof, pred = _compute_aperiodic_model(
aperiodic_spectrum=aperiodic_spectrum[ix],
freq=freqs,
fit_func=fit_func,
Expand All @@ -147,12 +149,14 @@ def num_zeros(decimal: int) -> float:

params['ch_name'] = ch_name
gof['ch_name'] = ch_name
pred['ch_name'] = ch_name

ap_list.append(params)
gof_list.append(gof)
pred_list.append(pred)

# combine & return
return AperiodicFit(aperiodic_params=pd.concat(ap_list), gof=pd.concat(gof_list))
return AperiodicFit(aperiodic_params=pd.concat(ap_list), gof=pd.concat(gof_list), model=pd.concat(pred_list))


def compute_aperiodic_model_sprint(
Expand Down Expand Up @@ -198,12 +202,14 @@ def compute_aperiodic_model_sprint(
Returns
-------
AperiodicFit
An object containing two pandas DataFrames:
An object containing three pandas DataFrames:
- aperiodic_params : pd.DataFrame
A DataFrame containing the aperiodic parameters (e.g., center frequency, bandwidth, peak height)
A DataFrame containing the aperiodic parameters (e.g., Offset and Exponent)
for each channel and each time point.
- gof : pd.DataFrame
A DataFrame containing the goodness of fit metrics for each channel and each time point.
- model : pd.DataFrame
A DataFrame containing the predicted aperiodic model for each channel and each time point.
Notes
-----
Expand All @@ -219,7 +225,7 @@ def compute_aperiodic_model_sprint(
"""

ap_t_list, gof_t_list = [], []
ap_t_list, gof_t_list, pred_t_list = [], [], []

for ix, t in enumerate(times):
aperiodic_fit = compute_aperiodic_model(
Expand All @@ -232,8 +238,10 @@ def compute_aperiodic_model_sprint(
)
aperiodic_fit.aperiodic_params['time'] = t
aperiodic_fit.gof['time'] = t
aperiodic_fit.model['time'] = t

ap_t_list.append(aperiodic_fit.aperiodic_params)
gof_t_list.append(aperiodic_fit.gof)
pred_t_list.append(aperiodic_fit.model)

return AperiodicFit(aperiodic_params=pd.concat(ap_t_list), gof=pd.concat(gof_t_list))
return AperiodicFit(aperiodic_params=pd.concat(ap_t_list), gof=pd.concat(gof_t_list), model=pd.concat(pred_t_list))
38 changes: 28 additions & 10 deletions pyrasa/utils/fit_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,24 @@ def _get_gof(psd: np.ndarray, psd_pred: np.ndarray, k: int, fit_type: str) -> pd
mse = np.mean(residuals**2)
n = len(psd)

loglik = -n / 2 * (1 + np.log(mse) + np.log(2 * np.pi))
aic = 2 * (k - loglik)
bic = k * np.log(n) - 2 * loglik

# bic = n * np.log(mse) + k * np.log(n)
# aic = n * np.log(mse) + 2 * k

gof = pd.DataFrame({'mse': mse, 'r_squared': 1 - (ss_res / ss_tot), 'BIC': bic, 'AIC': aic}, index=[0])
# https://robjhyndman.com/hyndsight/lm_aic.html
# c is in practice sometimes dropped. Only relevant when comparing models with different n
# c = n + n * np.log(2 * np.pi)
# aic = 2 * k + n * np.log(mse) + c #real
aic = 2 * k + np.log(n) * np.log(mse)
# according to Sclove 1987 only difference between BIC and AIC
# is that BIC uses log(n) * k instead of 2 * k
# bic = np.log(n) * k + n * np.log(mse) + c #real
bic = np.log(n) * k + np.log(n) * np.log(mse)
# Sclove 1987 also hints at sample size adjusted bic
an = np.log((n + 2) / 24) # defined in Rissanen 1978 based on minimum-bit representation of a signal
# abic -> https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7299313/
abic = np.log(an) * k + an * np.log(mse)

r2 = 1 - (ss_res / ss_tot)
r2_adj = 1 - (((1 - r2) * (n - 1)) / (n - k - 1))

gof = pd.DataFrame({'mse': mse, 'R2': r2, 'R2_adj.': r2_adj, 'BIC': bic, 'BIC_adj.': abic, 'AIC': aic}, index=[0])
gof['fit_type'] = fit_type
return gof

Expand Down Expand Up @@ -177,7 +187,7 @@ def handle_scaling(self, df_params: pd.DataFrame, scale_factor: float) -> pd.Dat
raise ValueError('Scale Factor not handled. You need to overwrite the handle_scaling method.')
return df_params

def fit_func(self) -> tuple[pd.DataFrame, pd.DataFrame]:
def fit_func(self) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
curve_kwargs = self.curve_kwargs
p, _ = curve_fit(self.func, self.freq, self.aperiodic_spectrum, **curve_kwargs)

Expand All @@ -191,7 +201,15 @@ def fit_func(self) -> tuple[pd.DataFrame, pd.DataFrame]:
df_params = self.add_infos_to_df(df_params)
df_params = self.handle_scaling(df_params, scale_factor=self.scale_factor)

return df_params, df_gof
freq = self.freq.copy()
if self.log10_aperiodic:
pred = 10**pred
if self.log10_freq:
freq = 10**freq
df_pred = pd.DataFrame({'Frequency (Hz)': freq, 'aperiodic_model': pred})
df_pred['fit_type'] = self.label

return df_params, df_gof, df_pred


class FixedFitFun(AbstractFitFun):
Expand Down
17 changes: 9 additions & 8 deletions pyrasa/utils/irasa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def _check_irasa_settings(irasa_params: dict, hset_info: tuple) -> None:
assert band_evaluated[0] > 0, 'The evaluated frequency range is 0 or lower this makes no sense'
assert band_evaluated[1] < nyquist, (
f'The evaluated frequency range goes up to {np.round(band_evaluated[1], 2)}Hz '
'which is higher than Nyquist (fs / 2)'
f'which is higher than the Nyquist frequency for your data of {nyquist}Hz, \n'
'try to either lower the upper bound for the hset or decrease the upper band limit, when running IRASA.'
)

filter_settings: list[float] = list(irasa_params['filter_settings'])
Expand All @@ -168,18 +169,18 @@ def _check_irasa_settings(irasa_params: dict, hset_info: tuple) -> None:
filter_settings[1] = band_evaluated[1]

assert np.logical_and(band_evaluated[0] >= filter_settings[0], band_evaluated[1] <= filter_settings[1]), (
f'You run IRASA in a frequency range from'
f'{np.round(band_evaluated[0], irasa_params["hset_accuracy"])} -'
f'You run IRASA in a frequency range from '
f'{np.round(band_evaluated[0], irasa_params["hset_accuracy"])} - '
f'{np.round(band_evaluated[1], irasa_params["hset_accuracy"])}Hz. \n'
'Your settings specified in "filter_settings" indicate that you have '
'a bandpass filter from '
'Your settings specified in "filter_settings" indicate that you have a pass band from '
f'{np.round(filter_settings[0], irasa_params["hset_accuracy"])} - '
f'{np.round(filter_settings[1], irasa_params["hset_accuracy"])}Hz. \n'
'This means that your evaluated range likely contains filter artifacts. \n'
'Either change your filter settings, adjust hset or the parameter "band" accordingly. \n'
f'You want to make sure that band[0] / hset.max() '
f'> {np.round(filter_settings[0], irasa_params["hset_accuracy"])} '
f'and that band[1] * hset.max() < {np.round(filter_settings[1], irasa_params["hset_accuracy"])}'
f'You want to make sure that the lower band limit divided by the upper bound of the hset '
f'> {np.round(filter_settings[0], irasa_params["hset_accuracy"])} \n'
'and that upper band limit times the upper bound of the hset < '
f'{np.round(filter_settings[1], irasa_params["hset_accuracy"])}'
)


Expand Down
1 change: 1 addition & 0 deletions pyrasa/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,4 @@ class AperiodicFit:

aperiodic_params: pd.DataFrame
gof: pd.DataFrame
model: pd.DataFrame
Binary file modified simulations/example_knee.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions tests/test_compute_slope.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_slope_fitting_fixed(fixed_aperiodic_signal, fs, exponent):
aperiodic_fit_f = compute_aperiodic_model(psd, freqs, fit_func='fixed')
assert pytest.approx(aperiodic_fit_f.aperiodic_params['Exponent'][0], abs=TOLERANCE) == np.abs(exponent)
# test goodness of fit should be close to r_squared == 1 for linear model
assert aperiodic_fit_f.gof['r_squared'][0] > MIN_R2
assert aperiodic_fit_f.gof['R2'][0] > MIN_R2

# test if we can set fit bounds w/o error
# _, _ = compute_slope(psd, freqs, fit_func='fixed', fit_bounds=[2, 50])
Expand All @@ -37,7 +37,7 @@ def test_slope_fitting_fixed(fixed_aperiodic_signal, fs, exponent):
# test the effect of scaling
aperiodic_fit_fs = compute_aperiodic_model(psd, freqs, fit_func='fixed', scale=True)
assert np.isclose(aperiodic_fit_fs.aperiodic_params['Exponent'], aperiodic_fit_f.aperiodic_params['Exponent'])
assert np.isclose(aperiodic_fit_fs.gof['r_squared'], aperiodic_fit_f.gof['r_squared'])
assert np.isclose(aperiodic_fit_fs.gof['R2'], aperiodic_fit_f.gof['R2'])


@pytest.mark.parametrize('exponent, fs', [(-1, 500)], scope='session')
Expand Down
4 changes: 2 additions & 2 deletions tests/test_irasa_knee.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_irasa_knee_peakless(load_knee_aperiodic_signal, fs, exponent, knee):
assert bool(np.isclose(knee_hat, knee_real, atol=KNEE_TOLERANCE))
# test bic/aic -> should be better for knee
assert slope_fit_k.gof['AIC'][0] < slope_fit_f.gof['AIC'][0]
assert slope_fit_k.gof['BIC'][0] < slope_fit_f.gof['BIC'][0]
assert slope_fit_k.gof['BIC_adj.'][0] < slope_fit_f.gof['BIC_adj.'][0]


# knee model
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_irasa_knee_cmb(load_knee_cmb_signal, fs, exponent, knee, osc_freq):
assert bool(np.isclose(knee_hat, knee_real, atol=KNEE_TOLERANCE))
# test bic/aic -> should be better for knee
assert slope_fit_k.gof['AIC'][0] < slope_fit_f.gof['AIC'][0]
assert slope_fit_k.gof['BIC'][0] < slope_fit_f.gof['BIC'][0]
assert slope_fit_k.gof['BIC_adj.'][0] < slope_fit_f.gof['BIC_adj.'][0]
# test whether we can reconstruct the peak frequency correctly
pe_params = irasa_out.get_peaks()
assert bool(np.isclose(np.round(pe_params['cf'], 0), osc_freq))
2 changes: 1 addition & 1 deletion tests/test_irasa_sprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_irasa_sprint(ts4sprint, fs, exponent_1, exponent_2):
# irasa_tf.aperiodic[np.newaxis, :, :], freqs=irasa_tf.freqs, times=irasa_tf.time,
# )

assert slope_fit.gof['r_squared'].mean() > MIN_R2_SPRINT
assert slope_fit.gof['R2'].mean() > MIN_R2_SPRINT
assert np.isclose(
np.mean(slope_fit.aperiodic_params.query('time < 7')['Exponent']), np.abs(exponent_1), atol=SPRINT_TOLERANCE
)
Expand Down

0 comments on commit a4b8a49

Please sign in to comment.