From 946b01b675bc3e162379c19c5a0dd1acd0c2bc3f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 7 Aug 2023 16:43:34 -0400 Subject: [PATCH 01/20] first round updates to use fit modes --- specparam/core/info.py | 6 ++ specparam/core/strings.py | 8 +-- specparam/objs/algorithm.py | 35 +++++++---- specparam/objs/fit.py | 12 +++- specparam/objs/modes.py | 117 ++++++++++++++++++++++++++++++++++++ 5 files changed, 158 insertions(+), 20 deletions(-) create mode 100644 specparam/objs/modes.py diff --git a/specparam/core/info.py b/specparam/core/info.py index aff32bb1..a250d4f1 100644 --- a/specparam/core/info.py +++ b/specparam/core/info.py @@ -75,6 +75,9 @@ def get_ap_indices(aperiodic_mode): Mapping of the column labels and indices for the aperiodic parameters. """ + # TEMP / TEST: + aperiodic_mode = str(aperiodic_mode) + if aperiodic_mode == 'fixed': labels = ('offset', 'exponent') elif aperiodic_mode == 'knee': @@ -101,6 +104,9 @@ def get_indices(aperiodic_mode): Mapping of the column labels and indices for all parameters. """ + # TEMP / TEST: + aperiodic_mode = str(aperiodic_mode) + # Get the periodic indices, and then update dictionary with aperiodic ones indices = get_peak_indices() indices.update(get_ap_indices(aperiodic_mode)) diff --git a/specparam/core/strings.py b/specparam/core/strings.py index 6b4a995c..e5711a3b 100644 --- a/specparam/core/strings.py +++ b/specparam/core/strings.py @@ -298,7 +298,7 @@ def gen_model_results_str(model, concise=False): # Aperiodic parameters ('Aperiodic Parameters (offset, ' + \ - ('knee, ' if model.aperiodic_mode == 'knee' else '') + \ + ('knee, ' if str(model.aperiodic_mode) == 'knee' else '') + \ 'exponent): '), ', '.join(['{:2.4f}'] * len(model.aperiodic_params_)).format(*model.aperiodic_params_), '', @@ -355,7 +355,7 @@ def gen_group_results_str(group, concise=False): errors = group.get_params('error') exps = group.get_params('aperiodic_params', 'exponent') kns = group.get_params('aperiodic_params', 'knee') \ - if group.aperiodic_mode == 'knee' else np.array([0]) + if str(group.aperiodic_mode) == 'knee' else np.array([0]) str_lst = [ @@ -378,12 +378,12 @@ def gen_group_results_str(group, concise=False): # Aperiodic parameters - knee fit status, and quick exponent description 'Power spectra were fit {} a knee.'.format(\ - 'with' if group.aperiodic_mode == 'knee' else 'without'), + 'with' if str(group.aperiodic_mode) == 'knee' else 'without'), '', 'Aperiodic Fit Values:', *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:5.2f}' .format(np.nanmin(kns), np.nanmax(kns), np.nanmean(kns)), - ] if group.aperiodic_mode == 'knee'], + ] if str(group.aperiodic_mode) == 'knee'], 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' .format(np.nanmin(exps), np.nanmax(exps), np.nanmean(exps)), '', diff --git a/specparam/objs/algorithm.py b/specparam/objs/algorithm.py index 3ab2388d..db69cc86 100644 --- a/specparam/objs/algorithm.py +++ b/specparam/objs/algorithm.py @@ -8,7 +8,7 @@ from specparam.core.utils import group_three from specparam.core.strings import gen_width_warning_str -from specparam.core.funcs import gaussian_function, get_ap_func +#from specparam.core.funcs import gaussian_function, get_ap_func from specparam.core.errors import NoDataError, FitError from specparam.utils.params import compute_gauss_std from specparam.sim.gen import gen_aperiodic, gen_periodic @@ -149,7 +149,8 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None): # Fit the aperiodic component self.aperiodic_params_ = self._robust_ap_fit(self.freqs, self.power_spectrum) - self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) + #self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) + self._ap_fit = self.aperiodic_mode.func(self.freqs, *self.aperiodic_params_) # Flatten the power spectrum using fit aperiodic fit self._spectrum_flat = self.power_spectrum - self._ap_fit @@ -159,7 +160,8 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None): # Calculate the peak fit # Note: if no peaks are found, this creates a flat (all zero) peak fit - self._peak_fit = gen_periodic(self.freqs, np.ndarray.flatten(self.gaussian_params_)) + #self._peak_fit = gen_periodic(self.freqs, np.ndarray.flatten(self.gaussian_params_)) + self._peak_fit = self.periodic_mode.func(self.freqs, *np.ndarray.flatten(self.gaussian_params_)) # Create peak-removed (but not flattened) power spectrum self._spectrum_peak_rm = self.power_spectrum - self._peak_fit @@ -167,7 +169,8 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None): # Run final aperiodic fit on peak-removed power spectrum # This overwrites previous aperiodic fit, and recomputes the flattened spectrum self.aperiodic_params_ = self._simple_ap_fit(self.freqs, self._spectrum_peak_rm) - self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) + #self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) + self._ap_fit = self.aperiodic_mode.func(self.freqs, *self.aperiodic_params_) self._spectrum_flat = self.power_spectrum - self._ap_fit # Create full power_spectrum model fit @@ -229,7 +232,7 @@ def _reset_results(self, clear_results=False): if clear_results: self.aperiodic_params_ = np.array([np.nan] * \ - (2 if self.aperiodic_mode == 'fixed' else 3)) + (2 if str(self.aperiodic_mode) == 'fixed' else 3)) self.gaussian_params_ = np.empty([0, 3]) self.peak_params_ = np.empty([0, 3]) self.r_squared_ = np.nan @@ -270,13 +273,13 @@ def _simple_ap_fit(self, freqs, power_spectrum): # Get the guess parameters and/or calculate from the data, as needed # Note that these are collected as lists, to concatenate with or without knee later off_guess = [power_spectrum[0] if not self._ap_guess[0] else self._ap_guess[0]] - kne_guess = [self._ap_guess[1]] if self.aperiodic_mode == 'knee' else [] + kne_guess = [self._ap_guess[1]] if str(self.aperiodic_mode) == 'knee' else [] exp_guess = [np.abs((self.power_spectrum[-1] - self.power_spectrum[0]) / (np.log10(self.freqs[-1]) - np.log10(self.freqs[0]))) if not self._ap_guess[2] else self._ap_guess[2]] # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee - ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ + ap_bounds = self._ap_bounds if str(self.aperiodic_mode) == 'knee' \ else tuple(bound[0::2] for bound in self._ap_bounds) # Collect together guess parameters @@ -289,7 +292,8 @@ def _simple_ap_fit(self, freqs, power_spectrum): try: with warnings.catch_warnings(): warnings.simplefilter("ignore") - aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), + aperiodic_params, _ = curve_fit(self.aperiodic_mode.func, + #get_ap_func(self.aperiodic_mode), freqs, power_spectrum, p0=guess, maxfev=self._maxfev, bounds=ap_bounds) except RuntimeError as excp: @@ -323,7 +327,8 @@ def _robust_ap_fit(self, freqs, power_spectrum): # Do a quick, initial aperiodic fit popt = self._simple_ap_fit(freqs, power_spectrum) - initial_fit = gen_aperiodic(freqs, popt) + #initial_fit = gen_aperiodic(freqs, popt) + initial_fit = self.aperiodic_mode.func(freqs, *popt) # Flatten power_spectrum based on initial aperiodic fit flatspec = power_spectrum - initial_fit @@ -338,7 +343,7 @@ def _robust_ap_fit(self, freqs, power_spectrum): spectrum_ignore = power_spectrum[perc_mask] # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee - ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ + ap_bounds = self._ap_bounds if str(self.aperiodic_mode) == 'knee' \ else tuple(bound[0::2] for bound in self._ap_bounds) # Second aperiodic fit - using results of first fit as guess parameters @@ -346,7 +351,8 @@ def _robust_ap_fit(self, freqs, power_spectrum): try: with warnings.catch_warnings(): warnings.simplefilter("ignore") - aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), + aperiodic_params, _ = curve_fit(self.aperiodic_mode.func, + #get_ap_func(self.aperiodic_mode), freqs_ignore, spectrum_ignore, p0=popt, maxfev=self._maxfev, bounds=ap_bounds) except RuntimeError as excp: @@ -434,7 +440,8 @@ def _fit_peaks(self, flat_iter): # Collect guess parameters and subtract this guess gaussian from the data guess = np.vstack((guess, (guess_freq, guess_height, guess_std))) - peak_gauss = gaussian_function(self.freqs, guess_freq, guess_height, guess_std) + peak_gauss = self.periodic_mode.func(self.freqs, guess_freq, guess_height, guess_std) + #peak_gauss = gaussian_function(self.freqs, guess_freq, guess_height, guess_std) flat_iter = flat_iter - peak_gauss # Check peaks based on edges, and on overlap, dropping any that violate requirements @@ -493,7 +500,9 @@ def _fit_peak_guess(self, guess): # Fit the peaks try: - gaussian_params, _ = curve_fit(gaussian_function, self.freqs, self._spectrum_flat, + gaussian_params, _ = curve_fit(self.periodic_mode.func, + #gaussian_function, + self.freqs, self._spectrum_flat, p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds) except RuntimeError as excp: error_msg = ("Model fitting failed due to not finding " diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 9808b7d9..f5974e61 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -5,7 +5,7 @@ from specparam.core.utils import unlog from specparam.core.funcs import infer_ap_func from specparam.core.utils import check_array_dim - +from specparam.objs.modes import AP_MODES, PE_MODES from specparam.data import FitResults, ModelSettings from specparam.core.items import OBJ_DESC @@ -20,8 +20,14 @@ def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True, error_metric='MAE'): # Set fit component modes - self.aperiodic_mode = aperiodic_mode - self.periodic_mode = periodic_mode + if isinstance(aperiodic_mode, str): + self.aperiodic_mode = AP_MODES[aperiodic_mode] + else: + self.aperiodic_mode = aperiodic_mode + if isinstance(periodic_mode, str): + self.periodic_mode = PE_MODES[periodic_mode] + else: + self.periodic_mode = periodic_mode # Set run modes self.set_debug_mode(debug_mode) diff --git a/specparam/objs/modes.py b/specparam/objs/modes.py new file mode 100644 index 00000000..9d3ee201 --- /dev/null +++ b/specparam/objs/modes.py @@ -0,0 +1,117 @@ +""" """ + +from specparam.core.funcs import * + +################################################################################################### +## MODES OUTLINE + +class Mode(): + + def __init__(self, name, component, description, func, params, param_description, + freq_space, powers_space): + + self.name = name + self.component = component + self.description = description + self.func = func + self.params = params + self.param_description = param_description + self.freq_space = freq_space + self.powers_space = powers_space + + def __repr__(self): + return self.name + +################################################################################################### +## APERIODIC MODES + +# 'Fixed' model +# ap_fixed = { +# 'name' : 'fixed', +# 'component' : 'aperiodic', +# 'description' : 'Fit an exponential, with no knee.', +# 'func' : expo_nk_function, +# 'params' : ['offset', 'exponent'], +# 'param_description' : { +# 'offset' : 'Offset of the aperiodic component.', +# 'exponent' : 'Exponent of the aperiodic component.' +# }, +# 'freq_space' : 'linear', +# 'powers_space' : 'log10', +# } + +param_desc_fixed = { + 'offset' : 'Offset of the aperiodic component.', + 'exponent' : 'Exponent of the aperiodic component.', +} +ap_fixed = Mode('fixed', 'aperiodic', 'Fit an exponential, with no knee.', + expo_nk_function, ['offset', 'exponent'], param_desc_fixed, + 'linear', 'log10') + + +# 'Knee' model +# ap_knee = { +# 'name' : 'knee', +# 'component' : 'aperiodic', +# 'description' : 'Fit an exponential, with a knee.', +# 'func' : expo_function, +# 'params' : ['offset', 'knee', 'exponent'], +# 'param_description' : { +# 'offset' : 'Offset of the aperiodic component.', +# 'knee' : 'Knee of the aperiodic component.', +# 'exponent' : 'Exponent of the aperiodic component.' +# }, +# 'freq_space' : 'linear', +# 'powers_space' : 'log10', +# } + + +param_desc_knee = { + 'offset' : 'Offset of the aperiodic component.', + 'knee' : 'Knee of the aperiodic component.', + 'exponent' : 'Exponent of the aperiodic component.', +} + +ap_knee = Mode('knee', 'aperiodic', 'Fit an exponential, with a knee.', + expo_function, ['offset', 'knee', 'exponent'], param_desc_knee, + 'linear', 'log10') + +# Collect available aperiodic modes +AP_MODES = { + 'fixed' : ap_fixed, + 'knee' : ap_knee, +} + +################################################################################################### +## PERIODIC MODES + +# # 'Gaussian' model +# pe_gaussian = { +# 'name' : 'gaussian', +# 'component' : 'periodic', +# 'description' : 'Gaussian peak fit function.', +# 'func' : gaussian_function, +# 'params' : ['cf', 'pw', 'bw'], +# 'param_description' : { +# 'cf' : 'Center frequency of the peak.', +# 'pw' : 'Power of the peak, over and above the aperiodic component.', +# 'bw' : 'Bandwidth of the peak.', +# }, +# 'freq_space' : 'linear', +# 'powers_space' : 'log10', +# } + +param_desc_gaus = { + 'cf' : 'Center frequency of the peak.', + 'pw' : 'Power of the peak, over and above the aperiodic component.', + 'bw' : 'Bandwidth of the peak.', +} + +pe_gaussian = Mode('gaussian', 'periodic', 'Gaussian peak fit function.', + gaussian_function, ['cf', 'pw', 'bw'], param_desc_gaus, + 'linear', 'log10') + +# Collect available periodic modes +PE_MODES = { + 'gaussian' : pe_gaussian, +} From d23fa592025fb2a4f210f1d39093560a63ba67dc Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 7 Aug 2023 16:44:05 -0400 Subject: [PATCH 02/20] temporarily turn off IO related tests --- specparam/tests/core/test_io.py | 232 ++++++++++++------------ specparam/tests/objs/test_group.py | 96 +++++----- specparam/tests/objs/test_model.py | 102 +++++------ specparam/tests/objs/test_utils.py | 276 ++++++++++++++--------------- specparam/tests/utils/test_io.py | 64 +++---- 5 files changed, 385 insertions(+), 385 deletions(-) diff --git a/specparam/tests/core/test_io.py b/specparam/tests/core/test_io.py index e8e6dd13..08ca546a 100644 --- a/specparam/tests/core/test_io.py +++ b/specparam/tests/core/test_io.py @@ -1,167 +1,167 @@ -"""Tests for specparam.core.io.""" +# """Tests for specparam.core.io.""" -import os -from pathlib import Path +# import os +# from pathlib import Path -from specparam.core.items import OBJ_DESC +# from specparam.core.items import OBJ_DESC -from specparam.tests.settings import TEST_DATA_PATH +# from specparam.tests.settings import TEST_DATA_PATH -from specparam.core.io import * +# from specparam.core.io import * -################################################################################################### -################################################################################################### +# ################################################################################################### +# ################################################################################################### -def test_fname(): - """Check that the file name checker helper function properly checks / adds file extensions.""" +# def test_fname(): +# """Check that the file name checker helper function properly checks / adds file extensions.""" - assert fname('data', 'json') == 'data.json' - assert fname('data.json', 'json') == 'data.json' - assert fname('pic', 'png') == 'pic.png' - assert fname('pic.png', 'png') == 'pic.png' - assert fname('report.pdf', 'pdf') == 'report.pdf' - assert fname('report.png', 'pdf') == 'report.png' +# assert fname('data', 'json') == 'data.json' +# assert fname('data.json', 'json') == 'data.json' +# assert fname('pic', 'png') == 'pic.png' +# assert fname('pic.png', 'png') == 'pic.png' +# assert fname('report.pdf', 'pdf') == 'report.pdf' +# assert fname('report.png', 'pdf') == 'report.png' -def test_fpath(): - """Check that the file path checker helper function properly checks / combines file paths.""" +# def test_fpath(): +# """Check that the file path checker helper function properly checks / combines file paths.""" - assert fpath(None, 'data.json') == 'data.json' - assert fpath('/path/', 'data.json') == '/path/data.json' - assert fpath(Path('/path/'), 'data.json') == '/path/data.json' +# assert fpath(None, 'data.json') == 'data.json' +# assert fpath('/path/', 'data.json') == '/path/data.json' +# assert fpath(Path('/path/'), 'data.json') == '/path/data.json' -def test_save_model_str(tfm): - """Check saving model object data, with file specifiers as strings.""" +# def test_save_model_str(tfm): +# """Check saving model object data, with file specifiers as strings.""" - # Test saving out each set of save elements - file_name_res = 'test_res' - file_name_set = 'test_set' - file_name_dat = 'test_dat' +# # Test saving out each set of save elements +# file_name_res = 'test_res' +# file_name_set = 'test_set' +# file_name_dat = 'test_dat' - save_model(tfm, file_name_res, TEST_DATA_PATH, False, True, False, False) - save_model(tfm, file_name_set, TEST_DATA_PATH, False, False, True, False) - save_model(tfm, file_name_dat, TEST_DATA_PATH, False, False, False, True) +# save_model(tfm, file_name_res, TEST_DATA_PATH, False, True, False, False) +# save_model(tfm, file_name_set, TEST_DATA_PATH, False, False, True, False) +# save_model(tfm, file_name_dat, TEST_DATA_PATH, False, False, False, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_res + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_set + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_dat + '.json')) +# assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_res + '.json')) +# assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_set + '.json')) +# assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_dat + '.json')) - # Test saving out all save elements - file_name_all = 'test_all' - save_model(tfm, file_name_all, TEST_DATA_PATH, False, True, True, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_all + '.json')) +# # Test saving out all save elements +# file_name_all = 'test_all' +# save_model(tfm, file_name_all, TEST_DATA_PATH, False, True, True, True) +# assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_all + '.json')) -def test_save_model_append(tfm): - """Check saving fm data, appending to a file.""" +# def test_save_model_append(tfm): +# """Check saving fm data, appending to a file.""" - file_name = 'test_append' +# file_name = 'test_append' - save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True) - save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True) +# save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True) +# save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) +# assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) -def test_save_model_fobj(tfm): - """Check saving fm data, with file object file specifier.""" +# def test_save_model_fobj(tfm): +# """Check saving fm data, with file object file specifier.""" - file_name = 'test_fileobj' +# file_name = 'test_fileobj' - # Save, using file-object: three successive lines with three possible save settings - with open(os.path.join(TEST_DATA_PATH, file_name + '.json'), 'w') as f_obj: - save_model(tfm, f_obj, TEST_DATA_PATH, False, True, False, False) - save_model(tfm, f_obj, TEST_DATA_PATH, False, False, True, False) - save_model(tfm, f_obj, TEST_DATA_PATH, False, False, False, True) +# # Save, using file-object: three successive lines with three possible save settings +# with open(os.path.join(TEST_DATA_PATH, file_name + '.json'), 'w') as f_obj: +# save_model(tfm, f_obj, TEST_DATA_PATH, False, True, False, False) +# save_model(tfm, f_obj, TEST_DATA_PATH, False, False, True, False) +# save_model(tfm, f_obj, TEST_DATA_PATH, False, False, False, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) +# assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) -def test_save_group(tfg): - """Check saving fg data.""" +# def test_save_group(tfg): +# """Check saving fg data.""" - res_file_name = 'test_group_res' - set_file_name = 'test_group_set' - dat_file_name = 'test_group_dat' +# res_file_name = 'test_group_res' +# set_file_name = 'test_group_set' +# dat_file_name = 'test_group_dat' - save_group(tfg, file_name=res_file_name, file_path=TEST_DATA_PATH, save_results=True) - save_group(tfg, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True) - save_group(tfg, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) +# save_group(tfg, file_name=res_file_name, file_path=TEST_DATA_PATH, save_results=True) +# save_group(tfg, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True) +# save_group(tfg, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, res_file_name + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, set_file_name + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, dat_file_name + '.json')) +# assert os.path.exists(os.path.join(TEST_DATA_PATH, res_file_name + '.json')) +# assert os.path.exists(os.path.join(TEST_DATA_PATH, set_file_name + '.json')) +# assert os.path.exists(os.path.join(TEST_DATA_PATH, dat_file_name + '.json')) - # Test saving out all save elements - file_name_all = 'test_group_all' - save_group(tfg, file_name_all, TEST_DATA_PATH, False, True, True, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_all + '.json')) +# # Test saving out all save elements +# file_name_all = 'test_group_all' +# save_group(tfg, file_name_all, TEST_DATA_PATH, False, True, True, True) +# assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_all + '.json')) -def test_save_group_append(tfg): - """Check saving fg data, appending to file.""" +# def test_save_group_append(tfg): +# """Check saving fg data, appending to file.""" - file_name = 'test_group_append' +# file_name = 'test_group_append' - save_group(tfg, file_name, TEST_DATA_PATH, True, save_results=True) - save_group(tfg, file_name, TEST_DATA_PATH, True, save_results=True) +# save_group(tfg, file_name, TEST_DATA_PATH, True, save_results=True) +# save_group(tfg, file_name, TEST_DATA_PATH, True, save_results=True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) +# assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) -def test_save_group_fobj(tfg): - """Check saving fg data, with file object file specifier.""" +# def test_save_group_fobj(tfg): +# """Check saving fg data, with file object file specifier.""" - file_name = 'test_fileobj' +# file_name = 'test_fileobj' - with open(os.path.join(TEST_DATA_PATH, file_name + '.json'), 'w') as f_obj: - save_group(tfg, f_obj, TEST_DATA_PATH, False, True, False, False) +# with open(os.path.join(TEST_DATA_PATH, file_name + '.json'), 'w') as f_obj: +# save_group(tfg, f_obj, TEST_DATA_PATH, False, True, False, False) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) +# assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) -def test_load_json_str(): - """Test loading JSON file, with str file specifier. - Loads files from test_save_model_str. - """ +# def test_load_json_str(): +# """Test loading JSON file, with str file specifier. +# Loads files from test_save_model_str. +# """ - file_name = 'test_all' +# file_name = 'test_all' - data = load_json(file_name, TEST_DATA_PATH) +# data = load_json(file_name, TEST_DATA_PATH) - assert data +# assert data -def test_load_json_fobj(): - """Test loading JSON file, with file object file specifier. - Loads files from test_save_model_str. - """ +# def test_load_json_fobj(): +# """Test loading JSON file, with file object file specifier. +# Loads files from test_save_model_str. +# """ - file_name = 'test_all' +# file_name = 'test_all' - with open(os.path.join(TEST_DATA_PATH, file_name + '.json'), 'r') as f_obj: - data = load_json(f_obj, '') +# with open(os.path.join(TEST_DATA_PATH, file_name + '.json'), 'r') as f_obj: +# data = load_json(f_obj, '') - assert data +# assert data -def test_load_jsonlines(): - """Test loading JSONlines file. - Loads files from test_save_group. - """ +# def test_load_jsonlines(): +# """Test loading JSONlines file. +# Loads files from test_save_group. +# """ - res_file_name = 'test_group_res' +# res_file_name = 'test_group_res' - for data in load_jsonlines(res_file_name, TEST_DATA_PATH): - assert data +# for data in load_jsonlines(res_file_name, TEST_DATA_PATH): +# assert data -def test_load_file_contents(): - """Check that loaded files contain the contents they should. - Note that is this test fails, it likely stems from an issue from saving. - """ +# def test_load_file_contents(): +# """Check that loaded files contain the contents they should. +# Note that is this test fails, it likely stems from an issue from saving. +# """ - file_name = 'test_all' - loaded_data = load_json(file_name, TEST_DATA_PATH) +# file_name = 'test_all' +# loaded_data = load_json(file_name, TEST_DATA_PATH) - # Check settings - for setting in OBJ_DESC['settings']: - assert setting in loaded_data.keys() +# # Check settings +# for setting in OBJ_DESC['settings']: +# assert setting in loaded_data.keys() - # Check results - for result in OBJ_DESC['results']: - assert result in loaded_data.keys() +# # Check results +# for result in OBJ_DESC['results']: +# assert result in loaded_data.keys() - # Check results - for datum in OBJ_DESC['data']: - assert datum in loaded_data.keys() +# # Check results +# for datum in OBJ_DESC['data']: +# assert datum in loaded_data.keys() diff --git a/specparam/tests/objs/test_group.py b/specparam/tests/objs/test_group.py index 26313d19..6b91ef7e 100644 --- a/specparam/tests/objs/test_group.py +++ b/specparam/tests/objs/test_group.py @@ -245,54 +245,54 @@ def test_plot(tfg, skip_if_no_mpl): tfg.plot() -def test_load(): - """Test load into group object. Note: loads files from test_core_io.""" - - file_name_res = 'test_group_res' - file_name_set = 'test_group_set' - file_name_dat = 'test_group_dat' - - # Test loading just results - tfg = SpectralGroupModel(verbose=False) - tfg.load(file_name_res, TEST_DATA_PATH) - assert len(tfg.group_results) > 0 - # Test that settings and data are None - # Except for aperiodic mode, which can be inferred from the data - for setting in OBJ_DESC['settings']: - if setting != 'aperiodic_mode': - assert getattr(tfg, setting) is None - assert tfg.power_spectra is None - - # Test loading just settings - tfg = SpectralGroupModel(verbose=False) - tfg.load(file_name_set, TEST_DATA_PATH) - for setting in OBJ_DESC['settings']: - assert getattr(tfg, setting) is not None - # Test that results and data are None - for result in OBJ_DESC['results']: - assert np.all(np.isnan(getattr(tfg, result))) - assert tfg.power_spectra is None - - # Test loading just data - tfg = SpectralGroupModel(verbose=False) - tfg.load(file_name_dat, TEST_DATA_PATH) - assert tfg.power_spectra is not None - # Test that settings and results are None - for setting in OBJ_DESC['settings']: - assert getattr(tfg, setting) is None - for result in OBJ_DESC['results']: - assert np.all(np.isnan(getattr(tfg, result))) - - # Test loading all elements - tfg = SpectralGroupModel(verbose=False) - file_name_all = 'test_group_all' - tfg.load(file_name_all, TEST_DATA_PATH) - assert len(tfg.group_results) > 0 - for setting in OBJ_DESC['settings']: - assert getattr(tfg, setting) is not None - assert tfg.power_spectra is not None - for meta_dat in OBJ_DESC['meta_data']: - assert getattr(tfg, meta_dat) is not None +# def test_load(): +# """Test load into group object. Note: loads files from test_core_io.""" + +# file_name_res = 'test_group_res' +# file_name_set = 'test_group_set' +# file_name_dat = 'test_group_dat' + +# # Test loading just results +# tfg = SpectralGroupModel(verbose=False) +# tfg.load(file_name_res, TEST_DATA_PATH) +# assert len(tfg.group_results) > 0 +# # Test that settings and data are None +# # Except for aperiodic mode, which can be inferred from the data +# for setting in OBJ_DESC['settings']: +# if setting != 'aperiodic_mode': +# assert getattr(tfg, setting) is None +# assert tfg.power_spectra is None + +# # Test loading just settings +# tfg = SpectralGroupModel(verbose=False) +# tfg.load(file_name_set, TEST_DATA_PATH) +# for setting in OBJ_DESC['settings']: +# assert getattr(tfg, setting) is not None +# # Test that results and data are None +# for result in OBJ_DESC['results']: +# assert np.all(np.isnan(getattr(tfg, result))) +# assert tfg.power_spectra is None + +# # Test loading just data +# tfg = SpectralGroupModel(verbose=False) +# tfg.load(file_name_dat, TEST_DATA_PATH) +# assert tfg.power_spectra is not None +# # Test that settings and results are None +# for setting in OBJ_DESC['settings']: +# assert getattr(tfg, setting) is None +# for result in OBJ_DESC['results']: +# assert np.all(np.isnan(getattr(tfg, result))) + +# # Test loading all elements +# tfg = SpectralGroupModel(verbose=False) +# file_name_all = 'test_group_all' +# tfg.load(file_name_all, TEST_DATA_PATH) +# assert len(tfg.group_results) > 0 +# for setting in OBJ_DESC['settings']: +# assert getattr(tfg, setting) is not None +# assert tfg.power_spectra is not None +# for meta_dat in OBJ_DESC['meta_data']: +# assert getattr(tfg, meta_dat) is not None def test_report(skip_if_no_mpl): """Check that running the top level model method runs.""" diff --git a/specparam/tests/objs/test_model.py b/specparam/tests/objs/test_model.py index cf768a47..24f986fd 100644 --- a/specparam/tests/objs/test_model.py +++ b/specparam/tests/objs/test_model.py @@ -177,57 +177,57 @@ def test_checks(): with raises(NoDataError): tfm.fit() -def test_load(): - """Test loading data into model object. Note: loads files from test_core_io.""" - - # Test loading just results - tfm = SpectralModel(verbose=False) - file_name_res = 'test_res' - tfm.load(file_name_res, TEST_DATA_PATH) - # Check that result attributes get filled - for result in OBJ_DESC['results']: - assert not np.all(np.isnan(getattr(tfm, result))) - # Test that settings and data are None - # Except for aperiodic mode, which can be inferred from the data - for setting in OBJ_DESC['settings']: - if setting != 'aperiodic_mode': - assert getattr(tfm, setting) is None - assert getattr(tfm, 'power_spectrum') is None - - # Test loading just settings - tfm = SpectralModel(verbose=False) - file_name_set = 'test_set' - tfm.load(file_name_set, TEST_DATA_PATH) - for setting in OBJ_DESC['settings']: - assert getattr(tfm, setting) is not None - # Test that results and data are None - for result in OBJ_DESC['results']: - assert np.all(np.isnan(getattr(tfm, result))) - assert tfm.power_spectrum is None - - # Test loading just data - tfm = SpectralModel(verbose=False) - file_name_dat = 'test_dat' - tfm.load(file_name_dat, TEST_DATA_PATH) - assert tfm.power_spectrum is not None - # Test that settings and results are None - for setting in OBJ_DESC['settings']: - assert getattr(tfm, setting) is None - for result in OBJ_DESC['results']: - assert np.all(np.isnan(getattr(tfm, result))) - - # Test loading all elements - tfm = SpectralModel(verbose=False) - file_name_all = 'test_all' - tfm.load(file_name_all, TEST_DATA_PATH) - for result in OBJ_DESC['results']: - assert not np.all(np.isnan(getattr(tfm, result))) - for setting in OBJ_DESC['settings']: - assert getattr(tfm, setting) is not None - for data in OBJ_DESC['data']: - assert getattr(tfm, data) is not None - for meta_dat in OBJ_DESC['meta_data']: - assert getattr(tfm, meta_dat) is not None +# def test_load(): +# """Test loading data into model object. Note: loads files from test_core_io.""" + +# # Test loading just results +# tfm = SpectralModel(verbose=False) +# file_name_res = 'test_res' +# tfm.load(file_name_res, TEST_DATA_PATH) +# # Check that result attributes get filled +# for result in OBJ_DESC['results']: +# assert not np.all(np.isnan(getattr(tfm, result))) +# # Test that settings and data are None +# # Except for aperiodic mode, which can be inferred from the data +# for setting in OBJ_DESC['settings']: +# if setting != 'aperiodic_mode': +# assert getattr(tfm, setting) is None +# assert getattr(tfm, 'power_spectrum') is None + +# # Test loading just settings +# tfm = SpectralModel(verbose=False) +# file_name_set = 'test_set' +# tfm.load(file_name_set, TEST_DATA_PATH) +# for setting in OBJ_DESC['settings']: +# assert getattr(tfm, setting) is not None +# # Test that results and data are None +# for result in OBJ_DESC['results']: +# assert np.all(np.isnan(getattr(tfm, result))) +# assert tfm.power_spectrum is None + +# # Test loading just data +# tfm = SpectralModel(verbose=False) +# file_name_dat = 'test_dat' +# tfm.load(file_name_dat, TEST_DATA_PATH) +# assert tfm.power_spectrum is not None +# # Test that settings and results are None +# for setting in OBJ_DESC['settings']: +# assert getattr(tfm, setting) is None +# for result in OBJ_DESC['results']: +# assert np.all(np.isnan(getattr(tfm, result))) + +# # Test loading all elements +# tfm = SpectralModel(verbose=False) +# file_name_all = 'test_all' +# tfm.load(file_name_all, TEST_DATA_PATH) +# for result in OBJ_DESC['results']: +# assert not np.all(np.isnan(getattr(tfm, result))) +# for setting in OBJ_DESC['settings']: +# assert getattr(tfm, setting) is not None +# for data in OBJ_DESC['data']: +# assert getattr(tfm, data) is not None +# for meta_dat in OBJ_DESC['meta_data']: +# assert getattr(tfm, meta_dat) is not None def test_add_data(): """Tests method to add data to model objects.""" diff --git a/specparam/tests/objs/test_utils.py b/specparam/tests/objs/test_utils.py index 77bb375d..d13d52e1 100644 --- a/specparam/tests/objs/test_utils.py +++ b/specparam/tests/objs/test_utils.py @@ -1,143 +1,143 @@ -"""Test functions for specparam.objs.utils.""" +# """Test functions for specparam.objs.utils.""" -from pytest import raises - -import numpy as np - -from specparam import SpectralGroupModel -from specparam.objs.utils import compare_model_objs -from specparam.sim import sim_group_power_spectra -from specparam.core.errors import NoModelError, IncompatibleSettingsError +# from pytest import raises + +# import numpy as np + +# from specparam import SpectralGroupModel +# from specparam.objs.utils import compare_model_objs +# from specparam.sim import sim_group_power_spectra +# from specparam.core.errors import NoModelError, IncompatibleSettingsError -from specparam.tests.tutils import default_group_params - -from specparam.objs.utils import * - -################################################################################################### -################################################################################################### - -def test_compare_model_objs(tfm, tfg): - - for f_obj in [tfm, tfg]: - - f_obj2 = f_obj.copy() - - assert compare_model_objs([f_obj, f_obj2], 'settings') - f_obj2.peak_width_limits = [2, 4] - f_obj2._reset_internal_settings() - assert not compare_model_objs([f_obj, f_obj2], 'settings') +# from specparam.tests.tutils import default_group_params + +# from specparam.objs.utils import * + +# ################################################################################################### +# ################################################################################################### + +# def test_compare_model_objs(tfm, tfg): + +# for f_obj in [tfm, tfg]: + +# f_obj2 = f_obj.copy() + +# assert compare_model_objs([f_obj, f_obj2], 'settings') +# f_obj2.peak_width_limits = [2, 4] +# f_obj2._reset_internal_settings() +# assert not compare_model_objs([f_obj, f_obj2], 'settings') - assert compare_model_objs([f_obj, f_obj2], 'meta_data') - f_obj2.freq_range = [5, 25] - assert not compare_model_objs([f_obj, f_obj2], 'meta_data') - -def test_average_group(tfg, tbands): +# assert compare_model_objs([f_obj, f_obj2], 'meta_data') +# f_obj2.freq_range = [5, 25] +# assert not compare_model_objs([f_obj, f_obj2], 'meta_data') + +# def test_average_group(tfg, tbands): - nfm = average_group(tfg, tbands) - assert nfm +# nfm = average_group(tfg, tbands) +# assert nfm - # Test bad average method error - with raises(ValueError): - average_group(tfg, tbands, avg_method='BAD') - - # Test no data available error - ntfg = SpectralGroupModel() - with raises(NoModelError): - average_group(ntfg, tbands) - -def test_average_reconstructions(tfg): - - freqs, avg_model = average_reconstructions(tfg) - assert isinstance(freqs, np.ndarray) - assert isinstance(avg_model, np.ndarray) - assert freqs.shape == avg_model.shape - -def test_combine_model_objs(tfm, tfg): - - tfm2 = tfm.copy() - tfm3 = tfm.copy() - tfg2 = tfg.copy() - tfg3 = tfg.copy() - - # Check combining 2 model objects - nfg1 = combine_model_objs([tfm, tfm2]) - assert nfg1 - assert len(nfg1) == 2 - assert compare_model_objs([nfg1, tfm], 'settings') - assert nfg1.group_results[0] == tfm.get_results() - assert nfg1.group_results[-1] == tfm2.get_results() - - # Check combining 3 model objects - nfg2 = combine_model_objs([tfm, tfm2, tfm3]) - assert nfg2 - assert len(nfg2) == 3 - assert compare_model_objs([nfg2, tfm], 'settings') - assert nfg2.group_results[0] == tfm.get_results() - assert nfg2.group_results[-1] == tfm3.get_results() - - # Check combining 2 group objects - nfg3 = combine_model_objs([tfg, tfg2]) - assert nfg3 - assert len(nfg3) == len(tfg) + len(tfg2) - assert compare_model_objs([nfg3, tfg, tfg2], 'settings') - assert nfg3.group_results[0] == tfg.group_results[0] - assert nfg3.group_results[-1] == tfg2.group_results[-1] - - # Check combining 3 group objects - nfg4 = combine_model_objs([tfg, tfg2, tfg3]) - assert nfg4 - assert len(nfg4) == len(tfg) + len(tfg2) + len(tfg3) - assert compare_model_objs([nfg4, tfg, tfg2, tfg3], 'settings') - assert nfg4.group_results[0] == tfg.group_results[0] - assert nfg4.group_results[-1] == tfg3.group_results[-1] - - # Check combining a mixture of model & group objects - nfg5 = combine_model_objs([tfg, tfm, tfg2, tfm2]) - assert nfg5 - assert len(nfg5) == len(tfg) + 1 + len(tfg2) + 1 - assert compare_model_objs([nfg5, tfg, tfm, tfg2, tfm2], 'settings') - assert nfg5.group_results[0] == tfg.group_results[0] - assert nfg5.group_results[-1] == tfm2.get_results() - - # Check combining objects with no data - tfm2._reset_data_results(False, True, True) - tfg2._reset_data_results(False, True, True, True) - nfg6 = combine_model_objs([tfm2, tfg2]) - assert len(nfg6) == 1 + len(tfg2) - assert nfg6.power_spectra is None - -def test_combine_errors(tfm, tfg): - - # Incompatible settings - for f_obj in [tfm, tfg]: - f_obj2 = f_obj.copy() - f_obj2.peak_width_limits = [2, 4] - f_obj2._reset_internal_settings() - - with raises(IncompatibleSettingsError): - combine_model_objs([f_obj, f_obj2]) - - # Incompatible data information - for f_obj in [tfm, tfg]: - f_obj2 = f_obj.copy() - f_obj2.freq_range = [5, 30] - - with raises(IncompatibleSettingsError): - combine_model_objs([f_obj, f_obj2]) - -def test_fit_models_3d(tfg): - - n_groups = 2 - n_spectra = 3 - xs, ys = sim_group_power_spectra(n_spectra, *default_group_params()) - ys = np.stack([ys] * n_groups, axis=0) - spectra_shape = np.shape(ys) - - tfg = SpectralGroupModel() - fgs = fit_models_3d(tfg, xs, ys) - - assert len(fgs) == n_groups == spectra_shape[0] - for fg in fgs: - assert fg - assert len(fg) == n_spectra - assert fg.power_spectra.shape == spectra_shape[1:] +# # Test bad average method error +# with raises(ValueError): +# average_group(tfg, tbands, avg_method='BAD') + +# # Test no data available error +# ntfg = SpectralGroupModel() +# with raises(NoModelError): +# average_group(ntfg, tbands) + +# def test_average_reconstructions(tfg): + +# freqs, avg_model = average_reconstructions(tfg) +# assert isinstance(freqs, np.ndarray) +# assert isinstance(avg_model, np.ndarray) +# assert freqs.shape == avg_model.shape + +# def test_combine_model_objs(tfm, tfg): + +# tfm2 = tfm.copy() +# tfm3 = tfm.copy() +# tfg2 = tfg.copy() +# tfg3 = tfg.copy() + +# # Check combining 2 model objects +# nfg1 = combine_model_objs([tfm, tfm2]) +# assert nfg1 +# assert len(nfg1) == 2 +# assert compare_model_objs([nfg1, tfm], 'settings') +# assert nfg1.group_results[0] == tfm.get_results() +# assert nfg1.group_results[-1] == tfm2.get_results() + +# # Check combining 3 model objects +# nfg2 = combine_model_objs([tfm, tfm2, tfm3]) +# assert nfg2 +# assert len(nfg2) == 3 +# assert compare_model_objs([nfg2, tfm], 'settings') +# assert nfg2.group_results[0] == tfm.get_results() +# assert nfg2.group_results[-1] == tfm3.get_results() + +# # Check combining 2 group objects +# nfg3 = combine_model_objs([tfg, tfg2]) +# assert nfg3 +# assert len(nfg3) == len(tfg) + len(tfg2) +# assert compare_model_objs([nfg3, tfg, tfg2], 'settings') +# assert nfg3.group_results[0] == tfg.group_results[0] +# assert nfg3.group_results[-1] == tfg2.group_results[-1] + +# # Check combining 3 group objects +# nfg4 = combine_model_objs([tfg, tfg2, tfg3]) +# assert nfg4 +# assert len(nfg4) == len(tfg) + len(tfg2) + len(tfg3) +# assert compare_model_objs([nfg4, tfg, tfg2, tfg3], 'settings') +# assert nfg4.group_results[0] == tfg.group_results[0] +# assert nfg4.group_results[-1] == tfg3.group_results[-1] + +# # Check combining a mixture of model & group objects +# nfg5 = combine_model_objs([tfg, tfm, tfg2, tfm2]) +# assert nfg5 +# assert len(nfg5) == len(tfg) + 1 + len(tfg2) + 1 +# assert compare_model_objs([nfg5, tfg, tfm, tfg2, tfm2], 'settings') +# assert nfg5.group_results[0] == tfg.group_results[0] +# assert nfg5.group_results[-1] == tfm2.get_results() + +# # Check combining objects with no data +# tfm2._reset_data_results(False, True, True) +# tfg2._reset_data_results(False, True, True, True) +# nfg6 = combine_model_objs([tfm2, tfg2]) +# assert len(nfg6) == 1 + len(tfg2) +# assert nfg6.power_spectra is None + +# def test_combine_errors(tfm, tfg): + +# # Incompatible settings +# for f_obj in [tfm, tfg]: +# f_obj2 = f_obj.copy() +# f_obj2.peak_width_limits = [2, 4] +# f_obj2._reset_internal_settings() + +# with raises(IncompatibleSettingsError): +# combine_model_objs([f_obj, f_obj2]) + +# # Incompatible data information +# for f_obj in [tfm, tfg]: +# f_obj2 = f_obj.copy() +# f_obj2.freq_range = [5, 30] + +# with raises(IncompatibleSettingsError): +# combine_model_objs([f_obj, f_obj2]) + +# def test_fit_models_3d(tfg): + +# n_groups = 2 +# n_spectra = 3 +# xs, ys = sim_group_power_spectra(n_spectra, *default_group_params()) +# ys = np.stack([ys] * n_groups, axis=0) +# spectra_shape = np.shape(ys) + +# tfg = SpectralGroupModel() +# fgs = fit_models_3d(tfg, xs, ys) + +# assert len(fgs) == n_groups == spectra_shape[0] +# for fg in fgs: +# assert fg +# assert len(fg) == n_spectra +# assert fg.power_spectra.shape == spectra_shape[1:] diff --git a/specparam/tests/utils/test_io.py b/specparam/tests/utils/test_io.py index fc602e0f..45d80d39 100644 --- a/specparam/tests/utils/test_io.py +++ b/specparam/tests/utils/test_io.py @@ -1,46 +1,46 @@ -"""Test functions for specparam.utils.io.""" +# """Test functions for specparam.utils.io.""" -import numpy as np +# import numpy as np -from specparam.core.items import OBJ_DESC -from specparam.objs import SpectralModel, SpectralGroupModel +# from specparam.core.items import OBJ_DESC +# from specparam.objs import SpectralModel, SpectralGroupModel -from specparam.tests.settings import TEST_DATA_PATH +# from specparam.tests.settings import TEST_DATA_PATH -from specparam.utils.io import * +# from specparam.utils.io import * -################################################################################################### -################################################################################################### +# ################################################################################################### +# ################################################################################################### -def test_load_model(): +# def test_load_model(): - file_name = 'test_all' +# file_name = 'test_all' - tfm = load_model(file_name, TEST_DATA_PATH) +# tfm = load_model(file_name, TEST_DATA_PATH) - assert isinstance(tfm, SpectralModel) +# assert isinstance(tfm, SpectralModel) - # Check that all elements get loaded - for result in OBJ_DESC['results']: - assert not np.all(np.isnan(getattr(tfm, result))) - for setting in OBJ_DESC['settings']: - assert getattr(tfm, setting) is not None - for data in OBJ_DESC['data']: - assert getattr(tfm, data) is not None - for meta_dat in OBJ_DESC['meta_data']: - assert getattr(tfm, meta_dat) is not None +# # Check that all elements get loaded +# for result in OBJ_DESC['results']: +# assert not np.all(np.isnan(getattr(tfm, result))) +# for setting in OBJ_DESC['settings']: +# assert getattr(tfm, setting) is not None +# for data in OBJ_DESC['data']: +# assert getattr(tfm, data) is not None +# for meta_dat in OBJ_DESC['meta_data']: +# assert getattr(tfm, meta_dat) is not None -def test_load_group(): +# def test_load_group(): - file_name = 'test_group_all' - tfg = load_group(file_name, TEST_DATA_PATH) +# file_name = 'test_group_all' +# tfg = load_group(file_name, TEST_DATA_PATH) - assert isinstance(tfg, SpectralGroupModel) +# assert isinstance(tfg, SpectralGroupModel) - # Check that all elements get loaded - assert len(tfg.group_results) > 0 - for setting in OBJ_DESC['settings']: - assert getattr(tfg, setting) is not None - assert tfg.power_spectra is not None - for meta_dat in OBJ_DESC['meta_data']: - assert getattr(tfg, meta_dat) is not None +# # Check that all elements get loaded +# assert len(tfg.group_results) > 0 +# for setting in OBJ_DESC['settings']: +# assert getattr(tfg, setting) is not None +# assert tfg.power_spectra is not None +# for meta_dat in OBJ_DESC['meta_data']: +# assert getattr(tfg, meta_dat) is not None From f4edc40f8e369109da4e164f755ad52b3ca75202 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 9 Aug 2023 22:35:16 -0400 Subject: [PATCH 03/20] acess param inds from Mode object --- specparam/objs/model.py | 13 ++++++++----- specparam/objs/modes.py | 4 ++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/specparam/objs/model.py b/specparam/objs/model.py index d3651b22..53e094c7 100644 --- a/specparam/objs/model.py +++ b/specparam/objs/model.py @@ -10,7 +10,7 @@ from specparam.objs.base import BaseObject from specparam.objs.algorithm import SpectralFitAlgorithm -from specparam.core.info import get_indices +#from specparam.core.info import get_indices from specparam.core.io import save_model, load_json from specparam.core.reports import save_model_report from specparam.core.modutils import copy_doc_func_to_method @@ -214,10 +214,6 @@ def get_params(self, name, col=None): if not self.has_model: raise NoModelError("No model fit results are available to extract, can not proceed.") - # If col specified as string, get mapping back to integer - if isinstance(col, str): - col = get_indices(self.aperiodic_mode)[col] - # Allow for shortcut alias, without adding `_params` if name in ['aperiodic', 'peak', 'gaussian']: name = name + '_params' @@ -229,6 +225,13 @@ def get_params(self, name, col=None): if isinstance(out, np.ndarray) and out.size == 0: out = np.array([np.nan, np.nan, np.nan]) + # If col specified as string, get mapping back to integer + if isinstance(col, str): + if 'aperiodic' in name: + col = self.aperiodic_mode.param_indices[col.lower()] + else: + col = self.periodic_mode.param_indices[col.lower()] + # Select out a specific column, if requested if col is not None: diff --git a/specparam/objs/modes.py b/specparam/objs/modes.py index 9d3ee201..033bd9e5 100644 --- a/specparam/objs/modes.py +++ b/specparam/objs/modes.py @@ -22,6 +22,10 @@ def __init__(self, name, component, description, func, params, param_description def __repr__(self): return self.name + @property + def param_indices(self): + return {label : index for index, label in enumerate(self.params)} + ################################################################################################### ## APERIODIC MODES From 8a015f8b6962ad32689780e16904d636e768f7a1 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 9 Aug 2023 23:22:29 -0400 Subject: [PATCH 04/20] clean ups --- specparam/objs/algorithm.py | 8 +++--- specparam/objs/model.py | 1 - specparam/objs/modes.py | 51 +------------------------------------ 3 files changed, 5 insertions(+), 55 deletions(-) diff --git a/specparam/objs/algorithm.py b/specparam/objs/algorithm.py index db69cc86..696310ce 100644 --- a/specparam/objs/algorithm.py +++ b/specparam/objs/algorithm.py @@ -232,7 +232,7 @@ def _reset_results(self, clear_results=False): if clear_results: self.aperiodic_params_ = np.array([np.nan] * \ - (2 if str(self.aperiodic_mode) == 'fixed' else 3)) + (2 if self.aperiodic_mode.name == 'fixed' else 3)) self.gaussian_params_ = np.empty([0, 3]) self.peak_params_ = np.empty([0, 3]) self.r_squared_ = np.nan @@ -273,13 +273,13 @@ def _simple_ap_fit(self, freqs, power_spectrum): # Get the guess parameters and/or calculate from the data, as needed # Note that these are collected as lists, to concatenate with or without knee later off_guess = [power_spectrum[0] if not self._ap_guess[0] else self._ap_guess[0]] - kne_guess = [self._ap_guess[1]] if str(self.aperiodic_mode) == 'knee' else [] + kne_guess = [self._ap_guess[1]] if self.aperiodic_mode.name == 'knee' else [] exp_guess = [np.abs((self.power_spectrum[-1] - self.power_spectrum[0]) / (np.log10(self.freqs[-1]) - np.log10(self.freqs[0]))) if not self._ap_guess[2] else self._ap_guess[2]] # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee - ap_bounds = self._ap_bounds if str(self.aperiodic_mode) == 'knee' \ + ap_bounds = self._ap_bounds if self.aperiodic_mode.name == 'knee' \ else tuple(bound[0::2] for bound in self._ap_bounds) # Collect together guess parameters @@ -343,7 +343,7 @@ def _robust_ap_fit(self, freqs, power_spectrum): spectrum_ignore = power_spectrum[perc_mask] # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee - ap_bounds = self._ap_bounds if str(self.aperiodic_mode) == 'knee' \ + ap_bounds = self._ap_bounds if self.aperiodic_mode.name == 'knee' \ else tuple(bound[0::2] for bound in self._ap_bounds) # Second aperiodic fit - using results of first fit as guess parameters diff --git a/specparam/objs/model.py b/specparam/objs/model.py index 53e094c7..f6d8e20a 100644 --- a/specparam/objs/model.py +++ b/specparam/objs/model.py @@ -10,7 +10,6 @@ from specparam.objs.base import BaseObject from specparam.objs.algorithm import SpectralFitAlgorithm -#from specparam.core.info import get_indices from specparam.core.io import save_model, load_json from specparam.core.reports import save_model_report from specparam.core.modutils import copy_doc_func_to_method diff --git a/specparam/objs/modes.py b/specparam/objs/modes.py index 033bd9e5..d148f2c9 100644 --- a/specparam/objs/modes.py +++ b/specparam/objs/modes.py @@ -1,4 +1,4 @@ -""" """ +"""Define fitting modes.""" from specparam.core.funcs import * @@ -29,21 +29,6 @@ def param_indices(self): ################################################################################################### ## APERIODIC MODES -# 'Fixed' model -# ap_fixed = { -# 'name' : 'fixed', -# 'component' : 'aperiodic', -# 'description' : 'Fit an exponential, with no knee.', -# 'func' : expo_nk_function, -# 'params' : ['offset', 'exponent'], -# 'param_description' : { -# 'offset' : 'Offset of the aperiodic component.', -# 'exponent' : 'Exponent of the aperiodic component.' -# }, -# 'freq_space' : 'linear', -# 'powers_space' : 'log10', -# } - param_desc_fixed = { 'offset' : 'Offset of the aperiodic component.', 'exponent' : 'Exponent of the aperiodic component.', @@ -52,24 +37,6 @@ def param_indices(self): expo_nk_function, ['offset', 'exponent'], param_desc_fixed, 'linear', 'log10') - -# 'Knee' model -# ap_knee = { -# 'name' : 'knee', -# 'component' : 'aperiodic', -# 'description' : 'Fit an exponential, with a knee.', -# 'func' : expo_function, -# 'params' : ['offset', 'knee', 'exponent'], -# 'param_description' : { -# 'offset' : 'Offset of the aperiodic component.', -# 'knee' : 'Knee of the aperiodic component.', -# 'exponent' : 'Exponent of the aperiodic component.' -# }, -# 'freq_space' : 'linear', -# 'powers_space' : 'log10', -# } - - param_desc_knee = { 'offset' : 'Offset of the aperiodic component.', 'knee' : 'Knee of the aperiodic component.', @@ -89,22 +56,6 @@ def param_indices(self): ################################################################################################### ## PERIODIC MODES -# # 'Gaussian' model -# pe_gaussian = { -# 'name' : 'gaussian', -# 'component' : 'periodic', -# 'description' : 'Gaussian peak fit function.', -# 'func' : gaussian_function, -# 'params' : ['cf', 'pw', 'bw'], -# 'param_description' : { -# 'cf' : 'Center frequency of the peak.', -# 'pw' : 'Power of the peak, over and above the aperiodic component.', -# 'bw' : 'Bandwidth of the peak.', -# }, -# 'freq_space' : 'linear', -# 'powers_space' : 'log10', -# } - param_desc_gaus = { 'cf' : 'Center frequency of the peak.', 'pw' : 'Power of the peak, over and above the aperiodic component.', From 5c2ce372d797075cd4623b2bf492565b7280edb7 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 9 Aug 2023 23:57:29 -0400 Subject: [PATCH 05/20] update modes orgs & add n_params --- specparam/objs/modes.py | 43 ++++++++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/specparam/objs/modes.py b/specparam/objs/modes.py index d148f2c9..4c2eda35 100644 --- a/specparam/objs/modes.py +++ b/specparam/objs/modes.py @@ -22,6 +22,10 @@ def __init__(self, name, component, description, func, params, param_description def __repr__(self): return self.name + @property + def n_params(self): + return len(self.params) + @property def param_indices(self): return {label : index for index, label in enumerate(self.params)} @@ -33,9 +37,16 @@ def param_indices(self): 'offset' : 'Offset of the aperiodic component.', 'exponent' : 'Exponent of the aperiodic component.', } -ap_fixed = Mode('fixed', 'aperiodic', 'Fit an exponential, with no knee.', - expo_nk_function, ['offset', 'exponent'], param_desc_fixed, - 'linear', 'log10') +ap_fixed = Mode( + name='fixed', + component='aperiodic', + description='Fit an exponential, with no knee.', + func=expo_nk_function, + params=['offset', 'exponent'], + param_description=param_desc_fixed, + freq_space='linear', + powers_space='log10', +) param_desc_knee = { 'offset' : 'Offset of the aperiodic component.', @@ -43,9 +54,16 @@ def param_indices(self): 'exponent' : 'Exponent of the aperiodic component.', } -ap_knee = Mode('knee', 'aperiodic', 'Fit an exponential, with a knee.', - expo_function, ['offset', 'knee', 'exponent'], param_desc_knee, - 'linear', 'log10') +ap_knee = Mode( + name='knee', + component='aperiodic', + description='Fit an exponential, with a knee.', + func=expo_function, + params=['offset', 'knee', 'exponent'], + param_description=param_desc_knee, + freq_space='linear', + powers_space='log10', +) # Collect available aperiodic modes AP_MODES = { @@ -62,9 +80,16 @@ def param_indices(self): 'bw' : 'Bandwidth of the peak.', } -pe_gaussian = Mode('gaussian', 'periodic', 'Gaussian peak fit function.', - gaussian_function, ['cf', 'pw', 'bw'], param_desc_gaus, - 'linear', 'log10') +pe_gaussian = Mode( + name='gaussian', + component='periodic', + description='Gaussian peak fit function.', + func=gaussian_function, + params=['cf', 'pw', 'bw'], + param_description=param_desc_gaus, + freq_space='linear', + powers_space='log10', +) # Collect available periodic modes PE_MODES = { From f4c35d2eb388b5cc24cd969a5649314ec64fca15 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 9 Aug 2023 23:58:02 -0400 Subject: [PATCH 06/20] add periodic_mode as a passable input --- specparam/objs/model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/specparam/objs/model.py b/specparam/objs/model.py index f6d8e20a..d78176ef 100644 --- a/specparam/objs/model.py +++ b/specparam/objs/model.py @@ -95,10 +95,11 @@ class SpectralModel(SpectralFitAlgorithm, BaseObject): """ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0, - peak_threshold=2.0, aperiodic_mode='fixed', verbose=True, **model_kwargs): + peak_threshold=2.0, aperiodic_mode='fixed', periodic_mode='gaussian', + verbose=True, **model_kwargs): """Initialize model object.""" - BaseObject.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode='gaussian', + BaseObject.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, debug_mode=model_kwargs.pop('debug_mode', False), verbose=verbose) SpectralFitAlgorithm.__init__(self, peak_width_limits=peak_width_limits, From 88cdd0a934494dc85c94107197e2c799f1c7405b Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 10 Aug 2023 00:21:46 -0400 Subject: [PATCH 07/20] group_three -> groupby --- specparam/core/utils.py | 16 +++++++++------- specparam/sim/params.py | 4 ++-- specparam/tests/core/test_utils.py | 6 +++--- specparam/tests/objs/test_model.py | 6 +++--- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/specparam/core/utils.py b/specparam/core/utils.py index 9a7ab2a8..84a96fde 100644 --- a/specparam/core/utils.py +++ b/specparam/core/utils.py @@ -27,13 +27,15 @@ def unlog(arr, base=10): return np.power(base, arr) -def group_three(vec): - """Group an array of values into threes. +def groupby(vec, groupby): + """Group an array of values by a specified number. Parameters ---------- vec : list or 1d array List or array of items to group by 3. Length of array must be divisible by three. + num : int + Number to group by. Returns ------- @@ -43,17 +45,17 @@ def group_three(vec): Raises ------ ValueError - If input data cannot be evenly grouped into threes. + If input data cannot be evenly grouped into specified number. """ - if len(vec) % 3 != 0: + if len(vec) % groupby != 0: raise ValueError("Wrong size array to group by three.") - # Reshape, if an array, as it's faster, otherwise asssume lise + # Reshape, if an array, as it's faster, otherwise assume list if isinstance(vec, np.ndarray): - return np.reshape(vec, (-1, 3)) + return np.reshape(vec, (-1, groupby)) else: - return [list(vec[ii:ii+3]) for ii in range(0, len(vec), 3)] + return [list(vec[ii:ii+groupby]) for ii in range(0, len(vec), groupby)] def nearest_ind(array, value): diff --git a/specparam/sim/params.py b/specparam/sim/params.py index 5fcfde1f..132edeb1 100644 --- a/specparam/sim/params.py +++ b/specparam/sim/params.py @@ -2,7 +2,7 @@ import numpy as np -from specparam.core.utils import group_three, check_flat +from specparam.core.utils import groupby, check_flat from specparam.core.info import get_indices from specparam.core.funcs import infer_ap_func from specparam.core.errors import InconsistentDataError @@ -31,7 +31,7 @@ def collect_sim_params(aperiodic_params, periodic_params, nlv): """ return SimParams(aperiodic_params.copy(), - sorted(group_three(check_flat(periodic_params))), + sorted(groupby(check_flat(periodic_params), 3)), nlv) diff --git a/specparam/tests/core/test_utils.py b/specparam/tests/core/test_utils.py index ab693456..14398617 100644 --- a/specparam/tests/core/test_utils.py +++ b/specparam/tests/core/test_utils.py @@ -19,13 +19,13 @@ def test_unlog(): unlogged = unlog(logged) assert np.array_equal(orig, unlogged) -def test_group_three(): +def test_groupby(): dat = [0, 1, 2, 3, 4, 5] - assert group_three(dat) == [[0, 1, 2], [3, 4, 5]] + assert groupby(dat, 3) == [[0, 1, 2], [3, 4, 5]] with raises(ValueError): - group_three([0, 1, 2, 3]) + groupby([0, 1, 2, 3], 3) def test_dict_array_to_lst(): diff --git a/specparam/tests/objs/test_model.py b/specparam/tests/objs/test_model.py index 24f986fd..56fa18ba 100644 --- a/specparam/tests/objs/test_model.py +++ b/specparam/tests/objs/test_model.py @@ -11,7 +11,7 @@ from specparam.core.items import OBJ_DESC from specparam.core.errors import FitError -from specparam.core.utils import group_three +from specparam.core.utils import groupby from specparam.sim import gen_freqs, sim_power_spectrum from specparam.data import FitResults from specparam.core.modutils import safe_import @@ -69,7 +69,7 @@ def test_fit_nk(): assert np.allclose(ap_params, tfm.aperiodic_params_, [0.5, 0.1]) # Check model results - gaussian parameters - for ii, gauss in enumerate(group_three(gauss_params)): + for ii, gauss in enumerate(groupby(gauss_params, 3)): assert np.allclose(gauss, tfm.gaussian_params_[ii], [2.0, 0.5, 1.0]) def test_fit_nk_noise(): @@ -100,7 +100,7 @@ def test_fit_knee(): assert np.allclose(ap_params, tfm.aperiodic_params_, [1, 2, 0.2]) # Check model results - gaussian parameters - for ii, gauss in enumerate(group_three(gauss_params)): + for ii, gauss in enumerate(groupby(gauss_params, 3)): assert np.allclose(gauss, tfm.gaussian_params_[ii], [2.0, 0.5, 1.0]) def test_fit_measures(): From 42af5bc28c7e60f6f6cbb81087d273c8fa3ef503 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 10 Aug 2023 00:30:40 -0400 Subject: [PATCH 08/20] generalize processes for number of params (including temp updates) --- specparam/objs/algorithm.py | 60 +++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/specparam/objs/algorithm.py b/specparam/objs/algorithm.py index 696310ce..a1c805a3 100644 --- a/specparam/objs/algorithm.py +++ b/specparam/objs/algorithm.py @@ -6,12 +6,10 @@ from numpy.linalg import LinAlgError from scipy.optimize import curve_fit -from specparam.core.utils import group_three +from specparam.core.utils import groupby from specparam.core.strings import gen_width_warning_str -#from specparam.core.funcs import gaussian_function, get_ap_func from specparam.core.errors import NoDataError, FitError from specparam.utils.params import compute_gauss_std -from specparam.sim.gen import gen_aperiodic, gen_periodic ################################################################################################### ################################################################################################### @@ -149,7 +147,6 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None): # Fit the aperiodic component self.aperiodic_params_ = self._robust_ap_fit(self.freqs, self.power_spectrum) - #self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) self._ap_fit = self.aperiodic_mode.func(self.freqs, *self.aperiodic_params_) # Flatten the power spectrum using fit aperiodic fit @@ -160,8 +157,8 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None): # Calculate the peak fit # Note: if no peaks are found, this creates a flat (all zero) peak fit - #self._peak_fit = gen_periodic(self.freqs, np.ndarray.flatten(self.gaussian_params_)) - self._peak_fit = self.periodic_mode.func(self.freqs, *np.ndarray.flatten(self.gaussian_params_)) + self._peak_fit = self.periodic_mode.func(\ + self.freqs, *np.ndarray.flatten(self.gaussian_params_)) # Create peak-removed (but not flattened) power spectrum self._spectrum_peak_rm = self.power_spectrum - self._peak_fit @@ -169,7 +166,6 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None): # Run final aperiodic fit on peak-removed power spectrum # This overwrites previous aperiodic fit, and recomputes the flattened spectrum self.aperiodic_params_ = self._simple_ap_fit(self.freqs, self._spectrum_peak_rm) - #self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) self._ap_fit = self.aperiodic_mode.func(self.freqs, *self.aperiodic_params_) self._spectrum_flat = self.power_spectrum - self._ap_fit @@ -231,10 +227,9 @@ def _reset_results(self, clear_results=False): if clear_results: - self.aperiodic_params_ = np.array([np.nan] * \ - (2 if self.aperiodic_mode.name == 'fixed' else 3)) - self.gaussian_params_ = np.empty([0, 3]) - self.peak_params_ = np.empty([0, 3]) + self.aperiodic_params_ = np.array([np.nan] * self.aperiodic_mode.n_params) + self.gaussian_params_ = np.empty([0, self.periodic_mode.n_params]) + self.peak_params_ = np.empty([0, self.periodic_mode.n_params]) self.r_squared_ = np.nan self.error_ = np.nan @@ -285,6 +280,11 @@ def _simple_ap_fit(self, freqs, power_spectrum): # Collect together guess parameters guess = np.array(off_guess + kne_guess + exp_guess) + ## TEMP + if self.aperiodic_mode.name == 'doublexp': + ap_bounds = self._ap_bounds + guess = self._ap_guess + # Ignore warnings that are raised in curve_fit # A runtime warning can occur while exploring parameters in curve fitting # This doesn't effect outcome - it won't settle on an answer that does this @@ -293,7 +293,6 @@ def _simple_ap_fit(self, freqs, power_spectrum): with warnings.catch_warnings(): warnings.simplefilter("ignore") aperiodic_params, _ = curve_fit(self.aperiodic_mode.func, - #get_ap_func(self.aperiodic_mode), freqs, power_spectrum, p0=guess, maxfev=self._maxfev, bounds=ap_bounds) except RuntimeError as excp: @@ -327,7 +326,6 @@ def _robust_ap_fit(self, freqs, power_spectrum): # Do a quick, initial aperiodic fit popt = self._simple_ap_fit(freqs, power_spectrum) - #initial_fit = gen_aperiodic(freqs, popt) initial_fit = self.aperiodic_mode.func(freqs, *popt) # Flatten power_spectrum based on initial aperiodic fit @@ -352,7 +350,6 @@ def _robust_ap_fit(self, freqs, power_spectrum): with warnings.catch_warnings(): warnings.simplefilter("ignore") aperiodic_params, _ = curve_fit(self.aperiodic_mode.func, - #get_ap_func(self.aperiodic_mode), freqs_ignore, spectrum_ignore, p0=popt, maxfev=self._maxfev, bounds=ap_bounds) except RuntimeError as excp: @@ -383,7 +380,7 @@ def _fit_peaks(self, flat_iter): """ # Initialize matrix of guess parameters for gaussian fitting - guess = np.empty([0, 3]) + guess = np.empty([0, self.periodic_mode.n_params]) # Find peak: Loop through, finding a candidate peak, and fitting with a guess gaussian # Stopping procedures: limit on # of peaks, or relative or absolute height thresholds @@ -439,9 +436,15 @@ def _fit_peaks(self, flat_iter): guess_std = self._gauss_std_limits[1] # Collect guess parameters and subtract this guess gaussian from the data - guess = np.vstack((guess, (guess_freq, guess_height, guess_std))) - peak_gauss = self.periodic_mode.func(self.freqs, guess_freq, guess_height, guess_std) - #peak_gauss = gaussian_function(self.freqs, guess_freq, guess_height, guess_std) + current_guess_params = (guess_freq, guess_height, guess_std) + + ## TEMP + if self.periodic_mode.name == 'skewnorm': + guess_skew = 0 + current_guess_params = (guess_freq, guess_height, guess_std, guess_skew) + + guess = np.vstack((guess, current_guess_params)) + peak_gauss = self.periodic_mode.func(self.freqs, *current_guess_params) flat_iter = flat_iter - peak_gauss # Check peaks based on edges, and on overlap, dropping any that violate requirements @@ -453,7 +456,7 @@ def _fit_peaks(self, flat_iter): gaussian_params = self._fit_peak_guess(guess) gaussian_params = gaussian_params[gaussian_params[:, 0].argsort()] else: - gaussian_params = np.empty([0, 3]) + gaussian_params = np.empty([0, self.periodic_mode.n_params]) return gaussian_params @@ -501,9 +504,10 @@ def _fit_peak_guess(self, guess): # Fit the peaks try: gaussian_params, _ = curve_fit(self.periodic_mode.func, - #gaussian_function, self.freqs, self._spectrum_flat, - p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds) + p0=guess, maxfev=self._maxfev, + #bounds=gaus_param_bounds ##TEMP + ) except RuntimeError as excp: error_msg = ("Model fitting failed due to not finding " "parameters in the peak component fit.") @@ -515,7 +519,7 @@ def _fit_peak_guess(self, guess): raise FitError(error_msg) from excp # Re-organize params into 2d matrix - gaussian_params = np.array(group_three(gaussian_params)) + gaussian_params = np.array(groupby(gaussian_params, self.periodic_mode.n_params)) return gaussian_params @@ -624,7 +628,7 @@ def _create_peak_params(self, gaus_params): with `freqs`, `modeled_spectrum_` and `_ap_fit` all required to be available. """ - peak_params = np.empty((len(gaus_params), 3)) + peak_params = np.empty((len(gaus_params), self.periodic_mode.n_params)) for ii, peak in enumerate(gaus_params): @@ -632,7 +636,13 @@ def _create_peak_params(self, gaus_params): ind = np.argmin(np.abs(self.freqs - peak[0])) # Collect peak parameter data - peak_params[ii] = [peak[0], self.modeled_spectrum_[ind] - self._ap_fit[ind], - peak[2] * 2] + if self.periodic_mode.name == 'gaussian': ## TEMP + peak_params[ii] = [peak[0], self.modeled_spectrum_[ind] - self._ap_fit[ind], + peak[2] * 2] + + ## TEMP: + if self.periodic_mode.name == 'skewnorm': + peak_params[ii] = [peak[0], self.modeled_spectrum_[ind] - self._ap_fit[ind], + peak[2] * 2, peak[3]] return peak_params From 280b38290787f66ce560a3757ed848c54595b097 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 13 Aug 2023 21:46:52 -0400 Subject: [PATCH 09/20] add labels to funcs file --- specparam/core/funcs.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/specparam/core/funcs.py b/specparam/core/funcs.py index bc880919..4d0198cf 100644 --- a/specparam/core/funcs.py +++ b/specparam/core/funcs.py @@ -1,11 +1,4 @@ -"""Functions that can be used for model fitting. - -NOTES ------ -- Model fitting currently (only) uses the exponential and gaussian functions. -- Linear & Quadratic functions are from previous versions. - - They are left available for easy swapping back in, if desired. -""" +"""Functions that can be used for model fitting.""" import numpy as np @@ -14,6 +7,8 @@ ################################################################################################### ################################################################################################### +## PEAK FUNCTIONS + def gaussian_function(xs, *params): """Gaussian fitting function. @@ -41,6 +36,8 @@ def gaussian_function(xs, *params): return ys +## APERIODIC FUNCTIONS + def expo_function(xs, *params): """Exponential fitting function, for fitting aperiodic component with a 'knee'. @@ -147,6 +144,8 @@ def quadratic_function(xs, *params): return ys +## GETTER FUNCTIONS + def get_pe_func(periodic_mode): """Select and return specified function for periodic component. From 8c9b5b14c8ef769d5281cfb1dc516b1930a462eb Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 13 Aug 2023 22:04:16 -0400 Subject: [PATCH 10/20] add double exp function --- specparam/core/funcs.py | 29 ++++++++++++++++++++++++++++- specparam/tests/core/test_funcs.py | 28 +++++++++++++++++++++++++--- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/specparam/core/funcs.py b/specparam/core/funcs.py index 4d0198cf..6beb91d8 100644 --- a/specparam/core/funcs.py +++ b/specparam/core/funcs.py @@ -35,7 +35,6 @@ def gaussian_function(xs, *params): return ys - ## APERIODIC FUNCTIONS def expo_function(xs, *params): @@ -94,6 +93,34 @@ def expo_nk_function(xs, *params): return ys +def double_expo_function(xs, *params): + """Double exponential fitting function, for fitting aperiodic component with two exponents and a knee. + + NOTE: this function requires linear frequency (not log). + + Parameters + ---------- + xs : 1d array + Input x-axis values. + *params : float + Parameters (offset, exp0, knee, exp1) that define the function: + y = 10^offset * (1/((x**exp0) * (knee + x^exp1)) + + Returns + ------- + ys : 1d array + Output values for exponential function. + """ + + ys = np.zeros_like(xs) + + offset, exp0, knee, exp1 = params + + ys = ys + offset - np.log10((xs**exp0) * (knee + xs**exp1)) + + return ys + + def linear_function(xs, *params): """Linear fitting function. diff --git a/specparam/tests/core/test_funcs.py b/specparam/tests/core/test_funcs.py index 5332676b..97bba4d2 100644 --- a/specparam/tests/core/test_funcs.py +++ b/specparam/tests/core/test_funcs.py @@ -12,6 +12,8 @@ ################################################################################################### ################################################################################################### +## Periodic functions + def test_gaussian_function(): ctr, hgt, wid = 50, 5, 10 @@ -26,6 +28,8 @@ def test_gaussian_function(): assert max(ys) == hgt assert np.allclose([i/sum(ys) for i in ys], norm.pdf(xs, ctr, wid)) +## Aperiodic functions + def test_expo_function(): off, knee, exp = 10, 5, 2 @@ -53,10 +57,26 @@ def test_expo_nk_function(): # By design, this expo function assumes linear xs and log-space ys # Where the log-log should be a straight line. Use that to test. - sl_meas, off_meas, _, _, _ = linregress(np.log10(xs), ys) - + exp_meas, off_meas, _, _, _ = linregress(np.log10(xs), ys) assert np.isclose(off, off_meas) - assert np.isclose(exp, np.abs(sl_meas)) + assert np.isclose(exp, np.abs(exp_meas)) + +def test_double_expo_function(): + + off, exp0, knee, exp1 = 10, 1, 5, 1 + + xs = np.arange(0.1, 100, 0.1) + ys = double_expo_function(xs, off, exp0, knee, exp1) + + assert np.all(ys) + + # Note: no obvious way to test the knee specifically + # Here - test that exponents at edges of the psd (pre & post knee) are as expected + exp_meas0, off_meas0, _, _, _ = linregress(np.log10(xs[:5]), ys[:5]) + assert np.isclose(np.abs(exp_meas0), exp0, 0.1) + exp_meas1, off_meas1, _, _, _ = linregress(np.log10(xs[-5:]), ys[-5:]) + assert np.isclose(np.abs(exp_meas1), exp0 + exp1, 0.1) + assert np.isclose(off_meas1, off, 0.25) def test_linear_function(): @@ -87,6 +107,8 @@ def test_quadratic_function(): assert np.isclose(sl_meas, sl) assert np.isclose(curve_meas, curve) +## Getter functions + def test_get_pe_func(): pe_ga_func = get_pe_func('gaussian') From 7f2bd95c567d513593bda897f98fea408fe62bff Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 13 Aug 2023 22:40:03 -0400 Subject: [PATCH 11/20] add normalize util func --- specparam/core/utils.py | 17 +++++++++++++++++ specparam/tests/core/test_utils.py | 11 +++++++++++ 2 files changed, 28 insertions(+) diff --git a/specparam/core/utils.py b/specparam/core/utils.py index 84a96fde..61f221cc 100644 --- a/specparam/core/utils.py +++ b/specparam/core/utils.py @@ -27,6 +27,23 @@ def unlog(arr, base=10): return np.power(base, arr) +def normalize(data): + """Normalize an array of numerical data (to the range of 0-1). + + Parameters + ---------- + data : np.ndarray + Array of data to normalize. + + Returns + ------- + np.ndarray + Normalized data. + """ + + return (data - np.min(data)) / (np.max(data) - np.min(data)) + + def groupby(vec, groupby): """Group an array of values by a specified number. diff --git a/specparam/tests/core/test_utils.py b/specparam/tests/core/test_utils.py index 14398617..4947e439 100644 --- a/specparam/tests/core/test_utils.py +++ b/specparam/tests/core/test_utils.py @@ -19,6 +19,17 @@ def test_unlog(): unlogged = unlog(logged) assert np.array_equal(orig, unlogged) + +def test_normalize(): + + arr1 = np.array([0, 0.25, 0.5]) + norm_arr1 = normalize(arr1) + assert np.array_equal(norm_arr1, np.array([0.0, 0.5, 1.0])) + + arr2 = np.array([0, 5, 10]) + norm_arr2 = normalize(arr2) + assert np.array_equal(norm_arr2, np.array([0.0, 0.5, 1.0])) + def test_groupby(): dat = [0, 1, 2, 3, 4, 5] From 8eb27adfc40e9947ebb8f924dd3b4a3e0937e4b1 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 13 Aug 2023 22:40:26 -0400 Subject: [PATCH 12/20] add skewed gaussian fit function --- specparam/core/funcs.py | 32 ++++++++++++++++++++++++++++++ specparam/tests/core/test_funcs.py | 21 +++++++++++++++++++- 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/specparam/core/funcs.py b/specparam/core/funcs.py index 6beb91d8..bdb8a04e 100644 --- a/specparam/core/funcs.py +++ b/specparam/core/funcs.py @@ -1,7 +1,9 @@ """Functions that can be used for model fitting.""" import numpy as np +from scipy.special import erf +from specparam.core.utils import normalize from specparam.core.errors import InconsistentDataError ################################################################################################### @@ -35,6 +37,36 @@ def gaussian_function(xs, *params): return ys + +def skewnorm_function(xs, *params): + """Skewed normal distribution fitting function. + + Parameters + ---------- + xs : 1d array + Input x-axis values. + *params : float + Parameters that define the skewed normal distribution function. + + Returns + ------- + ys : 1d array + Output values for skewed normal distribution function. + """ + + ys = np.zeros_like(xs) + + for ii in range(0, len(params), 4): + + ctr, hgt, wid, skew = params[ii:ii+4] + + ts = (xs - ctr) / wid + temp = 2 / wid * (1 / np.sqrt(2 * np.pi) * np.exp(-ts**2 / 2)) * \ + ((1 + erf(skew * ts / np.sqrt(2))) / 2) + ys = ys + hgt * normalize(temp) + + return ys + ## APERIODIC FUNCTIONS def expo_function(xs, *params): diff --git a/specparam/tests/core/test_funcs.py b/specparam/tests/core/test_funcs.py index 97bba4d2..c71eeaee 100644 --- a/specparam/tests/core/test_funcs.py +++ b/specparam/tests/core/test_funcs.py @@ -26,7 +26,26 @@ def test_gaussian_function(): # Check distribution matches generated gaussian from scipy # Generated gaussian is normalized for this comparison, height tested separately assert max(ys) == hgt - assert np.allclose([i/sum(ys) for i in ys], norm.pdf(xs, ctr, wid)) + assert np.allclose([ii/sum(ys) for ii in ys], norm.pdf(xs, ctr, wid)) + +def test_skewnorm_function(): + + # Check that with no skew, approximate gaussian + ctr, hgt, wid, skew = 50, 5, 10, 1 + xs = np.arange(1, 100) + ys_gaus = gaussian_function(xs, ctr, hgt, wid) + ys_skew = skewnorm_function(xs, ctr, hgt, wid, skew) + np.allclose(ys_gaus, ys_skew, atol=0.001) + + # Check with some skew - right skew (more density after center) + skew1 = 2 + ys_skew1 = skewnorm_function(xs, ctr, hgt, wid, skew1) + assert sum(ys_skew1[xsctr]) + + # Check with some skew - left skew (more density before center) + skew2 = -2 + ys_skew2 = skewnorm_function(xs, ctr, hgt, wid, skew2) + assert sum(ys_skew2[xs sum(ys_skew2[xs>ctr]) ## Aperiodic functions From 5f0cca3c5995158fcafb653c7e58904036cd2f7f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 13 Aug 2023 22:53:56 -0400 Subject: [PATCH 13/20] add new modes --- specparam/objs/modes.py | 47 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/specparam/objs/modes.py b/specparam/objs/modes.py index 4c2eda35..9f3befbf 100644 --- a/specparam/objs/modes.py +++ b/specparam/objs/modes.py @@ -19,13 +19,16 @@ def __init__(self, name, component, description, func, params, param_description self.freq_space = freq_space self.powers_space = powers_space + def __repr__(self): return self.name + @property def n_params(self): return len(self.params) + @property def param_indices(self): return {label : index for index, label in enumerate(self.params)} @@ -33,6 +36,7 @@ def param_indices(self): ################################################################################################### ## APERIODIC MODES +# Fixed param_desc_fixed = { 'offset' : 'Offset of the aperiodic component.', 'exponent' : 'Exponent of the aperiodic component.', @@ -48,6 +52,7 @@ def param_indices(self): powers_space='log10', ) +# Knee param_desc_knee = { 'offset' : 'Offset of the aperiodic component.', 'knee' : 'Knee of the aperiodic component.', @@ -65,15 +70,37 @@ def param_indices(self): powers_space='log10', ) + +# Double exponent +param_desc = { + 'offset' : 'Offset of the aperiodic component.', + 'exponent0' : 'Exponent of the aperiodic component, before the knee.', + 'knee' : 'Knee of the aperiodic component.', + 'exponent1' : 'Exponent of the aperiodic component, after the knee.', + } + +ap_doublexp = Mode( + name='doublexp', + component='aperiodic', + description='Fit an function with 2 exponents and a knee.', + func=double_expo_function, + params=['offset', 'exponent0', 'knee', 'exponent1'], + param_description=param_desc, + freq_space='linear', + powers_space='log10', +) + # Collect available aperiodic modes AP_MODES = { 'fixed' : ap_fixed, 'knee' : ap_knee, + 'doublexp' : ap_doublexp, } ################################################################################################### ## PERIODIC MODES +# Gaussian param_desc_gaus = { 'cf' : 'Center frequency of the peak.', 'pw' : 'Power of the peak, over and above the aperiodic component.', @@ -91,7 +118,27 @@ def param_indices(self): powers_space='log10', ) +# Skewed Gaussian +param_desc_skew = { + 'cf' : 'Center frequency of the peak.', + 'pw' : 'Power of the peak, over and above the aperiodic component.', + 'bw' : 'Bandwidth of the peak.', + 'skew' : 'Skewness of the peak.', + } + +pe_skewnorm = Mode( + name='skewnorm', + component='periodic', + description='Skewed Gaussian peak fit function.', + func=skewnorm_function, + params=['cf', 'pw', 'bw', 'skew'], + param_description=param_desc_skew, + freq_space='linear', + powers_space='log10', +) + # Collect available periodic modes PE_MODES = { 'gaussian' : pe_gaussian, + 'skewed_gaussian' : pe_skewnorm, } From d046a3c4d3afe266afe0bb1655dddda7dd8f9bd2 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 13 Aug 2023 22:59:00 -0400 Subject: [PATCH 14/20] add docs to Mode object --- specparam/objs/modes.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/specparam/objs/modes.py b/specparam/objs/modes.py index 9f3befbf..78d847db 100644 --- a/specparam/objs/modes.py +++ b/specparam/objs/modes.py @@ -6,9 +6,32 @@ ## MODES OUTLINE class Mode(): + """Defines a fit mode. + + Parameters + ---------- + name : str + Name of the mode. + component : {'periodic', 'aperiodic'}, + Which component the mode relates to. + description : str + Description of the mode. + func : callable + Function that defines the fit function for the mode. + params : list of str + Name of output parameter(s). + param_description : dict + Descriptions of the parameters. + Should have same length and keys as `params`. + freq_space : {'linear', 'log10'} + Required spacing of the frequency values for this mode. + powers_space : {'linear', 'log10'} + Required spacing of the power values for this mode. + """ def __init__(self, name, component, description, func, params, param_description, freq_space, powers_space): + """Initialize a mode.""" self.name = name self.component = component @@ -21,16 +44,22 @@ def __init__(self, name, component, description, func, params, param_description def __repr__(self): + """Return representation of this object as the name.""" + return self.name @property def n_params(self): + """Define property attribute for the number of parameters.""" + return len(self.params) @property def param_indices(self): + """Define property attribute for the indices of the parameters.""" + return {label : index for index, label in enumerate(self.params)} ################################################################################################### From bf3944dfa27a3e95a80b21a4ddee8d7fea5f11e5 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 20 Aug 2023 13:11:23 -0400 Subject: [PATCH 15/20] start process new bounds / guess funcs --- specparam/objs/algorithm.py | 97 ++++++++++++++++++++++++------------- 1 file changed, 62 insertions(+), 35 deletions(-) diff --git a/specparam/objs/algorithm.py b/specparam/objs/algorithm.py index a1c805a3..739f1e5e 100644 --- a/specparam/objs/algorithm.py +++ b/specparam/objs/algorithm.py @@ -249,6 +249,40 @@ def _check_width_limits(self): print(gen_width_warning_str(self.freq_res, self.peak_width_limits[0])) + def _get_ap_guess(self, freqs, power_spectrum): + """ """ + + # Get the guess parameters and/or calculate from the data, as needed + # Note that these are collected as lists, to concatenate with or without knee later + off_guess = [power_spectrum[0] if not self._ap_guess[0] else self._ap_guess[0]] + kne_guess = [self._ap_guess[1]] if self.aperiodic_mode.name == 'knee' else [] + exp_guess = [np.abs((self.power_spectrum[-1] - self.power_spectrum[0]) / + (np.log10(self.freqs[-1]) - np.log10(self.freqs[0]))) + if not self._ap_guess[2] else self._ap_guess[2]] + + # Collect together guess parameters + ap_guess = np.array(off_guess + kne_guess + exp_guess) + + ## TEMP + if self.aperiodic_mode.name == 'doublexp': + ap_guess = self._ap_guess + + return ap_guess + + + def _get_ap_bounds(self): + """ """ + + # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee + ap_bounds = self._ap_bounds if self.aperiodic_mode.name == 'knee' \ + else tuple(bound[0::2] for bound in self._ap_bounds) + + if self.aperiodic_mode.name == 'doublexp': + ap_bounds = self._ap_bounds + + return ap_bounds + + def _simple_ap_fit(self, freqs, power_spectrum): """Fit the aperiodic component of the power spectrum. @@ -265,25 +299,9 @@ def _simple_ap_fit(self, freqs, power_spectrum): Parameter estimates for aperiodic fit. """ - # Get the guess parameters and/or calculate from the data, as needed - # Note that these are collected as lists, to concatenate with or without knee later - off_guess = [power_spectrum[0] if not self._ap_guess[0] else self._ap_guess[0]] - kne_guess = [self._ap_guess[1]] if self.aperiodic_mode.name == 'knee' else [] - exp_guess = [np.abs((self.power_spectrum[-1] - self.power_spectrum[0]) / - (np.log10(self.freqs[-1]) - np.log10(self.freqs[0]))) - if not self._ap_guess[2] else self._ap_guess[2]] - - # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee - ap_bounds = self._ap_bounds if self.aperiodic_mode.name == 'knee' \ - else tuple(bound[0::2] for bound in self._ap_bounds) - - # Collect together guess parameters - guess = np.array(off_guess + kne_guess + exp_guess) - - ## TEMP - if self.aperiodic_mode.name == 'doublexp': - ap_bounds = self._ap_bounds - guess = self._ap_guess + # Get the guess and bounds for the aperiodic parameters + ap_guess = self._get_ap_guess(freqs, power_spectrum) + ap_bounds = self._get_ap_bounds() # Ignore warnings that are raised in curve_fit # A runtime warning can occur while exploring parameters in curve fitting @@ -293,7 +311,7 @@ def _simple_ap_fit(self, freqs, power_spectrum): with warnings.catch_warnings(): warnings.simplefilter("ignore") aperiodic_params, _ = curve_fit(self.aperiodic_mode.func, - freqs, power_spectrum, p0=guess, + freqs, power_spectrum, p0=ap_guess, maxfev=self._maxfev, bounds=ap_bounds) except RuntimeError as excp: error_msg = ("Model fitting failed due to not finding parameters in " @@ -341,8 +359,9 @@ def _robust_ap_fit(self, freqs, power_spectrum): spectrum_ignore = power_spectrum[perc_mask] # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee - ap_bounds = self._ap_bounds if self.aperiodic_mode.name == 'knee' \ - else tuple(bound[0::2] for bound in self._ap_bounds) + ap_bounds = self._get_ap_bounds() # TEMP + #ap_bounds = self._ap_bounds if self.aperiodic_mode.name == 'knee' \ + # else tuple(bound[0::2] for bound in self._ap_bounds) # Second aperiodic fit - using results of first fit as guess parameters # See note in _simple_ap_fit about warnings @@ -461,19 +480,8 @@ def _fit_peaks(self, flat_iter): return gaussian_params - def _fit_peak_guess(self, guess): - """Fits a group of peak guesses with a fit function. - - Parameters - ---------- - guess : 2d array, shape=[n_peaks, 3] - Guess parameters for gaussian fits to peaks, as gaussian parameters. - - Returns - ------- - gaussian_params : 2d array, shape=[n_peaks, 3] - Parameters for gaussian fits to peaks, as gaussian parameters. - """ + def _get_pe_bounds(self, guess): + """ """ # Set the bounds for CF, enforce positive height value, and set bandwidth limits # Note that 'guess' is in terms of gaussian std, so +/- BW is 2 * the guess_gauss_std @@ -498,6 +506,25 @@ def _fit_peak_guess(self, guess): gaus_param_bounds = (tuple(item for sublist in lo_bound for item in sublist), tuple(item for sublist in hi_bound for item in sublist)) + return gaus_param_bounds + + + def _fit_peak_guess(self, guess): + """Fits a group of peak guesses with a fit function. + + Parameters + ---------- + guess : 2d array, shape=[n_peaks, 3] + Guess parameters for gaussian fits to peaks, as gaussian parameters. + + Returns + ------- + gaussian_params : 2d array, shape=[n_peaks, 3] + Parameters for gaussian fits to peaks, as gaussian parameters. + """ + + gaus_param_bounds = self._get_pe_bounds(guess) + # Flatten guess, for use with curve fit guess = np.ndarray.flatten(guess) From b213edcb10d93754e1aa9ec25a095a935e7771c4 Mon Sep 17 00:00:00 2001 From: tom Date: Thu, 4 Apr 2024 08:44:34 -0400 Subject: [PATCH 16/20] Update readme badge links --- README.rst | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/README.rst b/README.rst index 03218b71..8a0f2918 100644 --- a/README.rst +++ b/README.rst @@ -2,29 +2,35 @@ Spectral Parameterization ========================= -|ProjectStatus|_ |Version|_ |BuildStatus|_ |Coverage|_ |License|_ |PythonVersions|_ |Paper|_ +|ProjectStatus| |Version| |BuildStatus| |Coverage| |License| |PythonVersions| |Publication| .. |ProjectStatus| image:: http://www.repostatus.org/badges/latest/active.svg -.. _ProjectStatus: https://www.repostatus.org/#active + :target: https://www.repostatus.org/#active + :alt: project status .. |Version| image:: https://img.shields.io/pypi/v/fooof.svg -.. _Version: https://pypi.python.org/pypi/fooof/ + :target: https://pypi.python.org/pypi/fooof/ + :alt: version .. |BuildStatus| image:: https://github.com/fooof-tools/fooof/actions/workflows/build.yml/badge.svg -.. _BuildStatus: https://github.com/fooof-tools/fooof/actions/workflows/build.yml + :target: https://github.com/fooof-tools/fooof/actions/workflows/build.yml + :alt: build status .. |Coverage| image:: https://codecov.io/gh/fooof-tools/fooof/branch/main/graph/badge.svg -.. _Coverage: https://codecov.io/gh/fooof-tools/fooof + :target: https://codecov.io/gh/fooof-tools/fooof + :alt: coverage .. |License| image:: https://img.shields.io/pypi/l/fooof.svg -.. _License: https://opensource.org/licenses/Apache-2.0 + :target: https://opensource.org/licenses/Apache-2.0 + :alt: license .. |PythonVersions| image:: https://img.shields.io/pypi/pyversions/fooof.svg -.. _PythonVersions: https://pypi.python.org/pypi/fooof/ - -.. |Paper| image:: https://img.shields.io/badge/paper-nn10.1038-informational.svg -.. _Paper: https://doi.org/10.1038/s41593-020-00744-x + :target: https://pypi.python.org/pypi/fooof/ + :alt: python versions +.. |Publication| image:: https://img.shields.io/badge/paper-nn10.1038-informational.svg + :target: https://doi.org/10.1038/s41593-020-00744-x + :alt: publication Spectral parameterization (`specparam`, formerly `fooof`) is a fast, efficient, and physiologically-informed tool to parameterize neural power spectra. From 024fb4a4d64fd04d75e002f7c2c761e95518f863 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 3 Aug 2024 16:32:44 -0400 Subject: [PATCH 17/20] add setable alpha level to add_shades --- specparam/plts/utils.py | 15 +++++++++++---- specparam/tests/plts/test_utils.py | 4 ++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/specparam/plts/utils.py b/specparam/plts/utils.py index f9156282..70fd1372 100644 --- a/specparam/plts/utils.py +++ b/specparam/plts/utils.py @@ -65,7 +65,8 @@ def set_alpha(n_points): return alpha -def add_shades(ax, shades, colors='r', add_center=False, logged=False): +def add_shades(ax, shades, colors='r', alpha=0.2, + add_center=False, center_alpha=0.6, logged=False): """Add shaded regions to a plot. Parameters @@ -76,8 +77,13 @@ def add_shades(ax, shades, colors='r', add_center=False, logged=False): Shaded region(s) to add to plot, defined as [lower_bound, upper_bound]. colors : str or list of string Color(s) to plot shades. + alpha : float or list of float, optional, default: 0.2 + The alpha level to add the shade regions with. + If a list, can specify a separate alpha level per shade. add_center : boolean, default: False Whether to add a line at the center point of the shaded regions. + center_alpha : float, optional, default: 0.6 + The alpha level for the center line, if added. logged : boolean, default: False Whether the shade values should be logged before applying to plot axes. """ @@ -87,16 +93,17 @@ def add_shades(ax, shades, colors='r', add_center=False, logged=False): shades = [shades] colors = repeat(colors) if not isinstance(colors, list) else colors + alphas = repeat(alpha) if not isinstance(alpha, list) else alpha - for shade, color in zip(shades, colors): + for shade, color, alpha in zip(shades, colors, alphas): shade = np.log10(shade) if logged else shade - ax.axvspan(shade[0], shade[1], color=color, alpha=0.2, lw=0) + ax.axvspan(shade[0], shade[1], color=color, alpha=alpha, lw=0) if add_center: center = sum(shade) / 2 - ax.axvspan(center, center, color='k', alpha=0.6) + ax.axvspan(center, center, color='k', alpha=center_alpha) def recursive_plot(data, plot_function, ax, **kwargs): diff --git a/specparam/tests/plts/test_utils.py b/specparam/tests/plts/test_utils.py index edfe80d1..a88accfa 100644 --- a/specparam/tests/plts/test_utils.py +++ b/specparam/tests/plts/test_utils.py @@ -33,6 +33,10 @@ def test_add_shades(skip_if_no_mpl): add_shades(check_ax(None), [4, 8]) +@plot_test +def test_add_shades_multi(skip_if_no_mpl): + add_shades(check_ax(None), [[4, 8], [8, 12], [12, 25]], colors=['b', 'c', 'y'], alpha=0.3) + @plot_test def test_recursive_plot(skip_if_no_mpl): From bc77ac54c42a8d15b5f0e00867ed938349f8e565 Mon Sep 17 00:00:00 2001 From: Kazuki Oikawa Date: Sat, 21 Sep 2024 01:47:03 +0900 Subject: [PATCH 18/20] Fixed matplotlib is required --- specparam/plts/settings.py | 6 ++++-- specparam/plts/style.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/specparam/plts/settings.py b/specparam/plts/settings.py index 263a5bee..b231e6da 100644 --- a/specparam/plts/settings.py +++ b/specparam/plts/settings.py @@ -2,13 +2,15 @@ from collections import OrderedDict -import matplotlib.pyplot as plt +from specparam.core.modutils import safe_import + +plt = safe_import('.pyplot', 'matplotlib') ################################################################################################### ################################################################################################### # Define list of default plot colors -DEFAULT_COLORS = plt.rcParams['axes.prop_cycle'].by_key()['color'] +DEFAULT_COLORS = plt.rcParams['axes.prop_cycle'].by_key()['color'] if plt else None # Define default figure sizes PLT_FIGSIZES = {'spectral' : (8.5, 6.5), diff --git a/specparam/plts/style.py b/specparam/plts/style.py index 05bff602..8b5a71a5 100644 --- a/specparam/plts/style.py +++ b/specparam/plts/style.py @@ -3,12 +3,13 @@ from itertools import cycle from functools import wraps -import matplotlib.pyplot as plt - +from specparam.core.modutils import safe_import from specparam.plts.settings import (AXIS_STYLE_ARGS, LINE_STYLE_ARGS, COLLECTION_STYLE_ARGS, CUSTOM_STYLE_ARGS, STYLE_ARGS, TICK_LABELSIZE, TITLE_FONTSIZE, LABEL_SIZE, LEGEND_SIZE, LEGEND_LOC) +plt = safe_import('.pyplot', 'matplotlib') + ################################################################################################### ################################################################################################### From 8c2ff3e4cd6283bcca7051a9417df7a6dd2ef127 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 27 Nov 2024 21:05:32 -0500 Subject: [PATCH 19/20] add py13 tests --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 17f5f284..29ee3f20 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -15,7 +15,7 @@ jobs: MODULE_NAME: specparam strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v3 From cbbec6878767f070d2bd8fa4b2c71973c31a6e64 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 27 Nov 2024 21:10:03 -0500 Subject: [PATCH 20/20] list py3.13 support --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index fdf54cec..4d854707 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,7 @@ 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', ], platforms = 'any', project_urls = {