Skip to content

Commit

Permalink
Merge branch 'name' into peakheights
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Sep 13, 2023
2 parents 6697b30 + 9d854fd commit 44744c7
Show file tree
Hide file tree
Showing 19 changed files with 308 additions and 73 deletions.
21 changes: 21 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,27 @@ Annotated plots that describe the model and fitting process.
plot_annotated_model
plot_annotated_peak_search

Plot Utilities & Styling
~~~~~~~~~~~~~~~~~~~~~~~~

Plot related utilies for styling and managing plots.

.. currentmodule:: fooof.plts.style

.. autosummary::
:toctree: generated/

check_style_options

.. currentmodule:: fooof.plts.utils

.. autosummary::
:toctree: generated/

check_ax
recursive_plot
save_figure

Utilities
---------

Expand Down
24 changes: 5 additions & 19 deletions specparam/core/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def gaussian_function(xs, *params):

ys = np.zeros_like(xs)

for ii in range(0, len(params), 3):

ctr, hgt, wid = params[ii:ii+3]
for ctr, hgt, wid in zip(*[iter(params)] * 3):

ys = ys + hgt * np.exp(-(xs-ctr)**2 / (2*wid**2))

Expand All @@ -60,11 +58,8 @@ def expo_function(xs, *params):
Output values for exponential function.
"""

ys = np.zeros_like(xs)

offset, knee, exp = params

ys = ys + offset - np.log10(knee + xs**exp)
ys = offset - np.log10(knee + xs**exp)

return ys

Expand All @@ -88,11 +83,8 @@ def expo_nk_function(xs, *params):
Output values for exponential function, without a knee.
"""

ys = np.zeros_like(xs)

offset, exp = params

ys = ys + offset - np.log10(xs**exp)
ys = offset - np.log10(xs**exp)

return ys

Expand All @@ -113,11 +105,8 @@ def linear_function(xs, *params):
Output values for linear function.
"""

ys = np.zeros_like(xs)

offset, slope = params

ys = ys + offset + (xs*slope)
ys = offset + (xs*slope)

return ys

Expand All @@ -138,11 +127,8 @@ def quadratic_function(xs, *params):
Output values for quadratic function.
"""

ys = np.zeros_like(xs)

offset, slope, curve = params

ys = ys + offset + (xs*slope) + ((xs**2)*curve)
ys = offset + (xs*slope) + ((xs**2)*curve)

return ys

Expand Down
103 changes: 103 additions & 0 deletions specparam/core/jacobians.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
""""Functions for computing Jacobian matrices to be used during fitting.
Notes
-----
These functions line up with those in `funcs`.
The parameters in these functions are labeled {a, b, c, ...}, but follow the order in `funcs`.
These functions are designed to be passed into `curve_fit` to provide a computed Jacobian.
"""

import numpy as np

###################################################################################################
###################################################################################################

## Periodic Jacobian functions

def jacobian_gauss(xs, *params):
"""Create the Jacobian matrix for the Gaussian function.
Parameters
----------
xs : 1d array
Input x-axis values.
*params : float
Parameters for the function.
Returns
-------
jacobian : 2d array
Jacobian matrix, with shape [len(xs), n_params].
"""

jacobian = np.zeros((len(xs), len(params)))

for i, (a, b, c) in enumerate(zip(*[iter(params)] * 3)):

ax = -a + xs
ax2 = ax**2

c2 = c**2
c3 = c**3

exp = np.exp(-ax2 / (2 * c2))
exp_b = exp * b

ii = i * 3
jacobian[:, ii] = (exp_b * ax) / c2
jacobian[:, ii+1] = exp
jacobian[:, ii+2] = (exp_b * ax2) / c3

return jacobian


## Aperiodic Jacobian functions

def jacobian_expo(xs, *params):
"""Create the Jacobian matrix for the exponential function.
Parameters
----------
xs : 1d array
Input x-axis values.
*params : float
Parameters for the function.
Returns
-------
jacobian : 2d array
Jacobian matrix, with shape [len(xs), n_params].
"""

a, b, c = params

xs_c = xs**c
b_xs_c = xs_c + b

jacobian = np.ones((len(xs), len(params)))
jacobian[:, 1] = -1 / b_xs_c
jacobian[:, 2] = -(xs_c * np.log10(xs)) / b_xs_c

