Skip to content

Commit

Permalink
more docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabi committed Aug 10, 2024
1 parent e462006 commit c9b9856
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyrasa/irasa.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Functions to compute the IRASA algorithm."""
"""Functions to compute IRASA."""

from collections.abc import Callable
from typing import TYPE_CHECKING, Any
Expand Down
2 changes: 2 additions & 0 deletions pyrasa/irasa_mne/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Interface to use the IRASA algorithm with MNE objects."""

from .irasa_mne import irasa_epochs, irasa_raw

__all__ = ['irasa_epochs', 'irasa_raw']
2 changes: 2 additions & 0 deletions pyrasa/irasa_mne/irasa_mne.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Interface to use the IRASA algorithm with MNE objects."""

import mne
import numpy as np

Expand Down
2 changes: 2 additions & 0 deletions pyrasa/irasa_mne/mne_objs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Classes for the MNE Python interface."""

# %% inherit from spectrum array

import matplotlib
Expand Down
106 changes: 102 additions & 4 deletions pyrasa/utils/fit_funcs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Classes used to model aperiodic spectra"""

import abc
import inspect
from collections.abc import Callable
Expand All @@ -10,16 +12,61 @@


def _get_args(f: Callable) -> list:
"""
Extracts the argument names from a function, excluding the first two.
Parameters
----------
f : Callable
The function or method from which to extract argument names.
Returns
-------
list
A list of argument names, excluding the first two.
"""

return inspect.getfullargspec(f)[0][2:]


def _get_gof(psd: np.ndarray, psd_pred: np.ndarray, k: int, fit_type: str) -> pd.DataFrame:
"""
get goodness of fit (i.e. mean squared error and R2)
BIC and AIC currently assume OLS
https://machinelearningmastery.com/probabilistic-model-selection-measures/
Calculate the goodness of fit metrics for a given model prediction against
actual aperiodic power spectral density (PSD) data.
This function computes several statistics to evaluate how well the predicted PSD values
match the observed PSD values. The metrics include Mean Squared Error (MSE), R-squared (R²),
Bayesian Information Criterion (BIC), and Akaike Information Criterion (AIC).
Parameters
----------
psd : np.ndarray
The observed power spectral density values.
psd_pred : np.ndarray
The predicted power spectral density values from the model.
k : int
The number of parameters in the curve fitting function used to predict the `psd`.
fit_type : str
A description or label for the type of fit/model used, which will be included in the output DataFrame.
Returns
-------
pd.DataFrame
A DataFrame containing the goodness of fit metrics:
- 'mse': Mean Squared Error
- 'r_squared': R-squared value
- 'BIC': Bayesian Information Criterion
- 'AIC': Akaike Information Criterion
- 'fit_type': The type of fit/model used (provided as input)
Notes
-----
- BIC and AIC calculations currently assume Ordinary Least Squares (OLS) regression.
References
----------
For further details on BIC and AIC, see: https://machinelearningmastery.com/probabilistic-model-selection-measures/
"""
# k number of parameters in curve fitting function

# add np.log10 to psd
residuals = psd - psd_pred
Expand All @@ -39,6 +86,57 @@ def _get_gof(psd: np.ndarray, psd_pred: np.ndarray, k: int, fit_type: str) -> pd

@define
class AbstractFitFun(abc.ABC):
"""
Abstract base class for fitting functions used to model aperiodic spectra.
This class provides a framework for defining and fitting models to aperiodic spectra.
It handles common functionality required for fitting a model, such as scaling and goodness-of-fit
computation. Subclasses should implement the `func` method to define the specific fitting function
used for curve fitting.
Attributes
----------
freq : np.ndarray
The frequency values associated with the aperiodic spectrum data.
aperiodic_spectrum : np.ndarray
The aperiodic spectrum data to which the model will be fit.
scale_factor : int | float
A scaling factor used to adjust the fit results.
label : ClassVar[str]
A label to identify the type of fit or model used. Default is 'custom'.
log10_aperiodic : ClassVar[bool]
If True, the aperiodic spectrum values will be transformed using log10. Default is False.
log10_freq : ClassVar[bool]
If True, the frequency values will be transformed using log10. Default is False.
Methods
-------
__attrs_post_init__()
Post-initialization method to apply log10 transformations if specified.
func(x: np.ndarray, *args: float) -> np.ndarray
Abstract method to define the model function. Must be implemented by subclasses
and should be applicable to scipy.optimize.curve_fit.
curve_kwargs() -> dict[str, Any]
Returns keyword arguments for the curve fitting process.
add_infos_to_df(df_params: pd.DataFrame) -> pd.DataFrame
Method to add additional information to the parameters DataFrame. Can be overridden by subclasses.
handle_scaling(df_params: pd.DataFrame, scale_factor: float) -> pd.DataFrame
Adjusts the parameters DataFrame based on the scaling factor. Can be overridden by subclasses.
fit_func() -> tuple[pd.DataFrame, pd.DataFrame]
Fits the model to the data and returns DataFrames containing the model parameters and goodness-of-fit metrics.
Notes
-----
- Subclasses must implement the `func` method to define the model's functional form.
- The `curve_kwargs` method can be overridden to customize curve fitting options.
- The `add_infos_to_df` and `handle_scaling` methods are intended to be overridden if additional
functionality or specific scaling behavior is required.
References
----------
For details on goodness-of-fit metrics and their calculations, see the documentation for `_get_gof`.
"""

freq: np.ndarray
aperiodic_spectrum: np.ndarray
scale_factor: int | float
Expand Down
1 change: 1 addition & 0 deletions pyrasa/utils/irasa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def _get_windows(
nperseg: int, dpss_settings: dict, win_func: Callable, win_func_kwargs: dict
) -> tuple[np.ndarray, np.ndarray]:
"""Generate a window function used for tapering"""

low_bias_ratio = 0.9
min_time_bandwidth = 2.0
win_func_kwargs = copy(win_func_kwargs)
Expand Down
39 changes: 39 additions & 0 deletions pyrasa/utils/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Custom classes for pyrasa."""

from typing import Protocol, TypedDict

import numpy as np
Expand All @@ -6,12 +8,49 @@


class IrasaFun(Protocol):
"""
A protocol defining the interface for an IRASA function used in the PyRASA library.
The `IrasaFun` protocol specifies the expected signature of a function used to separate
aperiodic and periodic components of a power spectrum using the IRASA algorithm.
Any function conforming to this protocol can be passed to other PyRASA functions
as a custom IRASA implementation.
Methods
-------
__call__(data: np.ndarray, fs: int, h: float,
up_down: str | None, time_orig: np.ndarray | None = None) -> np.ndarray
Separates the input data into its aperiodic and periodic components based on the IRASA method.
Parameters
----------
data : np.ndarray
The input time series data to be analyzed.
fs : int
The sampling frequency of the input data.
h : float
The resampling factor used in the IRASA algorithm.
up_down : str | None
A string indicating the direction of resampling ('up' or 'down').
If None, no resampling is performed.
time_orig : np.ndarray | None, optional
The original time points of the data, used for interpolation if necessary.
If None, no interpolation is performed.
Returns
-------
np.ndarray
The output of the IRASA function.
"""

def __call__(
self, data: np.ndarray, fs: int, h: float, up_down: str | None, time_orig: np.ndarray | None = None
) -> np.ndarray: ...


class IrasaSprintKwargsTyped(TypedDict):
"""TypedDict for the IRASA sprint function."""

nfft: int
hop: int
win_duration: float
Expand Down

0 comments on commit c9b9856

Please sign in to comment.