Skip to content

Commit

Permalink
figured out weird warning
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabi committed Aug 5, 2024
1 parent 6a47208 commit c6e4dea
Show file tree
Hide file tree
Showing 4 changed files with 5,306 additions and 382 deletions.
5,658 changes: 5,283 additions & 375 deletions examples/irasa_mne.ipynb

Large diffs are not rendered by default.

13 changes: 11 additions & 2 deletions pyrasa/irasa_mne/mne_objs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# %% inherit from spectrum array

import matplotlib
import mne
import numpy as np
Expand All @@ -10,6 +11,8 @@
from pyrasa.utils.peak_utils import get_peak_params
from pyrasa.utils.types import SlopeFit

# FutureWarning:


class PeriodicSpectrumArray(SpectrumArray):
"""Subclass of SpectrumArray"""
Expand Down Expand Up @@ -170,7 +173,10 @@ def __init__(
)

def get_slopes(
self: SpectrumArray, fit_func: str = 'fixed', scale: bool = False, fit_bounds: tuple[float, float] | None = None
self: SpectrumArray,
fit_func: str = 'fixed',
scale: bool = False,
fit_bounds: tuple[float, float] | None = None,
) -> SlopeFit:
"""
This method can be used to extract aperiodic parameters from the aperiodic spectrum extracted from IRASA.
Expand Down Expand Up @@ -393,7 +399,10 @@ def __init__(
)

def get_slopes(
self: SpectrumArray, fit_func: str = 'fixed', scale: bool = False, fit_bounds: tuple[float, float] | None = None
self: SpectrumArray,
fit_func: str = 'fixed',
scale: bool = False,
fit_bounds: tuple[float, float] | None = None,
) -> SlopeFit:
"""
This method can be used to extract aperiodic parameters from the aperiodic spectrum extracted from IRASA.
Expand Down
11 changes: 8 additions & 3 deletions pyrasa/utils/fit_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from scipy.optimize import curve_fit


def get_args(f: Callable) -> list:
def _get_args(f: Callable) -> list:
return inspect.getfullargspec(f)[0][2:]


Expand Down Expand Up @@ -59,7 +59,12 @@ def func(self, x: np.ndarray, *args: float) -> np.ndarray:

@property
def curve_kwargs(self) -> dict[str, Any]:
return {}
return {
'maxfev': 10_000,
'ftol': 1e-5,
'xtol': 1e-5,
'gtol': 1e-5,
}

def add_infos_to_df(self, df_params: pd.DataFrame) -> pd.DataFrame:
return df_params
Expand All @@ -75,7 +80,7 @@ def fit_func(self) -> tuple[pd.DataFrame, pd.DataFrame]:
curve_kwargs = self.curve_kwargs
p, _ = curve_fit(self.func, self.freq, self.aperiodic_spectrum, **curve_kwargs)

my_args = get_args(self.func)
my_args = _get_args(self.func)
df_params = pd.DataFrame(dict(zip(my_args, p)), index=[0])
df_params['fit_type'] = self.label

Expand Down
6 changes: 4 additions & 2 deletions tests/test_mne.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from neurodsp.utils.sim import set_random_seed
from scipy.optimize import OptimizeWarning

from pyrasa.irasa_mne import irasa_epochs, irasa_raw

Expand All @@ -13,8 +14,9 @@ def test_mne(gen_mne_data_raw):

# test raw
irasa_raw_result = irasa_raw(mne_data, band=(0.25, 50), duration=2, hset_info=(1.0, 2.0, 0.05))
irasa_raw_result.aperiodic.get_slopes(fit_func='fixed')
irasa_raw_result.periodic.get_peaks(smoothing_window=2)
with pytest.warns(OptimizeWarning):
irasa_raw_result.aperiodic.get_slopes(fit_func='fixed')
irasa_raw_result.periodic.get_peaks(smoothing_window=2)

# test epochs
irasa_epoched_result = irasa_epochs(epochs, band=(0.5, 50), hset_info=(1.0, 2.0, 0.05))
Expand Down

0 comments on commit c6e4dea

Please sign in to comment.