From c9b9856d6fc3dc2a8f36e042e1bb5877ba0ca388 Mon Sep 17 00:00:00 2001 From: Fabi Date: Sat, 10 Aug 2024 02:32:42 +0200 Subject: [PATCH] more docstrings --- pyrasa/irasa.py | 2 +- pyrasa/irasa_mne/__init__.py | 2 + pyrasa/irasa_mne/irasa_mne.py | 2 + pyrasa/irasa_mne/mne_objs.py | 2 + pyrasa/utils/fit_funcs.py | 106 ++++++++++++++++++++++++++++++++-- pyrasa/utils/irasa_utils.py | 1 + pyrasa/utils/types.py | 39 +++++++++++++ 7 files changed, 149 insertions(+), 5 deletions(-) diff --git a/pyrasa/irasa.py b/pyrasa/irasa.py index 5712fea..ec6598e 100644 --- a/pyrasa/irasa.py +++ b/pyrasa/irasa.py @@ -1,4 +1,4 @@ -"""Functions to compute the IRASA algorithm.""" +"""Functions to compute IRASA.""" from collections.abc import Callable from typing import TYPE_CHECKING, Any diff --git a/pyrasa/irasa_mne/__init__.py b/pyrasa/irasa_mne/__init__.py index 438b5d2..a0eb6c2 100644 --- a/pyrasa/irasa_mne/__init__.py +++ b/pyrasa/irasa_mne/__init__.py @@ -1,3 +1,5 @@ +"""Interface to use the IRASA algorithm with MNE objects.""" + from .irasa_mne import irasa_epochs, irasa_raw __all__ = ['irasa_epochs', 'irasa_raw'] diff --git a/pyrasa/irasa_mne/irasa_mne.py b/pyrasa/irasa_mne/irasa_mne.py index 1d714ab..bc65af0 100644 --- a/pyrasa/irasa_mne/irasa_mne.py +++ b/pyrasa/irasa_mne/irasa_mne.py @@ -1,3 +1,5 @@ +"""Interface to use the IRASA algorithm with MNE objects.""" + import mne import numpy as np diff --git a/pyrasa/irasa_mne/mne_objs.py b/pyrasa/irasa_mne/mne_objs.py index b3e1c59..33cd3ca 100644 --- a/pyrasa/irasa_mne/mne_objs.py +++ b/pyrasa/irasa_mne/mne_objs.py @@ -1,3 +1,5 @@ +"""Classes for the MNE Python interface.""" + # %% inherit from spectrum array import matplotlib diff --git a/pyrasa/utils/fit_funcs.py b/pyrasa/utils/fit_funcs.py index e04e300..41d0b81 100644 --- a/pyrasa/utils/fit_funcs.py +++ b/pyrasa/utils/fit_funcs.py @@ -1,3 +1,5 @@ +"""Classes used to model aperiodic spectra""" + import abc import inspect from collections.abc import Callable @@ -10,16 +12,61 @@ def _get_args(f: Callable) -> list: + """ + Extracts the argument names from a function, excluding the first two. + + Parameters + ---------- + f : Callable + The function or method from which to extract argument names. + + Returns + ------- + list + A list of argument names, excluding the first two. + """ + return inspect.getfullargspec(f)[0][2:] def _get_gof(psd: np.ndarray, psd_pred: np.ndarray, k: int, fit_type: str) -> pd.DataFrame: """ - get goodness of fit (i.e. mean squared error and R2) - BIC and AIC currently assume OLS - https://machinelearningmastery.com/probabilistic-model-selection-measures/ + Calculate the goodness of fit metrics for a given model prediction against + actual aperiodic power spectral density (PSD) data. + + This function computes several statistics to evaluate how well the predicted PSD values + match the observed PSD values. The metrics include Mean Squared Error (MSE), R-squared (R²), + Bayesian Information Criterion (BIC), and Akaike Information Criterion (AIC). + + Parameters + ---------- + psd : np.ndarray + The observed power spectral density values. + psd_pred : np.ndarray + The predicted power spectral density values from the model. + k : int + The number of parameters in the curve fitting function used to predict the `psd`. + fit_type : str + A description or label for the type of fit/model used, which will be included in the output DataFrame. + + Returns + ------- + pd.DataFrame + A DataFrame containing the goodness of fit metrics: + - 'mse': Mean Squared Error + - 'r_squared': R-squared value + - 'BIC': Bayesian Information Criterion + - 'AIC': Akaike Information Criterion + - 'fit_type': The type of fit/model used (provided as input) + + Notes + ----- + - BIC and AIC calculations currently assume Ordinary Least Squares (OLS) regression. + + References + ---------- + For further details on BIC and AIC, see: https://machinelearningmastery.com/probabilistic-model-selection-measures/ """ - # k number of parameters in curve fitting function # add np.log10 to psd residuals = psd - psd_pred @@ -39,6 +86,57 @@ def _get_gof(psd: np.ndarray, psd_pred: np.ndarray, k: int, fit_type: str) -> pd @define class AbstractFitFun(abc.ABC): + """ + Abstract base class for fitting functions used to model aperiodic spectra. + + This class provides a framework for defining and fitting models to aperiodic spectra. + It handles common functionality required for fitting a model, such as scaling and goodness-of-fit + computation. Subclasses should implement the `func` method to define the specific fitting function + used for curve fitting. + + Attributes + ---------- + freq : np.ndarray + The frequency values associated with the aperiodic spectrum data. + aperiodic_spectrum : np.ndarray + The aperiodic spectrum data to which the model will be fit. + scale_factor : int | float + A scaling factor used to adjust the fit results. + label : ClassVar[str] + A label to identify the type of fit or model used. Default is 'custom'. + log10_aperiodic : ClassVar[bool] + If True, the aperiodic spectrum values will be transformed using log10. Default is False. + log10_freq : ClassVar[bool] + If True, the frequency values will be transformed using log10. Default is False. + + Methods + ------- + __attrs_post_init__() + Post-initialization method to apply log10 transformations if specified. + func(x: np.ndarray, *args: float) -> np.ndarray + Abstract method to define the model function. Must be implemented by subclasses + and should be applicable to scipy.optimize.curve_fit. + curve_kwargs() -> dict[str, Any] + Returns keyword arguments for the curve fitting process. + add_infos_to_df(df_params: pd.DataFrame) -> pd.DataFrame + Method to add additional information to the parameters DataFrame. Can be overridden by subclasses. + handle_scaling(df_params: pd.DataFrame, scale_factor: float) -> pd.DataFrame + Adjusts the parameters DataFrame based on the scaling factor. Can be overridden by subclasses. + fit_func() -> tuple[pd.DataFrame, pd.DataFrame] + Fits the model to the data and returns DataFrames containing the model parameters and goodness-of-fit metrics. + + Notes + ----- + - Subclasses must implement the `func` method to define the model's functional form. + - The `curve_kwargs` method can be overridden to customize curve fitting options. + - The `add_infos_to_df` and `handle_scaling` methods are intended to be overridden if additional + functionality or specific scaling behavior is required. + + References + ---------- + For details on goodness-of-fit metrics and their calculations, see the documentation for `_get_gof`. + """ + freq: np.ndarray aperiodic_spectrum: np.ndarray scale_factor: int | float diff --git a/pyrasa/utils/irasa_utils.py b/pyrasa/utils/irasa_utils.py index 47eda76..9882430 100644 --- a/pyrasa/utils/irasa_utils.py +++ b/pyrasa/utils/irasa_utils.py @@ -72,6 +72,7 @@ def _get_windows( nperseg: int, dpss_settings: dict, win_func: Callable, win_func_kwargs: dict ) -> tuple[np.ndarray, np.ndarray]: """Generate a window function used for tapering""" + low_bias_ratio = 0.9 min_time_bandwidth = 2.0 win_func_kwargs = copy(win_func_kwargs) diff --git a/pyrasa/utils/types.py b/pyrasa/utils/types.py index 31ab85d..ef49197 100644 --- a/pyrasa/utils/types.py +++ b/pyrasa/utils/types.py @@ -1,3 +1,5 @@ +"""Custom classes for pyrasa.""" + from typing import Protocol, TypedDict import numpy as np @@ -6,12 +8,49 @@ class IrasaFun(Protocol): + """ + A protocol defining the interface for an IRASA function used in the PyRASA library. + + The `IrasaFun` protocol specifies the expected signature of a function used to separate + aperiodic and periodic components of a power spectrum using the IRASA algorithm. + Any function conforming to this protocol can be passed to other PyRASA functions + as a custom IRASA implementation. + + Methods + ------- + __call__(data: np.ndarray, fs: int, h: float, + up_down: str | None, time_orig: np.ndarray | None = None) -> np.ndarray + Separates the input data into its aperiodic and periodic components based on the IRASA method. + + Parameters + ---------- + data : np.ndarray + The input time series data to be analyzed. + fs : int + The sampling frequency of the input data. + h : float + The resampling factor used in the IRASA algorithm. + up_down : str | None + A string indicating the direction of resampling ('up' or 'down'). + If None, no resampling is performed. + time_orig : np.ndarray | None, optional + The original time points of the data, used for interpolation if necessary. + If None, no interpolation is performed. + + Returns + ------- + np.ndarray + The output of the IRASA function. + """ + def __call__( self, data: np.ndarray, fs: int, h: float, up_down: str | None, time_orig: np.ndarray | None = None ) -> np.ndarray: ... class IrasaSprintKwargsTyped(TypedDict): + """TypedDict for the IRASA sprint function.""" + nfft: int hop: int win_duration: float