Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add r-squared metric #252

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fooof/core/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_description():
"""

attributes = {'results' : ['aperiodic_params_', 'gaussian_params_', 'peak_params_',
'r_squared_', 'error_'],
'r_squared_', 'adj_r_squared_', 'error_'],
'settings' : ['peak_width_limits', 'max_n_peaks',
'min_peak_height', 'peak_threshold',
'aperiodic_mode'],
Expand Down
4 changes: 4 additions & 0 deletions fooof/core/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def gen_results_fm_str(fm, concise=False):
# Goodness if fit
'Goodness of fit metrics:',
'R^2 of model fit is {:5.4f}'.format(fm.r_squared_),
'Adjusted R^2 of model fit is {:5.4f}'.format(fm.adj_r_squared_),
'Error of the fit is {:5.4f}'.format(fm.error_),
'',

Expand Down Expand Up @@ -351,6 +352,7 @@ def gen_results_fg_str(fg, concise=False):
# Extract all the relevant data for printing
n_peaks = len(fg.get_params('peak_params'))
r2s = fg.get_params('r_squared')
adj_r2s = fg.get_params('adj_r_squared')
errors = fg.get_params('error')
exps = fg.get_params('aperiodic_params', 'exponent')
kns = fg.get_params('aperiodic_params', 'knee') \
Expand Down Expand Up @@ -399,6 +401,8 @@ def gen_results_fg_str(fg, concise=False):
'Goodness of fit metrics:',
' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}'
.format(np.nanmin(r2s), np.nanmax(r2s), np.nanmean(r2s)),
'Adj R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}'
.format(np.nanmin(adj_r2s), np.nanmax(adj_r2s), np.nanmean(adj_r2s)),
'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}'
.format(np.nanmin(errors), np.nanmax(errors), np.nanmean(errors)),
'',
Expand Down
6 changes: 5 additions & 1 deletion fooof/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class FOOOFMetaData(namedtuple('FOOOFMetaData', ['freq_range', 'freq_res'])):


class FOOOFResults(namedtuple('FOOOFResults', ['aperiodic_params', 'peak_params',
'r_squared', 'error', 'gaussian_params'])):
'r_squared', 'adj_r_squared', 'error',
'gaussian_params'])):
"""Model results from parameterizing a power spectrum.

