Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds more flexibility in the aperiodic/oscillatory fitting #147

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 45 additions & 5 deletions fooof/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def report(self, freqs=None, power_spectrum=None, freq_range=None, plt_log=False
self.print_results(False)


def fit(self, freqs=None, power_spectrum=None, freq_range=None):
def fit(self, freqs=None, power_spectrum=None, freq_range=None, ap_range=None):
"""Fit the full power spectrum as a combination of periodic and aperiodic components.

Parameters
Expand All @@ -332,15 +332,17 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None):
Power values, which must be input in linear space.
freq_range : list of [float, float], optional
Frequency range to restrict power spectrum to. If not provided, keeps the entire range.
ap_range : list of [float, float], or np.ndarray of booleans of the same length as freqs, optional.
Frequency range to restrict aperiodic fit to. If not provided, it will be fit on the range specified
by freq_range.

Notes
-----
Data is optional if data has been already been added to FOOOF object.
"""

# If freqs & power_spectrum provided together, add data to object.
if freqs is not None and power_spectrum is not None:
self.add_data(freqs, power_spectrum, freq_range)
self.add_data(freqs, power_spectrum, freq_range if ap_range is None else None)
# If power spectrum provided alone, add to object, and use existing frequency data
# Note: be careful passing in power_spectrum data like this:
# It assumes the power_spectrum is already logged, with correct freq_range.
Expand All @@ -358,27 +360,65 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None):
# In rare cases, the model fails to fit. Therefore it's in a try/except
# Cause of failure: RuntimeError, failure to find parameters in curve_fit
try:

if ap_range is not None:#isolate aperiodic frequencies/spectrum
if not isinstance(ap_range,np.ndarray) or ap_range.shape[-1]==2:
ap_inds = (self.freqs >= ap_range[0]) & (self.freqs <= ap_range[1])
elif ap_range.shape[-1]==self.freqs.shape[-1]:
ap_inds = ap_range
else:
raise ValueError('ap_range must have the same length as freqs - can not proceed')

ap_freqs = self.freqs[ap_inds]
ap_spectrum = self.power_spectrum[ap_inds]
else:
ap_freqs = self.freqs
ap_spectrum = self.power_spectrum

# Fit the aperiodic component
self.aperiodic_params_ = self._robust_ap_fit(self.freqs, self.power_spectrum)
self.aperiodic_params_ = self._robust_ap_fit(ap_freqs, ap_spectrum)
self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_)

# Flatten the power_spectrum using fit aperiodic fit
self._spectrum_flat = self.power_spectrum - self._ap_fit


if ap_range is not None:#isolate periodic frequencies/spectrum
per_inds = (self.freqs >= freq_range[0]) & (self.freqs <= freq_range[1])
per_spectrum_flat = np.copy(self._spectrum_flat[per_inds])
self._spectrum_flat = per_spectrum_flat
#save/set some attributes so peak fitting works properly
freqs_0 = self.freqs
self.freqs = self.freqs[per_inds]
if freq_range:
freq_range_0 = self.freq_range
self.freq_range = freq_range



# Find peaks, and fit them with gaussians
self.gaussian_params_ = self._fit_peaks(np.copy(self._spectrum_flat))

if ap_range is not None:
#restore attributes to initial values
self.freqs = freqs_0
self.freq_range = freq_range_0

# Calculate the peak fit
# Note: if no peaks are found, this creates a flat (all zero) peak fit.
self._peak_fit = gen_peaks(self.freqs, np.ndarray.flatten(self.gaussian_params_))

# Create peak-removed (but not flattened) power spectrum.
self._spectrum_peak_rm = self.power_spectrum - self._peak_fit

if ap_range is not None:
ap_spectrum_peak_rm = self._spectrum_peak_rm[ap_inds]
else:
ap_spectrum_peak_rm = self._spectrum_peak_rm

# Run final aperiodic fit on peak-removed power spectrum
# Note: This overwrites previous aperiodic fit
self.aperiodic_params_ = self._simple_ap_fit(self.freqs, self._spectrum_peak_rm)
self.aperiodic_params_ = self._simple_ap_fit(ap_freqs, ap_spectrum_peak_rm)
self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_)

# Create full power_spectrum model fit
Expand Down
4 changes: 2 additions & 2 deletions fooof/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def combine_fooofs(fooofs):
return fg


def fit_fooof_group_3d(fg, freqs, power_spectra, freq_range=None, n_jobs=1):
def fit_fooof_group_3d(fg, freqs, power_spectra, freq_range=None, ap_range=None, n_jobs=1):
"""Run FOOOFGroup across a 3D collection of power spectra.

Parameters
Expand All @@ -138,7 +138,7 @@ def fit_fooof_group_3d(fg, freqs, power_spectra, freq_range=None, n_jobs=1):

fgs = []
for cond_spectra in power_spectra:
fg.fit(freqs, cond_spectra, freq_range, n_jobs)
fg.fit(freqs, cond_spectra, freq_range, ap_range, n_jobs)
fgs.append(fg.copy())

return fgs
12 changes: 6 additions & 6 deletions fooof/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1):
self.print_results(False)


def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1):
def fit(self, freqs=None, power_spectra=None, freq_range=None, ap_range=None, n_jobs=1):
"""Run FOOOF across a group of power_spectra.

Parameters
Expand All @@ -167,22 +167,22 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1):

# If freqs & power spectra provided together, add data to object.
if freqs is not None and power_spectra is not None:
self.add_data(freqs, power_spectra, freq_range)
self.add_data(freqs, power_spectra, freq_range if ap_range is None else None)

# Run linearly
if n_jobs == 1:
self._reset_group_results(len(self.power_spectra))
for ind, power_spectrum in \
_progress(enumerate(self.power_spectra), self.verbose, len(self)):
self._fit(power_spectrum=power_spectrum)
self._fit(power_spectrum=power_spectrum, freq_range=freq_range, ap_range=ap_range)
self.group_results[ind] = self._get_results()

# Run in parallel
else:
self._reset_group_results()
n_jobs = cpu_count() if n_jobs == -1 else n_jobs
with Pool(processes=n_jobs) as pool:
self.group_results = list(_progress(pool.imap(partial(_par_fit, fg=self),
self.group_results = list(_progress(pool.imap(partial(_par_fit, fg=self, freq_range=freq_range, ap_range=ap_range),
self.power_spectra),
self.verbose, len(self.power_spectra)))

Expand Down Expand Up @@ -366,10 +366,10 @@ def _check_width_limits(self):
###################################################################################################
###################################################################################################

def _par_fit(power_spectrum, fg):
def _par_fit(power_spectrum, fg, freq_range, ap_range):
"""Helper function for running in parallel."""

fg._fit(power_spectrum=power_spectrum)
fg._fit(power_spectrum=power_spectrum, freq_range=freq_range, ap_range=ap_range)

return fg._get_results()

Expand Down