From bf66b098859230920e0fe446dcfca3afa18928c1 Mon Sep 17 00:00:00 2001 From: ryanhammonds Date: Fri, 5 Mar 2021 17:47:53 -0800 Subject: [PATCH 1/8] group shading --- fooof/plts/error.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/fooof/plts/error.py b/fooof/plts/error.py index c870900b..c7e93b54 100644 --- a/fooof/plts/error.py +++ b/fooof/plts/error.py @@ -53,3 +53,48 @@ def plot_spectral_error(freqs, error, shade=None, log_freqs=False, check_n_style(plot_style, ax, log_freqs, True) ax.set_ylabel('Absolute Error') + + +@check_dependency(plt, 'matplotlib') +def plot_shade_spectra(freqs, power_spectra, shade=None, log_freqs=False, log_powers=False, + ax=None, plot_style=style_spectrum_plot, **plot_kwargs): + """Plot error or standard deviation as a shaded region. + + Parameters + ---------- + freqs : 1d array + Frequency values, to be plotted on the x-axis. + power_spectra : 2d array + Power values, to be plotted on the y-axis. + shade : 1d array, optional, default: None + Values to shade in around the plotted error. None defaults to standard deviation. + log_freqs : bool, optional, default: False + Whether to plot the frequency axis in log spacing. + log_powers : bool, optional, default: False + Whether to plot the power axis in log spacing. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + plot_style : callable, optional, default: style_spectrum_plot + A function to call to apply styling & aesthetics to the plot. + **plot_kwargs + Keyword arguments to be passed to `plot_spectra` or to the plot call. + """ + + ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) + + # Set plot data & labels, logging if requested + plt_freqs = np.log10(freqs) if log_freqs else freqs + plt_powers = np.log10(power_spectra) if log_powers else power_spectra + + # Shade +/- 1 standard deviation + powers_mean = np.mean(plt_powers, axis=0) + + shade = np.std(plt_powers, axis=0) if shade is None else shade + upper_shade = powers_mean + shade + lower_shade = powers_mean - shade + + # Fill shade + alpha = plot_kwargs.pop('alpha', 0.25) + ax.fill_between(freqs, lower_shade, upper_shade, alpha=alpha) + + check_n_style(plot_style, ax, log_freqs, log_powers) From bd38137ea3185fd75677445693c0ff44f938f2b4 Mon Sep 17 00:00:00 2001 From: ryanhammonds Date: Mon, 8 Mar 2021 11:25:29 -0800 Subject: [PATCH 2/8] improved error shading --- fooof/plts/error.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/fooof/plts/error.py b/fooof/plts/error.py index c7e93b54..c6bac65b 100644 --- a/fooof/plts/error.py +++ b/fooof/plts/error.py @@ -56,18 +56,20 @@ def plot_spectral_error(freqs, error, shade=None, log_freqs=False, @check_dependency(plt, 'matplotlib') -def plot_shade_spectra(freqs, power_spectra, shade=None, log_freqs=False, log_powers=False, - ax=None, plot_style=style_spectrum_plot, **plot_kwargs): - """Plot error or standard deviation as a shaded region. +def plot_error_shade(freqs, power_spectra, shade=None, scale=1, log_freqs=False, + log_powers=False, ax=None, plot_style=style_spectrum_plot, **plot_kwargs): + """Plot standard deviation or error as a shaded region around the mean spectrum. Parameters ---------- freqs : 1d array Frequency values, to be plotted on the x-axis. - power_spectra : 2d array - Power values, to be plotted on the y-axis. + power_spectra : 1d or 2d array + Power values, to be plotted on the y-axis. ``shade`` must be provided if 1d. shade : 1d array, optional, default: None - Values to shade in around the plotted error. None defaults to standard deviation. + Powers to shade above/below the mean spectrum. None defaults to one standard deviation. + scale : int, optional, default: 1 + Factor to multiply the the standard deviation, or ``shade``, by. log_freqs : bool, optional, default: False Whether to plot the frequency axis in log spacing. log_powers : bool, optional, default: False @@ -80,21 +82,25 @@ def plot_shade_spectra(freqs, power_spectra, shade=None, log_freqs=False, log_po Keyword arguments to be passed to `plot_spectra` or to the plot call. """ + if shade is None and power_spectra.ndim != 2: + raise ValueError('Power spectra must be 2d if shade is not given.') + ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) # Set plot data & labels, logging if requested plt_freqs = np.log10(freqs) if log_freqs else freqs plt_powers = np.log10(power_spectra) if log_powers else power_spectra - # Shade +/- 1 standard deviation - powers_mean = np.mean(plt_powers, axis=0) + # Plot mean + powers_mean = np.mean(plt_powers, axis=0) if plt_powers.ndim == 2 else plt_powers + ax.plot(plt_freqs, powers_mean) - shade = np.std(plt_powers, axis=0) if shade is None else shade + # Shade +/- scale * (standard deviation or shade) + shade = scale * np.std(plt_powers, axis=0) if shade is None else scale * shade upper_shade = powers_mean + shade lower_shade = powers_mean - shade - # Fill shade alpha = plot_kwargs.pop('alpha', 0.25) - ax.fill_between(freqs, lower_shade, upper_shade, alpha=alpha) + ax.fill_between(plt_freqs, lower_shade, upper_shade, alpha=alpha, **plot_kwargs) check_n_style(plot_style, ax, log_freqs, log_powers) From 201869e2e00681283ce9db983dd67cdec6e86e0a Mon Sep 17 00:00:00 2001 From: ryanhammonds Date: Mon, 8 Mar 2021 11:25:41 -0800 Subject: [PATCH 3/8] error shading tests --- fooof/tests/plts/test_error.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/fooof/tests/plts/test_error.py b/fooof/tests/plts/test_error.py index 2bffbc7a..3cb26c56 100644 --- a/fooof/tests/plts/test_error.py +++ b/fooof/tests/plts/test_error.py @@ -1,5 +1,7 @@ """Tests for fooof.plts.error.""" +from pytest import raises, mark, param + import numpy as np from fooof.tests.tutils import plot_test @@ -16,3 +18,20 @@ def test_plot_spectral_error(skip_if_no_mpl): errs = np.ones(len(fs)) plot_spectral_error(fs, errs) + + +@plot_test +def test_plot_error_shade(skip_if_no_mpl, tfg): + + freqs = tfg.freqs + powers = tfg.power_spectra + + # Invalid 1d array, without shade + with raises(ValueError): + plot_error_shade(freqs, powers[0]) + + # Valid 1d array with shade + plot_error_shade(freqs, np.mean(powers, axis=0), shade=np.std(powers, axis=0)) + + # 2d array + plot_error_shade(freqs, powers) From cc23f95474b1c42b56603ee971e87dc11a5bc53d Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 12 Apr 2021 00:00:01 -0400 Subject: [PATCH 4/8] update plt approach --- fooof/plts/error.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fooof/plts/error.py b/fooof/plts/error.py index 66df69bf..066c672e 100644 --- a/fooof/plts/error.py +++ b/fooof/plts/error.py @@ -54,9 +54,11 @@ def plot_spectral_error(freqs, error, shade=None, log_freqs=False, ax=None, **pl ax.set_ylabel('Absolute Error') +@savefig +@style_plot @check_dependency(plt, 'matplotlib') def plot_error_shade(freqs, power_spectra, shade=None, scale=1, log_freqs=False, - log_powers=False, ax=None, plot_style=style_spectrum_plot, **plot_kwargs): + log_powers=False, ax=None, **plot_kwargs): """Plot standard deviation or error as a shaded region around the mean spectrum. Parameters @@ -102,4 +104,4 @@ def plot_error_shade(freqs, power_spectra, shade=None, scale=1, log_freqs=False, alpha = plot_kwargs.pop('alpha', 0.25) ax.fill_between(plt_freqs, lower_shade, upper_shade, alpha=alpha, **plot_kwargs) - check_n_style(plot_style, ax, log_freqs, log_powers) + style_spectrum_plot(ax, log_freqs, log_powers) From 8b4c93e1a0a6c3794d544d1d5d5336408bdf9db5 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 12 Apr 2021 00:13:19 -0400 Subject: [PATCH 5/8] move yshade func --- fooof/plts/error.py | 53 -------------------------------- fooof/plts/spectra.py | 53 ++++++++++++++++++++++++++++++++ fooof/tests/plts/test_error.py | 23 -------------- fooof/tests/plts/test_spectra.py | 21 +++++++++++++ 4 files changed, 74 insertions(+), 76 deletions(-) diff --git a/fooof/plts/error.py b/fooof/plts/error.py index 066c672e..f7cbfdf7 100644 --- a/fooof/plts/error.py +++ b/fooof/plts/error.py @@ -52,56 +52,3 @@ def plot_spectral_error(freqs, error, shade=None, log_freqs=False, ax=None, **pl style_spectrum_plot(ax, log_freqs, True) ax.set_ylabel('Absolute Error') - - -@savefig -@style_plot -@check_dependency(plt, 'matplotlib') -def plot_error_shade(freqs, power_spectra, shade=None, scale=1, log_freqs=False, - log_powers=False, ax=None, **plot_kwargs): - """Plot standard deviation or error as a shaded region around the mean spectrum. - - Parameters - ---------- - freqs : 1d array - Frequency values, to be plotted on the x-axis. - power_spectra : 1d or 2d array - Power values, to be plotted on the y-axis. ``shade`` must be provided if 1d. - shade : 1d array, optional, default: None - Powers to shade above/below the mean spectrum. None defaults to one standard deviation. - scale : int, optional, default: 1 - Factor to multiply the the standard deviation, or ``shade``, by. - log_freqs : bool, optional, default: False - Whether to plot the frequency axis in log spacing. - log_powers : bool, optional, default: False - Whether to plot the power axis in log spacing. - ax : matplotlib.Axes, optional - Figure axes upon which to plot. - plot_style : callable, optional, default: style_spectrum_plot - A function to call to apply styling & aesthetics to the plot. - **plot_kwargs - Keyword arguments to be passed to `plot_spectra` or to the plot call. - """ - - if shade is None and power_spectra.ndim != 2: - raise ValueError('Power spectra must be 2d if shade is not given.') - - ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) - - # Set plot data & labels, logging if requested - plt_freqs = np.log10(freqs) if log_freqs else freqs - plt_powers = np.log10(power_spectra) if log_powers else power_spectra - - # Plot mean - powers_mean = np.mean(plt_powers, axis=0) if plt_powers.ndim == 2 else plt_powers - ax.plot(plt_freqs, powers_mean) - - # Shade +/- scale * (standard deviation or shade) - shade = scale * np.std(plt_powers, axis=0) if shade is None else scale * shade - upper_shade = powers_mean + shade - lower_shade = powers_mean - shade - - alpha = plot_kwargs.pop('alpha', 0.25) - ax.fill_between(plt_freqs, lower_shade, upper_shade, alpha=alpha, **plot_kwargs) - - style_spectrum_plot(ax, log_freqs, log_powers) diff --git a/fooof/plts/spectra.py b/fooof/plts/spectra.py index 9c9f4ab5..73cebe88 100644 --- a/fooof/plts/spectra.py +++ b/fooof/plts/spectra.py @@ -172,3 +172,56 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', style_spectrum_plot(ax, plot_kwargs.get('log_freqs', False), plot_kwargs.get('log_powers', False)) + + +@savefig +@style_plot +@check_dependency(plt, 'matplotlib') +def plot_spectra_yshade(freqs, power_spectra, shade=None, scale=1, log_freqs=False, + log_powers=False, ax=None, **plot_kwargs): + """Plot standard deviation or error as a shaded region around the mean spectrum. + + Parameters + ---------- + freqs : 1d array + Frequency values, to be plotted on the x-axis. + power_spectra : 1d or 2d array + Power values, to be plotted on the y-axis. ``shade`` must be provided if 1d. + shade : 1d array, optional, default: None + Powers to shade above/below the mean spectrum. None defaults to one standard deviation. + scale : int, optional, default: 1 + Factor to multiply the the standard deviation, or ``shade``, by. + log_freqs : bool, optional, default: False + Whether to plot the frequency axis in log spacing. + log_powers : bool, optional, default: False + Whether to plot the power axis in log spacing. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + plot_style : callable, optional, default: style_spectrum_plot + A function to call to apply styling & aesthetics to the plot. + **plot_kwargs + Keyword arguments to be passed to `plot_spectra` or to the plot call. + """ + + if shade is None and power_spectra.ndim != 2: + raise ValueError('Power spectra must be 2d if shade is not given.') + + ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) + + # Set plot data & labels, logging if requested + plt_freqs = np.log10(freqs) if log_freqs else freqs + plt_powers = np.log10(power_spectra) if log_powers else power_spectra + + # Plot mean + powers_mean = np.mean(plt_powers, axis=0) if plt_powers.ndim == 2 else plt_powers + ax.plot(plt_freqs, powers_mean) + + # Shade +/- scale * (standard deviation or shade) + shade = scale * np.std(plt_powers, axis=0) if shade is None else scale * shade + upper_shade = powers_mean + shade + lower_shade = powers_mean - shade + + alpha = plot_kwargs.pop('alpha', 0.25) + ax.fill_between(plt_freqs, lower_shade, upper_shade, alpha=alpha, **plot_kwargs) + + style_spectrum_plot(ax, log_freqs, log_powers) diff --git a/fooof/tests/plts/test_error.py b/fooof/tests/plts/test_error.py index eefea9c8..3e8b817b 100644 --- a/fooof/tests/plts/test_error.py +++ b/fooof/tests/plts/test_error.py @@ -1,7 +1,5 @@ """Tests for fooof.plts.error.""" -from pytest import raises, mark, param - import numpy as np from fooof.tests.tutils import plot_test @@ -20,24 +18,3 @@ def test_plot_spectral_error(skip_if_no_mpl): plot_spectral_error(fs, errs, save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectral_error.png') - - - -@plot_test -def test_plot_error_shade(skip_if_no_mpl, tfg): - - freqs = tfg.freqs - powers = tfg.power_spectra - - # Invalid 1d array, without shade - with raises(ValueError): - plot_error_shade(freqs, powers[0]) - - # Valid 1d array with shade - plot_error_shade(freqs, np.mean(powers, axis=0), shade=np.std(powers, axis=0), - save_fig=True, file_path=TEST_PLOTS_PATH, - file_name='test_plot_spectral_error_shade1.png') - - # 2d array - plot_error_shade(freqs, powers, save_fig=True, file_path=TEST_PLOTS_PATH, - file_name='test_plot_spectral_error_shade2.png') diff --git a/fooof/tests/plts/test_spectra.py b/fooof/tests/plts/test_spectra.py index 0b85e2d9..fe7e5177 100644 --- a/fooof/tests/plts/test_spectra.py +++ b/fooof/tests/plts/test_spectra.py @@ -1,5 +1,7 @@ """Tests for fooof.plts.spectra.""" +from pytest import raises + import numpy as np from fooof.tests.tutils import plot_test @@ -59,3 +61,22 @@ def test_plot_spectra_shading(tfg, skip_if_no_mpl): shades=[8, 12], add_center=True, log_freqs=True, log_powers=True, labels=['A', 'B'], save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_shading_kwargs.png') + +@plot_test +def test_plot_spectra_yshade(skip_if_no_mpl, tfg): + + freqs = tfg.freqs + powers = tfg.power_spectra + + # Invalid 1d array, without shade + with raises(ValueError): + plot_spectra_yshade(freqs, powers[0]) + + # Valid 1d array with shade + plot_spectra_yshade(freqs, np.mean(powers, axis=0), shade=np.std(powers, axis=0), + save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectra_yshade1.png') + + # 2d array + plot_spectra_yshade(freqs, powers, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectra_yshade2.png') From 2ec52b1f9ec90fa2fc1bda670f7f6971ff4d85a4 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 12 Apr 2021 00:23:50 -0400 Subject: [PATCH 6/8] add yshade func to API list --- doc/api.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/api.rst b/doc/api.rst index d3aedc10..fe851486 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -259,6 +259,7 @@ Plots for plotting power spectra with shaded regions. plot_spectrum_shading plot_spectra_shading + plot_spectra_yshade Plot Model Properties & Parameters ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From d78495e9ac2dae581ccb9200f2f6992e9d75e193 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 12 Apr 2021 00:55:16 -0400 Subject: [PATCH 7/8] update args for average & shade --- fooof/plts/spectra.py | 47 +++++++++++++++++++++----------- fooof/tests/plts/test_spectra.py | 15 +++++++--- 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/fooof/plts/spectra.py b/fooof/plts/spectra.py index 73cebe88..2022315e 100644 --- a/fooof/plts/spectra.py +++ b/fooof/plts/spectra.py @@ -8,6 +8,7 @@ from itertools import repeat, cycle import numpy as np +from scipy.stats import sem from fooof.core.modutils import safe_import, check_dependency from fooof.plts.settings import PLT_FIGSIZES @@ -177,8 +178,9 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', @savefig @style_plot @check_dependency(plt, 'matplotlib') -def plot_spectra_yshade(freqs, power_spectra, shade=None, scale=1, log_freqs=False, - log_powers=False, ax=None, **plot_kwargs): +def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale=1, + log_freqs=False, log_powers=False, color=None, label=None, + ax=None, **plot_kwargs): """Plot standard deviation or error as a shaded region around the mean spectrum. Parameters @@ -187,23 +189,27 @@ def plot_spectra_yshade(freqs, power_spectra, shade=None, scale=1, log_freqs=Fal Frequency values, to be plotted on the x-axis. power_spectra : 1d or 2d array Power values, to be plotted on the y-axis. ``shade`` must be provided if 1d. - shade : 1d array, optional, default: None - Powers to shade above/below the mean spectrum. None defaults to one standard deviation. + shade : 'std', 'sem', 1d array or callable, optional, default: 'std' + Approach for shading above/below the mean spectrum. + average : 'mean', 'median' or callable, optional, default: 'mean' + Averaging approach for the average spectrum to plot. Only used if power_spectra is 2d. scale : int, optional, default: 1 - Factor to multiply the the standard deviation, or ``shade``, by. + Factor to multiply the plotted shade by. log_freqs : bool, optional, default: False Whether to plot the frequency axis in log spacing. log_powers : bool, optional, default: False Whether to plot the power axis in log spacing. + color : str, optional, default: None + Line color of the spectrum. + label : str, optional, default: None + Legend label for the spectrum. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_spectrum_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs Keyword arguments to be passed to `plot_spectra` or to the plot call. """ - if shade is None and power_spectra.ndim != 2: + if isinstance(shade, str) and power_spectra.ndim != 2: raise ValueError('Power spectra must be 2d if shade is not given.') ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) @@ -212,16 +218,25 @@ def plot_spectra_yshade(freqs, power_spectra, shade=None, scale=1, log_freqs=Fal plt_freqs = np.log10(freqs) if log_freqs else freqs plt_powers = np.log10(power_spectra) if log_powers else power_spectra - # Plot mean - powers_mean = np.mean(plt_powers, axis=0) if plt_powers.ndim == 2 else plt_powers - ax.plot(plt_freqs, powers_mean) + # Organize mean spectrum to plot + avg_funcs = {'mean' : np.mean, 'median' : np.median} + avg_func = avg_funcs[average] if isinstance(average, str) else average + avg_powers = avg_func(plt_powers, axis=0) if plt_powers.ndim == 2 else plt_powers + + # Plot average power spectrum + ax.plot(plt_freqs, avg_powers, linewidth=2.0, color=color, label=label) - # Shade +/- scale * (standard deviation or shade) - shade = scale * np.std(plt_powers, axis=0) if shade is None else scale * shade - upper_shade = powers_mean + shade - lower_shade = powers_mean - shade + # Organize shading to plot + shade_funcs = {'std' : np.std, 'sem' : sem} + shade_func = shade_funcs[shade] if isinstance(shade, str) else shade + shade_vals = scale * shade_func(plt_powers, axis=0) \ + if isinstance(shade, str) else scale * shade + upper_shade = avg_powers + shade_vals + lower_shade = avg_powers - shade_vals + # Plot +/- yshading around spectrum alpha = plot_kwargs.pop('alpha', 0.25) - ax.fill_between(plt_freqs, lower_shade, upper_shade, alpha=alpha, **plot_kwargs) + ax.fill_between(plt_freqs, lower_shade, upper_shade, + alpha=alpha, color=color, **plot_kwargs) style_spectrum_plot(ax, log_freqs, log_powers) diff --git a/fooof/tests/plts/test_spectra.py b/fooof/tests/plts/test_spectra.py index fe7e5177..343e7961 100644 --- a/fooof/tests/plts/test_spectra.py +++ b/fooof/tests/plts/test_spectra.py @@ -72,11 +72,18 @@ def test_plot_spectra_yshade(skip_if_no_mpl, tfg): with raises(ValueError): plot_spectra_yshade(freqs, powers[0]) - # Valid 1d array with shade - plot_spectra_yshade(freqs, np.mean(powers, axis=0), shade=np.std(powers, axis=0), + # Plot with 2d array + plot_spectra_yshade(freqs, powers, shade='std', save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_yshade1.png') - # 2d array - plot_spectra_yshade(freqs, powers, save_fig=True, file_path=TEST_PLOTS_PATH, + # Plot shade with given 1d array + plot_spectra_yshade(freqs, np.mean(powers, axis=0), + shade=np.std(powers, axis=0), + save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_yshade2.png') + + # Plot shade with different average and shade approaches + plot_spectra_yshade(freqs, powers, shade='sem', average='median', + save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectra_yshade3.png') From 849d9db833fd94cbd58bf80ef0e8d2888f08e819 Mon Sep 17 00:00:00 2001 From: ryanhammonds Date: Mon, 12 Apr 2021 14:10:33 -0700 Subject: [PATCH 8/8] custom callables fix --- fooof/plts/spectra.py | 23 +++++++++++++++++------ fooof/tests/plts/test_spectra.py | 8 ++++++++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/fooof/plts/spectra.py b/fooof/plts/spectra.py index 2022315e..af90733b 100644 --- a/fooof/plts/spectra.py +++ b/fooof/plts/spectra.py @@ -5,6 +5,7 @@ This file contains functions for plotting power spectra, that take in data directly. """ +from inspect import isfunction from itertools import repeat, cycle import numpy as np @@ -209,7 +210,7 @@ def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale Keyword arguments to be passed to `plot_spectra` or to the plot call. """ - if isinstance(shade, str) and power_spectra.ndim != 2: + if (isinstance(shade, str) or isfunction(shade)) and power_spectra.ndim != 2: raise ValueError('Power spectra must be 2d if shade is not given.') ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) @@ -220,17 +221,27 @@ def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale # Organize mean spectrum to plot avg_funcs = {'mean' : np.mean, 'median' : np.median} - avg_func = avg_funcs[average] if isinstance(average, str) else average - avg_powers = avg_func(plt_powers, axis=0) if plt_powers.ndim == 2 else plt_powers + + if isinstance(average, str) and plt_powers.ndim == 2: + avg_powers = avg_funcs[average](plt_powers, axis=0) + elif isfunction(average) and plt_powers.ndim == 2: + avg_powers = average(plt_powers) + else: + avg_powers = plt_powers # Plot average power spectrum ax.plot(plt_freqs, avg_powers, linewidth=2.0, color=color, label=label) # Organize shading to plot shade_funcs = {'std' : np.std, 'sem' : sem} - shade_func = shade_funcs[shade] if isinstance(shade, str) else shade - shade_vals = scale * shade_func(plt_powers, axis=0) \ - if isinstance(shade, str) else scale * shade + + if isinstance(shade, str): + shade_vals = scale * shade_funcs[shade](plt_powers, axis=0) + elif isfunction(shade): + shade_vals = scale * shade(plt_powers) + else: + shade_vals = scale * shade + upper_shade = avg_powers + shade_vals lower_shade = avg_powers - shade_vals diff --git a/fooof/tests/plts/test_spectra.py b/fooof/tests/plts/test_spectra.py index 343e7961..75301ae2 100644 --- a/fooof/tests/plts/test_spectra.py +++ b/fooof/tests/plts/test_spectra.py @@ -87,3 +87,11 @@ def test_plot_spectra_yshade(skip_if_no_mpl, tfg): plot_spectra_yshade(freqs, powers, shade='sem', average='median', save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_yshade3.png') + + # Plot shade with custom average and shade callables + def _average_callable(powers): return np.mean(powers, axis=0) + def _shade_callable(powers): return np.std(powers, axis=0) + + plot_spectra_yshade(freqs, powers, shade=_shade_callable, average=_average_callable, + log_powers=True, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectra_yshade4.png')