return jacobian


def jacobian_expo_nk(xs, *params):
"""Create the Jacobian matrix for the exponential no-knee function.
Parameters
----------
xs : 1d array
Input x-axis values.
*params : float
Parameters for the function.
Returns
-------
jacobian : 2d array
Jacobian matrix, with shape [len(xs), n_params].
"""

jacobian = np.ones((len(xs), len(params)))
jacobian[:, 1] = -np.log10(xs)

return jacobian
24 changes: 18 additions & 6 deletions specparam/objs/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from specparam.core.modutils import copy_doc_func_to_method
from specparam.core.utils import group_three, check_array_dim
from specparam.core.funcs import gaussian_function, get_ap_func, infer_ap_func
from specparam.core.jacobians import jacobian_gauss
from specparam.core.errors import (FitError, NoModelError, DataError,
NoDataError, InconsistentDataError)
from specparam.core.strings import (gen_settings_str, gen_model_results_str,
Expand Down Expand Up @@ -191,12 +192,17 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h
self._gauss_overlap_thresh = 0.75
# Parameter bounds for center frequency when fitting gaussians, in terms of +/- std dev
self._cf_bound = 1.5
# The maximum number of calls to the curve fitting function
self._maxfev = 5000
# The error metric to calculate, post model fitting. See `_calc_error` for options
# Note: this is for checking error post fitting, not an objective function for fitting
self._error_metric = 'MAE'

## PRIVATE CURVE_FIT SETTINGS
# The maximum number of calls to the curve fitting function
self._maxfev = 5000
# The tolerance setting for curve fitting (see scipy.curve_fit - ftol / xtol / gtol)
# Here reduce tolerance to speed fitting. Set value to 1e-8 to match curve_fit default
self._tol = 0.00001

## RUN MODES
# Set default debug mode - controls if an error is raised if model fitting is unsuccessful
self._debug = False
Expand Down Expand Up @@ -400,7 +406,7 @@ def report(self, freqs=None, power_spectrum=None, freq_range=None,
Only relevant / effective if `freqs` and `power_spectrum` passed in in this call.
**plot_kwargs
Keyword arguments to pass into the plot method.
Plot options with a name conflict be passed by pre-pending 'plot_'.
Plot options with a name conflict be passed by pre-pending `plot_`.
e.g. `freqs`, `power_spectrum` and `freq_range`.
Notes
Expand Down Expand Up @@ -944,7 +950,9 @@ def _simple_ap_fit(self, freqs, power_spectrum):
warnings.simplefilter("ignore")
aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode),
freqs, power_spectrum, p0=guess,
maxfev=self._maxfev, bounds=ap_bounds)
maxfev=self._maxfev, bounds=ap_bounds,
ftol=self._tol, xtol=self._tol, gtol=self._tol,
check_finite=False)
except RuntimeError as excp:
error_msg = ("Model fitting failed due to not finding parameters in "
"the simple aperiodic component fit.")
Expand Down Expand Up @@ -1001,7 +1009,9 @@ def _robust_ap_fit(self, freqs, power_spectrum):
warnings.simplefilter("ignore")
aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode),
freqs_ignore, spectrum_ignore, p0=popt,
maxfev=self._maxfev, bounds=ap_bounds)
maxfev=self._maxfev, bounds=ap_bounds,
ftol=self._tol, xtol=self._tol, gtol=self._tol,
check_finite=False)
except RuntimeError as excp:
error_msg = ("Model fitting failed due to not finding "
"parameters in the robust aperiodic fit.")
Expand Down Expand Up @@ -1147,7 +1157,9 @@ def _fit_peak_guess(self, guess):
# Fit the peaks
try:
gaussian_params, _ = curve_fit(gaussian_function, self.freqs, self._spectrum_flat,
p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds)
p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds,
ftol=self._tol, xtol=self._tol, gtol=self._tol,
check_finite=False, jac=jacobian_gauss)
except RuntimeError as excp:
error_msg = ("Model fitting failed due to not finding "
"parameters in the peak component fit.")
Expand Down
4 changes: 2 additions & 2 deletions specparam/plts/aperiodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs)
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
**plot_kwargs
Keyword arguments to pass into the ``style_plot``.
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
"""

ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))
Expand Down Expand Up @@ -83,7 +83,7 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
**plot_kwargs
Keyword arguments to pass into the ``style_plot``.
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
"""

ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))
Expand Down
2 changes: 1 addition & 1 deletion specparam/plts/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def plot_spectral_error(freqs, error, shade=None, log_freqs=False, ax=None, **pl
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
**plot_kwargs
Keyword arguments to pass into the ``style_plot``.
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
"""

ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral']))
Expand Down
8 changes: 4 additions & 4 deletions specparam/plts/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def plot_group(group, **plot_kwargs):
group : SpectralGroupModel
Object containing results from fitting a group of power spectra.
**plot_kwargs
Keyword arguments to apply to the plot.
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
Raises
------
Expand Down Expand Up @@ -72,7 +72,7 @@ def plot_group_aperiodic(group, ax=None, **plot_kwargs):
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
**plot_kwargs
Keyword arguments to pass into the ``style_plot``.
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
"""

if group.aperiodic_mode == 'knee':
Expand All @@ -97,7 +97,7 @@ def plot_group_goodness(group, ax=None, **plot_kwargs):
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
**plot_kwargs
Keyword arguments to pass into the ``style_plot``.
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
"""

plot_scatter_2(group.get_params('error'), 'Error',
Expand All @@ -117,7 +117,7 @@ def plot_group_peak_frequencies(group, ax=None, **plot_kwargs):
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
**plot_kwargs
Keyword arguments to pass into the ``style_plot``.
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
"""

plot_hist(group.get_params('peak_params', 0)[:, 0], 'Center Frequency',
Expand Down
4 changes: 2 additions & 2 deletions specparam/plts/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def plot_model(model, plot_peaks=None, plot_aperiodic=True, freqs=None, power_sp
data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs : None or dict, optional
Keyword arguments to pass into the plot call for each plot element.
**plot_kwargs
Keyword arguments to apply to the plot.
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
Notes
-----
Expand Down Expand Up @@ -163,7 +163,7 @@ def _add_peaks_shade(model, plt_log, ax, **plot_kwargs):
ax : matplotlib.Axes
Figure axes upon which to plot.
**plot_kwargs
Keyword arguments to pass into the ``fill_between``.
Keyword arguments to pass into ``fill_between``.
"""

defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25}
Expand Down
4 changes: 2 additions & 2 deletions specparam/plts/periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax=None,
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
**plot_kwargs
Keyword arguments to pass into the ``style_plot``.
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
"""

ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))
Expand Down Expand Up @@ -86,7 +86,7 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, **
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
**plot_kwargs
Keyword arguments to pass into the plot call.
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
"""

ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))
Expand Down
12 changes: 9 additions & 3 deletions specparam/plts/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
'linestyle' : ['ls', 'linestyle']}

# Plot style arguments are those that can be defined on an axis object
AXIS_STYLE_ARGS = ['title', 'xlabel', 'ylabel', 'xlim', 'ylim']
AXIS_STYLE_ARGS = ['title', 'xlabel', 'ylabel', 'xlim', 'ylim',
'xticks', 'yticks', 'xticklabels', 'yticklabels']

# Line style arguments are those that can be defined on a line object
LINE_STYLE_ARGS = ['alpha', 'lw', 'linewidth', 'ls', 'linestyle',
Expand All @@ -40,8 +41,13 @@
# Custom style arguments are those that are custom-handled by the plot style function
CUSTOM_STYLE_ARGS = ['title_fontsize', 'label_size', 'tick_labelsize',
'legend_size', 'legend_loc']
STYLERS = ['axis_styler', 'line_styler', 'custom_styler']
STYLE_ARGS = AXIS_STYLE_ARGS + LINE_STYLE_ARGS + CUSTOM_STYLE_ARGS + STYLERS

# Define list of available style functions - these can also be replaced by arguments
STYLERS = ['axis_styler', 'line_styler', 'collection_styler', 'custom_styler']

# Collect the full set of possible style related input keyword arguments
STYLE_ARGS = \
AXIS_STYLE_ARGS + LINE_STYLE_ARGS + COLLECTION_STYLE_ARGS + CUSTOM_STYLE_ARGS + STYLERS

## Define default values for plot aesthetics
# These are all custom style arguments
Expand Down
Loading

0 comments on commit 44744c7

Please sign in to comment.