Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Group shading #199

Merged
merged 9 commits into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ Plots for plotting power spectra with shaded regions.

plot_spectrum_shading
plot_spectra_shading
plot_spectra_yshade

Plot Model Properties & Parameters
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
79 changes: 79 additions & 0 deletions fooof/plts/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
This file contains functions for plotting power spectra, that take in data directly.
"""

from inspect import isfunction
from itertools import repeat, cycle

import numpy as np
from scipy.stats import sem

from fooof.core.modutils import safe_import, check_dependency
from fooof.plts.settings import PLT_FIGSIZES
Expand Down Expand Up @@ -172,3 +174,80 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r',

style_spectrum_plot(ax, plot_kwargs.get('log_freqs', False),
plot_kwargs.get('log_powers', False))


@savefig
@style_plot
@check_dependency(plt, 'matplotlib')
def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale=1,
log_freqs=False, log_powers=False, color=None, label=None,
ax=None, **plot_kwargs):
"""Plot standard deviation or error as a shaded region around the mean spectrum.

Parameters
----------
freqs : 1d array
Frequency values, to be plotted on the x-axis.
power_spectra : 1d or 2d array
Power values, to be plotted on the y-axis. ``shade`` must be provided if 1d.
shade : 'std', 'sem', 1d array or callable, optional, default: 'std'
Approach for shading above/below the mean spectrum.
average : 'mean', 'median' or callable, optional, default: 'mean'
Averaging approach for the average spectrum to plot. Only used if power_spectra is 2d.
scale : int, optional, default: 1
Factor to multiply the plotted shade by.
log_freqs : bool, optional, default: False
Whether to plot the frequency axis in log spacing.
log_powers : bool, optional, default: False
Whether to plot the power axis in log spacing.
color : str, optional, default: None
Line color of the spectrum.
label : str, optional, default: None
Legend label for the spectrum.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
**plot_kwargs
Keyword arguments to be passed to `plot_spectra` or to the plot call.
"""

if (isinstance(shade, str) or isfunction(shade)) and power_spectra.ndim != 2:
raise ValueError('Power spectra must be 2d if shade is not given.')

ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral']))

# Set plot data & labels, logging if requested
plt_freqs = np.log10(freqs) if log_freqs else freqs
plt_powers = np.log10(power_spectra) if log_powers else power_spectra

# Organize mean spectrum to plot
avg_funcs = {'mean' : np.mean, 'median' : np.median}

if isinstance(average, str) and plt_powers.ndim == 2:
avg_powers = avg_funcs[average](plt_powers, axis=0)
elif isfunction(average) and plt_powers.ndim == 2:
avg_powers = average(plt_powers)
else:
avg_powers = plt_powers

# Plot average power spectrum
ax.plot(plt_freqs, avg_powers, linewidth=2.0, color=color, label=label)

# Organize shading to plot
shade_funcs = {'std' : np.std, 'sem' : sem}

if isinstance(shade, str):
shade_vals = scale * shade_funcs[shade](plt_powers, axis=0)
elif isfunction(shade):
shade_vals = scale * shade(plt_powers)
else:
shade_vals = scale * shade

upper_shade = avg_powers + shade_vals
lower_shade = avg_powers - shade_vals

# Plot +/- yshading around spectrum
alpha = plot_kwargs.pop('alpha', 0.25)
ax.fill_between(plt_freqs, lower_shade, upper_shade,
alpha=alpha, color=color, **plot_kwargs)

style_spectrum_plot(ax, log_freqs, log_powers)
36 changes: 36 additions & 0 deletions fooof/tests/plts/test_spectra.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for fooof.plts.spectra."""

from pytest import raises

import numpy as np

from fooof.tests.tutils import plot_test
Expand Down Expand Up @@ -59,3 +61,37 @@ def test_plot_spectra_shading(tfg, skip_if_no_mpl):
shades=[8, 12], add_center=True, log_freqs=True, log_powers=True,
labels=['A', 'B'], save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_spectra_shading_kwargs.png')

@plot_test
def test_plot_spectra_yshade(skip_if_no_mpl, tfg):

freqs = tfg.freqs
powers = tfg.power_spectra

# Invalid 1d array, without shade
with raises(ValueError):
plot_spectra_yshade(freqs, powers[0])

# Plot with 2d array
plot_spectra_yshade(freqs, powers, shade='std',
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_spectra_yshade1.png')

# Plot shade with given 1d array
plot_spectra_yshade(freqs, np.mean(powers, axis=0),
shade=np.std(powers, axis=0),
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_spectra_yshade2.png')

# Plot shade with different average and shade approaches
plot_spectra_yshade(freqs, powers, shade='sem', average='median',
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_spectra_yshade3.png')

# Plot shade with custom average and shade callables
def _average_callable(powers): return np.mean(powers, axis=0)
def _shade_callable(powers): return np.std(powers, axis=0)

plot_spectra_yshade(freqs, powers, shade=_shade_callable, average=_average_callable,
log_powers=True, save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_spectra_yshade4.png')