Skip to content

Commit

Permalink
some type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
Schmidt Fabian committed Jul 29, 2024
1 parent 438ef79 commit 3b464ed
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 34 deletions.
21 changes: 17 additions & 4 deletions pyrasa/irasa.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import fractions
from collections.abc import Callable
from typing import TypedDict

import numpy as np
import scipy.signal as dsp
Expand Down Expand Up @@ -195,7 +196,7 @@ def irasa_sprint( # noqa PLR0915 C901
band: tuple[float, float] = (1.0, 100.0),
freq_res: float = 0.5,
smooth: bool = True,
n_avgs: int = 1,
n_avgs: list = [1],
win_duration: float = 0.4,
hop: int = 10,
win_func: Callable = dsp.windows.hann,
Expand Down Expand Up @@ -302,7 +303,19 @@ def irasa_sprint( # noqa PLR0915 C901
'eigenvalue_weighting': dpss_eigenvalue_weighting,
}

irasa_kwargs = {
class IrasaKwargsTyped(TypedDict):
mfft: int
hop: int
win_duration: float
h: int | None
up_down: str | None
dpss_settings: dict
win_kwargs: dict
time_orig: None | np.ndarray
smooth: bool
n_avgs: list

irasa_kwargs: IrasaKwargsTyped = {
'mfft': mfft,
'hop': hop,
'win_duration': win_duration,
Expand All @@ -327,7 +340,7 @@ def _compute_sgramm( # noqa C901
h: int | None = None,
time_orig: np.ndarray | None = None,
smooth: bool = True,
n_avgs: int = 3,
n_avgs: list = [3],
spectrum_only: bool = False,
) -> tuple[np.ndarray, np.ndarray, np.ndarray] | np.ndarray:
"""Function to compute spectrograms"""
Expand Down Expand Up @@ -404,7 +417,7 @@ def sgramm_smoother(sgramm: np.ndarray, n_avgs: int) -> np.ndarray:
fs=fs,
irasa_fun=_compute_sgramm,
hset=hset,
irasa_kwargs=irasa_kwargs,
irasa_kwargs=dict(irasa_kwargs),
time=time,
)

Expand Down
14 changes: 9 additions & 5 deletions pyrasa/utils/aperiodic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from scipy.optimize import curve_fit


def fixed_model(x, b0, b):
def fixed_model(x: np.ndarray, b0: float, b: float) -> np.ndarray:
"""
Specparams fixed fitting function.
Use this to model aperiodic activity without a spectral knee
Expand All @@ -19,7 +19,7 @@ def fixed_model(x, b0, b):
return y_hat


def knee_model(x, b0, k, b1, b2):
def knee_model(x: np.ndarray, b0: float, k: float, b1: float, b2: float) -> np.ndarray:
"""
Model aperiodic activity with a spectral knee and a pre-knee slope.
Use this to model aperiodic activity with a spectral knee
Expand Down Expand Up @@ -56,7 +56,11 @@ def _get_gof(psd: np.ndarray, psd_pred: np.ndarray, fit_func: str) -> pd.DataFra


def _compute_slope(
aperiodic_spectrum: np.ndarray, freq: np.ndarray, fit_func: str, fit_bounds: tuple | None = None, scale_factor=1
aperiodic_spectrum: np.ndarray,
freq: np.ndarray,
fit_func: str,
fit_bounds: tuple | None = None,
scale_factor: float | int = 1,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""get the slope of the aperiodic spectrum"""

Expand Down Expand Up @@ -193,7 +197,7 @@ def compute_slope(

if scale:

def num_zeros(decimal):
def num_zeros(decimal: int) -> float:
return np.inf if decimal == 0 else -np.floor(np.log10(abs(decimal))) - 1

scale_factor = 10 ** num_zeros(aperiodic_spectrum.min())
Expand Down Expand Up @@ -231,7 +235,7 @@ def compute_slope_sprint(
fit_func: str,
ch_names: Iterable = (),
fit_bounds: tuple[float, float] | None = None,
):
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
This function can be used to extract aperiodic parameters from the aperiodic spectrogram extracted from IRASA.
The algorithm works by applying one of two different curve fit functions and returns the associated parameters,
Expand Down
31 changes: 12 additions & 19 deletions pyrasa/utils/irasa_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Utilities for signal decompositon using IRASA"""

from collections.abc import Callable
from copy import copy

import numpy as np
import scipy.signal as dsp


def _crop_data(band, freqs, psd_aperiodic, psd_periodic, axis):
def _crop_data(
band: list | tuple, freqs: np.ndarray, psd_aperiodic: np.ndarray, psd_periodic: np.ndarray, axis: int
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Utility function to crop spectra to a defined frequency range"""

mask_freqs = np.ma.masked_outside(freqs, *band).mask
Expand All @@ -17,7 +20,7 @@ def _crop_data(band, freqs, psd_aperiodic, psd_periodic, axis):
return freqs, psd_aperiodic, psd_periodic


def _gen_time_from_sft(SFT, sgramm): # noqa N803
def _gen_time_from_sft(SFT: type[dsp.ShortTimeFFT], sgramm: np.ndarray) -> np.ndarray: # noqa N803
"""Generates time from SFT object"""

tmin, tmax = SFT.extent(sgramm.shape[-1])[:2]
Expand All @@ -27,7 +30,7 @@ def _gen_time_from_sft(SFT, sgramm): # noqa N803
return time


def _find_nearest(sgramm_ud, time_array, time_value):
def _find_nearest(sgramm_ud: np.ndarray, time_array: np.ndarray, time_value: float) -> np.ndarray:
"""Find the nearest time point in an up/downsampled spectrogram"""

idx = (np.abs(time_array - time_value)).argmin()
Expand All @@ -40,7 +43,9 @@ def _find_nearest(sgramm_ud, time_array, time_value):
return sgramm_sel


def _get_windows(nperseg, dpss_settings, win_func, win_func_kwargs):
def _get_windows(
nperseg: int, dpss_settings: dict, win_func: Callable, win_func_kwargs: dict
) -> tuple[np.ndarray, np.ndarray]:
"""Generate a window function used for tapering"""
low_bias_ratio = 0.9
max_time_bandwidth = 2.0
Expand Down Expand Up @@ -72,7 +77,7 @@ def _get_windows(nperseg, dpss_settings, win_func, win_func_kwargs):
return win, ratios


def _check_irasa_settings(irasa_params, hset_info):
def _check_irasa_settings(irasa_params: dict, hset_info: tuple) -> None:
"""Check if the input parameters for irasa are specified correctly"""

valid_hset_shape = 3
Expand All @@ -87,14 +92,14 @@ def _check_irasa_settings(irasa_params, hset_info):
# check that evaluated range fits with the data settings
nyquist = irasa_params['fs'] / 2
hmax = np.max(hset_info)
band_evaluated = (irasa_params['band'][0] / hmax, irasa_params['band'][1] * hmax)
band_evaluated: tuple[float, float] = (irasa_params['band'][0] / hmax, irasa_params['band'][1] * hmax)
assert band_evaluated[0] > 0, 'The evaluated frequency range is 0 or lower this makes no sense'
assert band_evaluated[1] < nyquist, (
f'The evaluated frequency range goes up to {np.round(band_evaluated[1], 2)}Hz '
'which is higher than Nyquist (fs / 2)'
)

filter_settings = list(irasa_params['filter_settings'])
filter_settings: list[float] = list(irasa_params['filter_settings'])
if filter_settings[0] is None:
filter_settings[0] = band_evaluated[0]
if filter_settings[1] is None:
Expand All @@ -114,15 +119,3 @@ def _check_irasa_settings(irasa_params, hset_info):
f'> {np.round(filter_settings[0], irasa_params['hset_accuracy'])} '
f'and that band[1] * hset.max() < {np.round(filter_settings[1], irasa_params['hset_accuracy'])}'
)


def _check_psd_settings_raw(data_array, fs, duration, overlap):
"""LEGACY: Check if the kwargs for welch are specified correctly"""

# check parameters for welch
overlap /= 100
assert isinstance(duration, int | float), 'You need to set the duration of your time window in seconds'
assert data_array.shape[1] > int(fs * duration), 'The duration for each segment cant be longer than the actual data'
assert np.logical_and(
overlap < 1, overlap > 0
), 'The overlap between segments cant be larger than 100% or less than 0%'
12 changes: 6 additions & 6 deletions pyrasa/utils/peak_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_peak_params(
peak_threshold: float = 1.0,
min_peak_height: float = 0.01,
peak_width_limits: tuple[float, float] = (0.5, 6.0),
):
) -> pd.DataFrame:
"""
This function can be used to extract peak parameters from the periodic spectrum extracted from IRASA.
The algorithm works by smoothing the spectrum, zeroing out negative values and
Expand Down Expand Up @@ -123,10 +123,10 @@ def get_peak_params_sprint(
smoothing_window: int = 1,
polyorder: int = 1,
cut_spectrum: tuple[float, float] | None = None,
peak_threshold=1,
min_peak_height=0.01,
peak_width_limits=(0.5, 6),
):
peak_threshold: int = 1,
min_peak_height: float = 0.01,
peak_width_limits: tuple[float, float] = (0.5, 6),
) -> pd.DataFrame:
"""
This function can be used to extract peak parameters from the periodic spectrum extracted from IRASA.
The algorithm works by smoothing the spectrum, zeroing out negative values and
Expand Down Expand Up @@ -185,7 +185,7 @@ def get_peak_params_sprint(


# %% find peaks irasa style
def get_band_info(df_peaks, freq_range, ch_names):
def get_band_info(df_peaks: pd.DataFrame, freq_range: tuple[int, int], ch_names: list) -> pd.DataFrame:
"""
This function can be used to extract peaks in a specified frequency range
from the Peak DataFrame obtained via "get_peak_params".
Expand Down

0 comments on commit 3b464ed

Please sign in to comment.