diff --git a/doc/api.rst b/doc/api.rst index d3aedc10..fe851486 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -259,6 +259,7 @@ Plots for plotting power spectra with shaded regions. plot_spectrum_shading plot_spectra_shading + plot_spectra_yshade Plot Model Properties & Parameters ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/fooof/plts/spectra.py b/fooof/plts/spectra.py index 9c9f4ab5..af90733b 100644 --- a/fooof/plts/spectra.py +++ b/fooof/plts/spectra.py @@ -5,9 +5,11 @@ This file contains functions for plotting power spectra, that take in data directly. """ +from inspect import isfunction from itertools import repeat, cycle import numpy as np +from scipy.stats import sem from fooof.core.modutils import safe_import, check_dependency from fooof.plts.settings import PLT_FIGSIZES @@ -172,3 +174,80 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', style_spectrum_plot(ax, plot_kwargs.get('log_freqs', False), plot_kwargs.get('log_powers', False)) + + +@savefig +@style_plot +@check_dependency(plt, 'matplotlib') +def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale=1, + log_freqs=False, log_powers=False, color=None, label=None, + ax=None, **plot_kwargs): + """Plot standard deviation or error as a shaded region around the mean spectrum. + + Parameters + ---------- + freqs : 1d array + Frequency values, to be plotted on the x-axis. + power_spectra : 1d or 2d array + Power values, to be plotted on the y-axis. ``shade`` must be provided if 1d. + shade : 'std', 'sem', 1d array or callable, optional, default: 'std' + Approach for shading above/below the mean spectrum. + average : 'mean', 'median' or callable, optional, default: 'mean' + Averaging approach for the average spectrum to plot. Only used if power_spectra is 2d. + scale : int, optional, default: 1 + Factor to multiply the plotted shade by. + log_freqs : bool, optional, default: False + Whether to plot the frequency axis in log spacing. + log_powers : bool, optional, default: False + Whether to plot the power axis in log spacing. + color : str, optional, default: None + Line color of the spectrum. + label : str, optional, default: None + Legend label for the spectrum. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Keyword arguments to be passed to `plot_spectra` or to the plot call. + """ + + if (isinstance(shade, str) or isfunction(shade)) and power_spectra.ndim != 2: + raise ValueError('Power spectra must be 2d if shade is not given.') + + ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) + + # Set plot data & labels, logging if requested + plt_freqs = np.log10(freqs) if log_freqs else freqs + plt_powers = np.log10(power_spectra) if log_powers else power_spectra + + # Organize mean spectrum to plot + avg_funcs = {'mean' : np.mean, 'median' : np.median} + + if isinstance(average, str) and plt_powers.ndim == 2: + avg_powers = avg_funcs[average](plt_powers, axis=0) + elif isfunction(average) and plt_powers.ndim == 2: + avg_powers = average(plt_powers) + else: + avg_powers = plt_powers + + # Plot average power spectrum + ax.plot(plt_freqs, avg_powers, linewidth=2.0, color=color, label=label) + + # Organize shading to plot + shade_funcs = {'std' : np.std, 'sem' : sem} + + if isinstance(shade, str): + shade_vals = scale * shade_funcs[shade](plt_powers, axis=0) + elif isfunction(shade): + shade_vals = scale * shade(plt_powers) + else: + shade_vals = scale * shade + + upper_shade = avg_powers + shade_vals + lower_shade = avg_powers - shade_vals + + # Plot +/- yshading around spectrum + alpha = plot_kwargs.pop('alpha', 0.25) + ax.fill_between(plt_freqs, lower_shade, upper_shade, + alpha=alpha, color=color, **plot_kwargs) + + style_spectrum_plot(ax, log_freqs, log_powers) diff --git a/fooof/tests/plts/test_spectra.py b/fooof/tests/plts/test_spectra.py index 0b85e2d9..75301ae2 100644 --- a/fooof/tests/plts/test_spectra.py +++ b/fooof/tests/plts/test_spectra.py @@ -1,5 +1,7 @@ """Tests for fooof.plts.spectra.""" +from pytest import raises + import numpy as np from fooof.tests.tutils import plot_test @@ -59,3 +61,37 @@ def test_plot_spectra_shading(tfg, skip_if_no_mpl): shades=[8, 12], add_center=True, log_freqs=True, log_powers=True, labels=['A', 'B'], save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_shading_kwargs.png') + +@plot_test +def test_plot_spectra_yshade(skip_if_no_mpl, tfg): + + freqs = tfg.freqs + powers = tfg.power_spectra + + # Invalid 1d array, without shade + with raises(ValueError): + plot_spectra_yshade(freqs, powers[0]) + + # Plot with 2d array + plot_spectra_yshade(freqs, powers, shade='std', + save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectra_yshade1.png') + + # Plot shade with given 1d array + plot_spectra_yshade(freqs, np.mean(powers, axis=0), + shade=np.std(powers, axis=0), + save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectra_yshade2.png') + + # Plot shade with different average and shade approaches + plot_spectra_yshade(freqs, powers, shade='sem', average='median', + save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectra_yshade3.png') + + # Plot shade with custom average and shade callables + def _average_callable(powers): return np.mean(powers, axis=0) + def _shade_callable(powers): return np.std(powers, axis=0) + + plot_spectra_yshade(freqs, powers, shade=_shade_callable, average=_average_callable, + log_powers=True, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectra_yshade4.png')