Skip to content

Commit

Permalink
added aperiodic model to output of model fit
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabi committed Aug 22, 2024
1 parent 6a61887 commit 7e67975
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 70 deletions.
159 changes: 100 additions & 59 deletions examples/basic_functionality.ipynb

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pyrasa/irasa_mne/mne_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def fit_aperiodic_model(
A DataFrame containing the fitted aperiodic parameters for each channel.
- gof : pd.DataFrame
A DataFrame containing the goodness of fit metrics for each channel.
- pred : pd.DataFrame
- model : pd.DataFrame
A DataFrame containing the predicted aperiodic model for each channel and each time point.
Expand Down Expand Up @@ -533,7 +533,7 @@ def fit_aperiodic_model(
A DataFrame containing the fitted aperiodic parameters for each channel.
- gof : pd.DataFrame
A DataFrame containing the goodness of fit metrics for each channel.
- pred : pd.DataFrame
- model : pd.DataFrame
A DataFrame containing the predicted aperiodic model for each channel and each time point.
Expand Down Expand Up @@ -572,9 +572,9 @@ def fit_aperiodic_model(
slope_fit.gof['event_id'] = event_dict[events[ix]]
aps_list.append(slope_fit.aperiodic_params.copy())
gof_list.append(slope_fit.gof.copy())
pred_list.append(slope_fit.pred)
pred_list.append(slope_fit.model)

return AperiodicFit(aperiodic_params=pd.concat(aps_list), gof=pd.concat(gof_list), pred=pd.concat(pred_list))
return AperiodicFit(aperiodic_params=pd.concat(aps_list), gof=pd.concat(gof_list), model=pd.concat(pred_list))


@define
Expand Down
12 changes: 6 additions & 6 deletions pyrasa/utils/aperiodic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def compute_aperiodic_model(
A DataFrame containing the fitted aperiodic parameters for each channel.
- gof : pd.DataFrame
A DataFrame containing the goodness of fit metrics for each channel.
- pred : pd.DataFrame
- model : pd.DataFrame
A DataFrame containing the predicted aperiodic model for each channel.
Notes
Expand Down Expand Up @@ -156,7 +156,7 @@ def num_zeros(decimal: int) -> float:
pred_list.append(pred)

# combine & return
return AperiodicFit(aperiodic_params=pd.concat(ap_list), gof=pd.concat(gof_list), pred=pd.concat(pred_list))
return AperiodicFit(aperiodic_params=pd.concat(ap_list), gof=pd.concat(gof_list), model=pd.concat(pred_list))


def compute_aperiodic_model_sprint(
Expand Down Expand Up @@ -208,7 +208,7 @@ def compute_aperiodic_model_sprint(
for each channel and each time point.
- gof : pd.DataFrame
A DataFrame containing the goodness of fit metrics for each channel and each time point.
- pred : pd.DataFrame
- model : pd.DataFrame
A DataFrame containing the predicted aperiodic model for each channel and each time point.
Notes
Expand Down Expand Up @@ -238,10 +238,10 @@ def compute_aperiodic_model_sprint(
)
aperiodic_fit.aperiodic_params['time'] = t
aperiodic_fit.gof['time'] = t
aperiodic_fit.pred['time'] = t
aperiodic_fit.model['time'] = t

ap_t_list.append(aperiodic_fit.aperiodic_params)
gof_t_list.append(aperiodic_fit.gof)
pred_t_list.append(aperiodic_fit.pred)
pred_t_list.append(aperiodic_fit.model)

return AperiodicFit(aperiodic_params=pd.concat(ap_t_list), gof=pd.concat(gof_t_list), pred=pd.concat(pred_t_list))
return AperiodicFit(aperiodic_params=pd.concat(ap_t_list), gof=pd.concat(gof_t_list), model=pd.concat(pred_t_list))
1 change: 1 addition & 0 deletions pyrasa/utils/fit_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def fit_func(self) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
if self.log10_freq:
freq = 10**freq
df_pred = pd.DataFrame({'Frequency (Hz)': freq, 'aperiodic_model': pred})
df_pred['fit_type'] = self.label

return df_params, df_gof, df_pred

Expand Down
2 changes: 1 addition & 1 deletion pyrasa/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,4 @@ class AperiodicFit:

aperiodic_params: pd.DataFrame
gof: pd.DataFrame
pred: pd.DataFrame
model: pd.DataFrame
Binary file modified simulations/example_knee.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 7e67975

Please sign in to comment.