diff --git a/pyrasa/utils/aperiodic_utils.py b/pyrasa/utils/aperiodic_utils.py index c1c89ac..fb47609 100644 --- a/pyrasa/utils/aperiodic_utils.py +++ b/pyrasa/utils/aperiodic_utils.py @@ -7,7 +7,7 @@ import pandas as pd from scipy.optimize import curve_fit -from pyrasa.utils.types import SlopeFit +from pyrasa.utils.types import SlopeFit, FitFun def fixed_model(x: np.ndarray, b0: float, b: float) -> np.ndarray: @@ -62,14 +62,16 @@ def _get_gof(psd: np.ndarray, psd_pred: np.ndarray, k: int, fit_func: str, semi_ def _compute_slope( aperiodic_spectrum: np.ndarray, freq: np.ndarray, - fit_func: str | Callable, + fit_func: AbstractFitFun, fit_bounds: tuple | None = None, scale_factor: float | int = 1, - curv_kwargs: dict = {}, semi_log: bool = True, ) -> tuple[pd.DataFrame, pd.DataFrame]: """get the slope of the aperiodic spectrum""" + if curv_kwargs is None: + curv_kwargs = {} + if isinstance(fit_func, str): off_guess = [aperiodic_spectrum[0]] if fit_bounds is None else fit_bounds[0] exp_guess = ( @@ -81,6 +83,7 @@ def _compute_slope( assert fit_func in valid_slope_functions, f'The slope fitting function has to be in {valid_slope_functions}' if fit_func == 'fixed': + fit_func_object = FixedFitFun(data) fit_f = fixed_model curv_kwargs['p0'] = np.array(off_guess + exp_guess) @@ -129,7 +132,7 @@ def _compute_slope( else: if semi_log: - p, _ = curve_fit(fit_func, freq, np.log10(aperiodic_spectrum), **curv_kwargs) + p, _ = curve_fit(fit_func_object, freq, np.log10(aperiodic_spectrum), **fit_func_object.curve_kwargs) else: p, _ = curve_fit(fit_func, freq, aperiodic_spectrum, **curv_kwargs) diff --git a/pyrasa/utils/fitfuncs.py b/pyrasa/utils/fitfuncs.py new file mode 100644 index 0000000..d895088 --- /dev/null +++ b/pyrasa/utils/fitfuncs.py @@ -0,0 +1,29 @@ +import abc +import numpy as np + + +class AbstractFitFun(abc.ABC): + def __init__(self, *args: Any, **kwargs: Any): + pass + + @abc.abstractmethod + def __call__(self, x: np.ndarray, *args: float, **kwargs: float) -> np.ndarray: + pass + + @property + def curve_kwargs(self) -> dict[str, Any]: + return {} + + +class FixedFitFun(AbstractFitFun): + def __init__(self, x: np.ndarray): + self.x = x + + def __call__(self, x: np.ndarray, b0: float, b: float, *args: float, **kwargs: float) -> np.ndarray: + y_hat = b0 - np.log10(x ** b) + + return y_hat + + @property + def curve_kwargs(self) -> dict[str, Any]: + return {"b0": 0.0, "b": 0.0} diff --git a/pyrasa/utils/types.py b/pyrasa/utils/types.py index c8795a1..edfcb22 100644 --- a/pyrasa/utils/types.py +++ b/pyrasa/utils/types.py @@ -11,6 +11,10 @@ def __call__( ) -> np.ndarray: ... +class FitFun(Protocol): + def __call__(self, x: np.ndarray, *args: float, **kwargs: float) -> np.ndarray: ... + + class IrasaSprintKwargsTyped(TypedDict): mfft: int hop: int