Parameters
Expand All @@ -68,6 +69,9 @@ class FOOOFResults(namedtuple('FOOOFResults', ['aperiodic_params', 'peak_params'
Fitted parameter values for the peaks. Each row is a peak, as [CF, PW, BW].
r_squared : float
R-squared of the fit between the full model fit and the input data.
adj_r_squared : float
R-squared of the fit between the full model fit and the input data,
adjusted for the number of parameters in the model.
error : float
Error of the full model fit.
gaussian_params : 2d array
Expand Down
13 changes: 12 additions & 1 deletion fooof/objs/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ class FOOOF():
Each row is a gaussian, as [mean, height, standard deviation].
r_squared_ : float
R-squared of the fit between the input power spectrum and the full model fit.
adj_r_squared_ : float
Adjusted R-squared of the fit between the input power spectrum and the full model fit,
adjusted for the number of parameters in the model.
error_ : float
Error of the full model fit.
n_peaks_ : int
Expand Down Expand Up @@ -281,6 +284,7 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_res
self.gaussian_params_ = np.empty([0, 3])
self.peak_params_ = np.empty([0, 3])
self.r_squared_ = np.nan
self.adj_r_squared_ = np.nan
self.error_ = np.nan

self.fooofed_spectrum_ = None
Expand Down Expand Up @@ -367,6 +371,7 @@ def add_results(self, fooof_result):
self.gaussian_params_ = fooof_result.gaussian_params
self.peak_params_ = fooof_result.peak_params
self.r_squared_ = fooof_result.r_squared
self.adj_r_squared_ = fooof_result.adj_r_squared
self.error_ = fooof_result.error

self._check_loaded_results(fooof_result._asdict())
Expand Down Expand Up @@ -567,7 +572,7 @@ def get_params(self, name, col=None):

Parameters
----------
name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'}
name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared', 'adj_r_squared'}
Name of the data field to extract.
col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional
Column name / index to extract from selected data, if requested.
Expand Down Expand Up @@ -1098,8 +1103,14 @@ def _drop_peak_overlap(self, guess):
def _calc_r_squared(self):
"""Calculate the r-squared goodness of fit of the model, compared to the original data."""

# compute r-squared
r_val = np.corrcoef(self.power_spectrum, self.fooofed_spectrum_)
self.r_squared_ = r_val[0][1] ** 2

# compute adjusted r-squared
n = len(self.power_spectrum) # number of data points
k = len(self.peak_params_) * 3 + 2 # number of parameters
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mwprestonjr - just checking in here: the number of parameters here is computed as n_peaks * 3 (reflecting the 3 Gaussian params per Gaussian), and then + 2, presumably for aperiodic - but I think this should be +2 if fixed, +3 if knee mode, right? Or, more generally, n_params = n_peaks * n_params_per_peak + n_ap_param?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Tom, nice catch. Sorry I missed this. I've now made this update.

self.adj_r_squared_ = 1 - (1 - self.r_squared_) * (n - 1) / (n - k - 1)


def _calc_error(self, metric=None):
Expand Down
6 changes: 3 additions & 3 deletions fooof/objs/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ class FOOOFGroup(FOOOF):
and the BW of the peak, is 2*std of the gaussian (as 'two sided' bandwidth).
- The FOOOFGroup object inherits from the FOOOF object. As such it also has data
attributes (`power_spectrum` & `fooofed_spectrum_`), and parameter attributes
(`aperiodic_params_`, `peak_params_`, `gaussian_params_`, `r_squared_`, `error_`)
which are defined in the context of individual model fits. These attributes are
(`aperiodic_params_`, `peak_params_`, `gaussian_params_`, `r_squared_`, `adj_r_squared_`,
`error_`) which are defined in the context of individual model fits. These attributes are
used during the fitting process, but in the group context do not store results
post-fitting. Rather, all model fit results are collected and stored into the
`group_results` attribute. To access individual parameters of the fit, use
Expand Down Expand Up @@ -334,7 +334,7 @@ def get_params(self, name, col=None):

Parameters
----------
name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'}
name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared', 'adj_r_squared'}
Name of the data field to extract across the group.
col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional
Column name / index to extract from selected data, if requested.
Expand Down
3 changes: 2 additions & 1 deletion fooof/objs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,11 @@ def average_fg(fg, bands, avg_method='mean', regenerate=True):

# Goodness of fit measures: extract & average
r2 = avg_func(fg.get_params('r_squared'))
adj_r2 = avg_func(fg.get_params('adj_r_squared'))
error = avg_func(fg.get_params('error'))

# Collect all results together, to be added to FOOOF object
results = FOOOFResults(ap_params, peak_params, r2, error, gauss_params)
results = FOOOFResults(ap_params, peak_params, r2, adj_r2, error, gauss_params)

# Create the new FOOOF object, with settings, data info & results
fm = FOOOF()
Expand Down
14 changes: 12 additions & 2 deletions fooof/plts/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from fooof.plts.templates import plot_scatter_1, plot_scatter_2, plot_hist
from fooof.plts.utils import savefig
from fooof.plts.style import style_plot
from fooof.plts.utils import check_ax

plt = safe_import('.pyplot', 'matplotlib')
gridspec = safe_import('.gridspec', 'matplotlib')
Expand Down Expand Up @@ -104,8 +105,17 @@ def plot_fg_gf(fg, ax=None, **plot_kwargs):
Keyword arguments to pass into the ``style_plot``.
"""

plot_scatter_2(fg.get_params('error'), 'Error',
fg.get_params('r_squared'), 'R^2', 'Goodness of Fit', ax=ax)
ax = check_ax(ax)
ax1 = ax.twinx()

plot_scatter_1(fg.get_params('error'), 'Error', 'Goodness of Fit', color='#1f77b4', ax=ax)
plot_scatter_1(fg.get_params('r_squared'), 'R^2', x_val=1, color='#1f77b4', ax=ax1)
plot_scatter_1(fg.get_params('adj_r_squared'), x_val=2, color='#1f77b4', ax=ax1)

ax.set(xlim=[-0.5, 2.5],
xticks=[0, 1, 2],
xticklabels=['Error', 'R^2', 'Adj. R^2'])
ax.tick_params(axis='x', labelsize=16)


@savefig
Expand Down
11 changes: 9 additions & 2 deletions fooof/plts/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
###################################################################################################

@check_dependency(plt, 'matplotlib')
def plot_scatter_1(data, label=None, title=None, x_val=0, ax=None):
def plot_scatter_1(data, label=None, title=None, x_val=0, color=None, ax=None):
"""Plot a scatter plot, with a single y-axis.

Parameters
Expand All @@ -30,6 +30,9 @@ def plot_scatter_1(data, label=None, title=None, x_val=0, ax=None):
Title for the plot.
x_val : int, optional, default: 0
Position along the x-axis to plot set of data.
color : color, optional, default: None
Color of data points plotted ('c' argument for pyplot.scatter).
None will use default color.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.

Expand All @@ -43,7 +46,11 @@ def plot_scatter_1(data, label=None, title=None, x_val=0, ax=None):
# Create x-axis data, with small jitter for visualization purposes
x_data = np.ones_like(data) * x_val + np.random.normal(0, 0.025, data.shape)

ax.scatter(x_data, data, s=36, alpha=set_alpha(len(data)))
# Plot the data
if color is None:
ax.scatter(x_data, data, s=36, alpha=set_alpha(len(data)))
else:
ax.scatter(x_data, data, s=36, alpha=set_alpha(len(data)), c=color)

if label:
ax.set_ylabel(label, fontsize=16)
Expand Down
2 changes: 1 addition & 1 deletion fooof/tests/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_fooof_meta_data():

def test_fooof_results():

results = FOOOFResults([1, 1], [10, 0.5, 1], 0.95, 0.05, [10, 0.5, 0.5])
results = FOOOFResults([1, 1], [10, 0.5, 1], 0.95, 0.94, 0.05, [10, 0.5, 0.5])
assert results

results_fields = OBJ_DESC['results']
Expand Down
4 changes: 2 additions & 2 deletions fooof/tests/objs/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def test_add_data():

# Test that prior data does not get cleared, when requesting not to clear
tfm._reset_data_results(True, True, True)
tfm.add_results(FOOOFResults([1, 1], [10, 0.5, 0.5], 0.95, 0.02, [10, 0.5, 0.25]))
tfm.add_results(FOOOFResults([1, 1], [10, 0.5, 0.5], 0.95, 0.94, 0.02, [10, 0.5, 0.25]))
tfm.add_data(freqs, pows, clear_results=False)
assert tfm.has_data
assert tfm.has_model
Expand Down Expand Up @@ -285,7 +285,7 @@ def test_add_results():
tfm = get_tfm()

# Test adding results
fooof_results = FOOOFResults([1, 1], [10, 0.5, 0.5], 0.95, 0.02, [10, 0.5, 0.25])
fooof_results = FOOOFResults([1, 1], [10, 0.5, 0.5], 0.95, 0.94, 0.02, [10, 0.5, 0.25])
tfm.add_results(fooof_results)
assert tfm.has_model
for setting in OBJ_DESC['results']:
Expand Down