diff --git a/pixi.lock b/pixi.lock index f9729fc..46a3615 100644 --- a/pixi.lock +++ b/pixi.lock @@ -28539,8 +28539,9 @@ packages: name: pyrasa version: 0.1.0.dev0 path: . - sha256: be3bb8dc71703f61f5ebe96399791305ea95ca48b8f56dfdb4eb24655e185172 + sha256: 39640bc3e1a4ead8752644076a356e607560548282bc8947616e21831a4d7587 requires_dist: + - attrs - numpy - pandas - scipy>=1.12 diff --git a/pyproject.toml b/pyproject.toml index 21de3c7..4bd7e79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ ] keywords=['spectral parametrization', 'oscillations', 'power spectra', '1/f'] requires-python = ">= 3.11" -dependencies = ["numpy", "pandas", "scipy>=1.12"] +dependencies = ["numpy", "pandas", "scipy>=1.12", "attrs"] [project.optional-dependencies] mne = ['mne'] diff --git a/pyrasa/irasa.py b/pyrasa/irasa.py index 49a970b..279a5e2 100644 --- a/pyrasa/irasa.py +++ b/pyrasa/irasa.py @@ -1,65 +1,23 @@ -import fractions from collections.abc import Callable from typing import TYPE_CHECKING, Any import numpy as np import scipy.signal as dsp +from pyrasa.utils.irasa_spectrum import IrasaSpectrum +from pyrasa.utils.irasa_tf_spectrum import IrasaTfSpectrum + # from scipy.stats.mstats import gmean from pyrasa.utils.irasa_utils import ( _check_irasa_settings, _compute_psd_welch, _compute_sgramm, - _crop_data, # _find_nearest, _gen_time_from_sft, _get_windows, + _crop_data, + _gen_irasa, ) -from pyrasa.utils.types import IrasaFun if TYPE_CHECKING: - from pyrasa.utils.input_classes import IrasaSprintKwargsTyped - - -# TODO: Port to Cython -def _gen_irasa( - data: np.ndarray, - orig_spectrum: np.ndarray, - fs: int, - irasa_fun: IrasaFun, - hset: np.ndarray, - time: np.ndarray | None = None, -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - This function is implementing the IRASA algorithm using a custom function to - compute a power/cross-spectral density and returns an "original", "periodic" and "aperiodic spectrum". - This implementation of the IRASA algorithm is based on the yasa.irasa function in (Vallat & Walker, 2021). - - [1] Vallat, Raphael, and Matthew P. Walker. “An open-source, - high-performance tool for automated sleep staging.” - Elife 10 (2021). doi: https://doi.org/10.7554/eLife.70092 - """ - - spectra = np.zeros((len(hset), *orig_spectrum.shape)) - for i, h in enumerate(hset): - rat = fractions.Fraction(str(h)) - up, down = rat.numerator, rat.denominator - - # Much faster than FFT-based resampling - data_up = dsp.resample_poly(data, up, down, axis=-1) - data_down = dsp.resample_poly(data, down, up, axis=-1) - - # Calculate an up/downsampled version of the PSD using same params as original - spectrum_up = irasa_fun(data=data_up, fs=int(fs * h), h=h, time_orig=time, up_down='up') - spectrum_dw = irasa_fun(data=data_down, fs=int(fs / h), h=h, time_orig=time, up_down='down') - - # geometric mean between up and downsampled - # be aware of the input dimensions - if spectra.ndim == 2: # noqa PLR2004 - spectra[i, :] = np.sqrt(spectrum_up * spectrum_dw) - if spectra.ndim == 3: # noqa PLR2004 - spectra[i, :, :] = np.sqrt(spectrum_up * spectrum_dw) - - aperiodic_spectrum = np.median(spectra, axis=0) - periodic_spectrum = orig_spectrum - aperiodic_spectrum - return orig_spectrum, aperiodic_spectrum, periodic_spectrum + from pyrasa.utils.types import IrasaSprintKwargsTyped # %% irasa @@ -68,6 +26,7 @@ def irasa( fs: int, band: tuple[float, float], psd_kwargs: dict, + ch_names: np.ndarray | None = None, win_func: Callable = dsp.windows.hann, win_func_kwargs: dict | None = None, dpss_settings_time_bandwidth: float = 2.0, @@ -76,7 +35,7 @@ def irasa( filter_settings: tuple[float | None, float | None] = (None, None), hset_info: tuple[float, float, float] = (1.05, 2.0, 0.05), hset_accuracy: int = 4, -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: +) -> IrasaSpectrum: """ This function can be used to generate aperiodic and periodic power spectra from a time series using the IRASA algorithm (Wen & Liu, 2016). @@ -181,15 +140,18 @@ def _local_irasa_fun( data=np.squeeze(data), orig_spectrum=psd, fs=fs, irasa_fun=_local_irasa_fun, hset=hset ) - freq, psd_aperiodic, psd_periodic = _crop_data(band, freq, psd_aperiodic, psd_periodic, axis=-1) + freq, psd_aperiodic, psd_periodic, psd = _crop_data(band, freq, psd_aperiodic, psd_periodic, psd, axis=-1) - return freq, psd_aperiodic, psd_periodic + return IrasaSpectrum( + freqs=freq, raw_spectrum=psd, aperiodic=psd_aperiodic, periodic=psd_periodic, ch_names=ch_names + ) # irasa sprint def irasa_sprint( # noqa PLR0915 C901 data: np.ndarray, fs: int, + ch_names: np.ndarray | None = None, band: tuple[float, float] = (1.0, 100.0), freq_res: float = 0.5, win_duration: float = 0.4, @@ -202,7 +164,7 @@ def irasa_sprint( # noqa PLR0915 C901 filter_settings: tuple[float | None, float | None] = (None, None), hset_info: tuple[float, float, float] = (1.05, 2.0, 0.05), hset_accuracy: int = 4, -) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: +) -> IrasaTfSpectrum: """ This function can be used to seperate aperiodic from periodic power spectra @@ -328,10 +290,19 @@ def _local_irasa_fun( ) # NOTE: we need to transpose the data as crop_data extracts stuff from the last axis - freq, sgramm_aperiodic, sgramm_periodic = _crop_data(band, freq, sgramm_aperiodic, sgramm_periodic, axis=0) + freq, sgramm_aperiodic, sgramm_periodic, sgramm = _crop_data( + band, freq, sgramm_aperiodic, sgramm_periodic, sgramm, axis=0 + ) # adjust time info (i.e. cut the padded stuff) tmax = data.shape[1] / fs t_mask = np.logical_and(time >= 0, time < tmax) - return sgramm_aperiodic[:, t_mask], sgramm_periodic[:, t_mask], freq, time[t_mask] + return IrasaTfSpectrum( + freqs=freq, + time=time[t_mask], + raw_spectrum=sgramm, + periodic=sgramm_periodic[:, t_mask], + aperiodic=sgramm_aperiodic[:, t_mask], + ch_names=ch_names, + ) diff --git a/pyrasa/irasa_mne/irasa_mne.py b/pyrasa/irasa_mne/irasa_mne.py index 1f92fcb..41c5eae 100644 --- a/pyrasa/irasa_mne/irasa_mne.py +++ b/pyrasa/irasa_mne/irasa_mne.py @@ -5,6 +5,8 @@ from pyrasa.irasa_mne.mne_objs import ( AperiodicEpochsSpectrum, AperiodicSpectrumArray, + IrasaEpoched, + IrasaRaw, PeriodicEpochsSpectrum, PeriodicSpectrumArray, ) @@ -16,8 +18,7 @@ def irasa_raw( duration: float | None = None, overlap: float | int = 50, hset_info: tuple[float, float, float] = (1.05, 2.0, 0.05), - as_array: bool = False, -) -> tuple[AperiodicSpectrumArray, PeriodicSpectrumArray] | tuple[np.ndarray, np.ndarray, np.ndarray]: +) -> IrasaRaw: """ This function can be used to seperate aperiodic from periodic power spectra using the IRASA algorithm (Wen & Liu, 2016). @@ -87,7 +88,7 @@ def irasa_raw( 'noverlap': int(fs * duration * overlap), } - freq, psd_aperiodic, psd_periodic = irasa( + irasa_spectrum = irasa( data_array, fs=fs, band=band, @@ -96,22 +97,17 @@ def irasa_raw( psd_kwargs=kwargs_psd, ) - if as_array is True: - return psd_aperiodic, psd_periodic, freq - - else: - aperiodic = AperiodicSpectrumArray(psd_aperiodic, info, freqs=freq) - periodic = PeriodicSpectrumArray(psd_periodic, info, freqs=freq) - - return aperiodic, periodic + return IrasaRaw( + periodic=PeriodicSpectrumArray(irasa_spectrum.periodic, info, freqs=irasa_spectrum.freqs), + aperiodic=AperiodicSpectrumArray(irasa_spectrum.aperiodic, info, freqs=irasa_spectrum.freqs), + ) def irasa_epochs( data: mne.Epochs, band: tuple[float, float] = (1.0, 100.0), hset_info: tuple[float, float, float] = (1.05, 2.0, 0.05), - as_array: bool = False, -) -> tuple[AperiodicSpectrumArray, PeriodicSpectrumArray] | tuple[np.ndarray, np.ndarray, np.ndarray]: +) -> IrasaEpoched: """ This function can be used to seperate aperiodic from periodic power spectra using the IRASA algorithm (Wen & Liu, 2016). @@ -177,7 +173,7 @@ def irasa_epochs( # Do the actual IRASA stuff.. psd_list_aperiodic, psd_list_periodic = [], [] for epoch in data_array: - freq, psd_aperiodic, psd_periodic = irasa( + irasa_spectrum = irasa( epoch, fs=fs, band=band, @@ -185,16 +181,17 @@ def irasa_epochs( hset_info=hset_info, psd_kwargs=kwargs_psd, ) - psd_list_aperiodic.append(psd_aperiodic) - psd_list_periodic.append(psd_periodic) - - psd_aperiodic = np.array(psd_list_aperiodic) - psd_periodic = np.array(psd_list_periodic) - - if as_array is True: - return psd_aperiodic, psd_periodic, freq - else: - aperiodic = AperiodicEpochsSpectrum(psd_aperiodic, info, freqs=freq, events=events, event_id=event_ids) - periodic = PeriodicEpochsSpectrum(psd_periodic, info, freqs=freq, events=events, event_id=event_ids) - - return aperiodic, periodic + psd_list_aperiodic.append(irasa_spectrum.aperiodic.copy()) + psd_list_periodic.append(irasa_spectrum.periodic.copy()) + + psds_aperiodic = np.array(psd_list_aperiodic) + psds_periodic = np.array(psd_list_periodic) + + return IrasaEpoched( + periodic=PeriodicEpochsSpectrum( + psds_periodic, info, freqs=irasa_spectrum.freqs, events=events, event_id=event_ids + ), + aperiodic=AperiodicEpochsSpectrum( + psds_aperiodic, info, freqs=irasa_spectrum.freqs, events=events, event_id=event_ids + ), + ) diff --git a/pyrasa/irasa_mne/mne_objs.py b/pyrasa/irasa_mne/mne_objs.py index c420384..b5ce0ed 100644 --- a/pyrasa/irasa_mne/mne_objs.py +++ b/pyrasa/irasa_mne/mne_objs.py @@ -3,10 +3,12 @@ import mne import numpy as np import pandas as pd +from attrs import define from mne.time_frequency import EpochsSpectrumArray, SpectrumArray from pyrasa.utils.aperiodic_utils import compute_slope from pyrasa.utils.peak_utils import get_peak_params +from pyrasa.utils.types import SlopeFit class PeriodicSpectrumArray(SpectrumArray): @@ -169,7 +171,7 @@ def __init__( def get_slopes( self: SpectrumArray, fit_func: str = 'fixed', scale: bool = False, fit_bounds: tuple[float, float] | None = None - ) -> tuple[pd.DataFrame, pd.DataFrame]: + ) -> SlopeFit: """ This method can be used to extract aperiodic parameters from the aperiodic spectrum extracted from IRASA. The algorithm works by applying one of two different curve fit functions and returns the associated parameters, @@ -190,7 +192,7 @@ def get_slopes( """ - df_aps, df_gof = compute_slope( + return compute_slope( self.get_data(), self.freqs, ch_names=self.ch_names, @@ -199,8 +201,6 @@ def get_slopes( fit_bounds=fit_bounds, ) - return df_aps, df_gof - # %% class PeriodicEpochsSpectrum(EpochsSpectrumArray): @@ -394,7 +394,7 @@ def __init__( def get_slopes( self: SpectrumArray, fit_func: str = 'fixed', scale: bool = False, fit_bounds: tuple[float, float] | None = None - ) -> tuple[pd.DataFrame, pd.DataFrame]: + ) -> SlopeFit: """ This method can be used to extract aperiodic parameters from the aperiodic spectrum extracted from IRASA. The algorithm works by applying one of two different curve fit functions and returns the associated parameters, @@ -420,7 +420,7 @@ def get_slopes( aps_list, gof_list = [], [] for ix, cur_epoch in enumerate(self.get_data()): - df_aps, df_gof = compute_slope( + slope_fit = compute_slope( cur_epoch, self.freqs, ch_names=self.ch_names, @@ -429,9 +429,21 @@ def get_slopes( fit_bounds=fit_bounds, ) - df_aps['event_id'] = event_dict[events[ix]] - df_gof['event_id'] = event_dict[events[ix]] - aps_list.append(df_aps) - gof_list.append(df_gof) + slope_fit.aperiodic_params['event_id'] = event_dict[events[ix]] + slope_fit.gof['event_id'] = event_dict[events[ix]] + aps_list.append(slope_fit.aperiodic_params.copy()) + gof_list.append(slope_fit.gof.copy()) + + return SlopeFit(aperiodic_params=pd.concat(aps_list), gof=pd.concat(gof_list)) + + +@define +class IrasaRaw: + periodic: PeriodicSpectrumArray + aperiodic: AperiodicSpectrumArray + - return pd.concat(aps_list), pd.concat(gof_list) +@define +class IrasaEpoched: + periodic: PeriodicEpochsSpectrum + aperiodic: AperiodicEpochsSpectrum diff --git a/pyrasa/utils/aperiodic_utils.py b/pyrasa/utils/aperiodic_utils.py index 09669e3..030adf4 100644 --- a/pyrasa/utils/aperiodic_utils.py +++ b/pyrasa/utils/aperiodic_utils.py @@ -7,6 +7,8 @@ import pandas as pd from scipy.optimize import curve_fit +from pyrasa.utils.types import SlopeFit + def fixed_model(x: np.ndarray, b0: float, b: float) -> np.ndarray: """ @@ -134,10 +136,10 @@ def compute_slope( aperiodic_spectrum: np.ndarray, freqs: np.ndarray, fit_func: str, - ch_names: Iterable = (), + ch_names: Iterable | None = None, scale: bool = False, fit_bounds: tuple[float, float] | None = None, -) -> tuple[pd.DataFrame, pd.DataFrame]: +) -> SlopeFit: """ This function can be used to extract aperiodic parameters from the aperiodic spectrum extracted from IRASA. The algorithm works by applying one of two different curve fit functions and returns the associated parameters, @@ -177,8 +179,8 @@ def compute_slope( assert freqs.ndim == 1, 'freqs needs to be of shape (freqs,).' assert isinstance( - ch_names, list | tuple | np.ndarray - ), 'Channel names should be of type list, tuple or numpy.ndarray' + ch_names, list | tuple | np.ndarray | None + ), 'Channel names should be of type list, tuple or numpy.ndarray or None' if fit_bounds is not None: fmin, fmax = freqs.min(), freqs.max() @@ -194,7 +196,7 @@ def compute_slope( aperiodic_spectrum = aperiodic_spectrum[:, 1:] # generate channel names if not given - if len(ch_names) == 0: + if ch_names is None: ch_names = np.arange(aperiodic_spectrum.shape[0]) if scale: @@ -224,10 +226,7 @@ def num_zeros(decimal: int) -> float: gof_list.append(gof) # combine & return - df_aps = pd.concat(ap_list) - df_gof = pd.concat(gof_list) - - return df_aps, df_gof + return SlopeFit(aperiodic_params=pd.concat(ap_list), gof=pd.concat(gof_list)) def compute_slope_sprint( @@ -235,9 +234,10 @@ def compute_slope_sprint( freqs: np.ndarray, times: np.ndarray, fit_func: str, - ch_names: Iterable = (), + scale: bool = False, + ch_names: Iterable | None = None, fit_bounds: tuple[float, float] | None = None, -) -> tuple[pd.DataFrame, pd.DataFrame]: +) -> SlopeFit: """ This function can be used to extract aperiodic parameters from the aperiodic spectrogram extracted from IRASA. The algorithm works by applying one of two different curve fit functions and returns the associated parameters, @@ -268,16 +268,18 @@ def compute_slope_sprint( ap_t_list, gof_t_list = [], [] for ix, t in enumerate(times): - cur_aps, cur_gof = compute_slope( - aperiodic_spectrum[:, :, ix], freqs=freqs, fit_func=fit_func, ch_names=ch_names, fit_bounds=fit_bounds + slope_fit = compute_slope( + aperiodic_spectrum[:, :, ix], + freqs=freqs, + fit_func=fit_func, + ch_names=ch_names, + fit_bounds=fit_bounds, + scale=scale, ) - cur_aps['time'] = t - cur_gof['time'] = t - - ap_t_list.append(cur_aps) - gof_t_list.append(cur_gof) + slope_fit.aperiodic_params['time'] = t + slope_fit.gof['time'] = t - df_ap_time = pd.concat(ap_t_list) - df_gof_time = pd.concat(gof_t_list) + ap_t_list.append(slope_fit.aperiodic_params) + gof_t_list.append(slope_fit.gof) - return df_ap_time, df_gof_time + return SlopeFit(aperiodic_params=pd.concat(ap_t_list), gof=pd.concat(gof_t_list)) diff --git a/pyrasa/utils/input_classes.py b/pyrasa/utils/input_classes.py deleted file mode 100644 index 4090853..0000000 --- a/pyrasa/utils/input_classes.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import TypedDict - - -class IrasaSprintKwargsTyped(TypedDict): - mfft: int - hop: int - win_duration: float - dpss_settings: dict - win_kwargs: dict - # smooth: bool - # n_avgs: list diff --git a/pyrasa/utils/irasa_spectrum.py b/pyrasa/utils/irasa_spectrum.py new file mode 100644 index 0000000..9c7014b --- /dev/null +++ b/pyrasa/utils/irasa_spectrum.py @@ -0,0 +1,93 @@ +import numpy as np +import pandas as pd +from attrs import define + +from pyrasa.utils.aperiodic_utils import compute_slope +from pyrasa.utils.peak_utils import get_peak_params +from pyrasa.utils.types import SlopeFit + + +@define +class IrasaSpectrum: + freqs: np.ndarray + raw_spectrum: np.ndarray + aperiodic: np.ndarray + periodic: np.ndarray + ch_names: np.ndarray | None + + def get_slopes( + self, fit_func: str = 'fixed', scale: bool = False, fit_bounds: tuple[float, float] | None = None + ) -> SlopeFit: + """ + This method can be used to extract aperiodic parameters from the aperiodic spectrum extracted from IRASA. + The algorithm works by applying one of two different curve fit functions and returns the associated parameters, + as well as the respective goodness of fit. + + Parameters: + fit_func : string + Can be either "fixed" or "knee". + fit_bounds : None, tuple + Lower and upper bound for the fit function, + should be None if the whole frequency range is desired. + Otherwise a tuple of (lower, upper) + + Returns: SlopeFit + df_aps: DataFrame + DataFrame containing the center frequency, bandwidth and peak height for each channel + df_gof: DataFrame + DataFrame containing the goodness of fit of the specific fit function for each channel. + + """ + return compute_slope( + aperiodic_spectrum=self.aperiodic, + freqs=self.freqs, + ch_names=self.ch_names, + scale=scale, + fit_func=fit_func, + fit_bounds=fit_bounds, + ) + + def get_peaks( + self, + smoothing_window: float | int = 1, + cut_spectrum: tuple[float, float] | None = None, + peak_threshold: float = 2.5, + min_peak_height: float = 0.0, + polyorder: int = 1, + peak_width_limits: tuple[float, float] = (0.5, 12), + ) -> pd.DataFrame: + """ + This method can be used to extract peak parameters from the periodic spectrum extracted from IRASA. + The algorithm works by smoothing the spectrum, zeroing out negative values and + extracting peaks based on user specified parameters. + + Parameters: smoothing window : int, optional, default: 2 + Smoothing window in Hz handed over to the savitzky-golay filter. + cut_spectrum : tuple of (float, float), optional, default (1, 40) + Cut the periodic spectrum to limit peak finding to a sensible range + peak_threshold : float, optional, default: 1 + Relative threshold for detecting peaks. This threshold is defined in + relative units of the periodic spectrum + min_peak_height : float, optional, default: 0.01 + Absolute threshold for identifying peaks. The threhsold is defined in relative + units of the power spectrum. Setting this is somewhat necessary when a + "knee" is present in the data as it will carry over to the periodic spctrum in irasa. + peak_width_limits : tuple of (float, float), optional, default (.5, 12) + Limits on possible peak width, in Hz, as (lower_bound, upper_bound) + + Returns: df_peaks: DataFrame + DataFrame containing the center frequency, bandwidth and peak height for each channel + + """ + + return get_peak_params( + self.periodic, + self.freqs, + self.ch_names, + smoothing_window=smoothing_window, + cut_spectrum=cut_spectrum, + peak_threshold=peak_threshold, + min_peak_height=min_peak_height, + polyorder=polyorder, + peak_width_limits=peak_width_limits, + ) diff --git a/pyrasa/utils/irasa_tf_spectrum.py b/pyrasa/utils/irasa_tf_spectrum.py new file mode 100644 index 0000000..cb78fea --- /dev/null +++ b/pyrasa/utils/irasa_tf_spectrum.py @@ -0,0 +1,100 @@ +import numpy as np +import pandas as pd +from attrs import define + +from pyrasa.utils.aperiodic_utils import compute_slope_sprint +from pyrasa.utils.peak_utils import get_peak_params_sprint +from pyrasa.utils.types import SlopeFit + +min_ndim = 2 + + +@define +class IrasaTfSpectrum: + freqs: np.ndarray + time: np.ndarray + raw_spectrum: np.ndarray + aperiodic: np.ndarray + periodic: np.ndarray + ch_names: np.ndarray | None + + def get_slopes( + self, fit_func: str = 'fixed', scale: bool = False, fit_bounds: tuple[float, float] | None = None + ) -> SlopeFit: + """ + This method can be used to extract aperiodic parameters from the aperiodic spectrum extracted from IRASA. + The algorithm works by applying one of two different curve fit functions and returns the associated parameters, + as well as the respective goodness of fit. + + Parameters: + fit_func : string + Can be either "fixed" or "knee". + fit_bounds : None, tuple + Lower and upper bound for the fit function, + should be None if the whole frequency range is desired. + Otherwise a tuple of (lower, upper) + + Returns: SlopeFit + df_aps: DataFrame + DataFrame containing the center frequency, bandwidth and peak height for each channel + df_gof: DataFrame + DataFrame containing the goodness of fit of the specific fit function for each channel. + + """ + return compute_slope_sprint( + aperiodic_spectrum=self.aperiodic[np.newaxis, :, :] if self.aperiodic.ndim == min_ndim else self.aperiodic, + freqs=self.freqs, + times=self.time, + ch_names=self.ch_names, + scale=scale, + fit_func=fit_func, + fit_bounds=fit_bounds, + ) + + def get_peaks( + self, + smooth: bool = True, + smoothing_window: float | int = 1, + cut_spectrum: tuple[float, float] | None = None, + peak_threshold: float = 2.5, + min_peak_height: float = 0.0, + polyorder: int = 1, + peak_width_limits: tuple[float, float] = (0.5, 12), + ) -> pd.DataFrame: + """ + This method can be used to extract peak parameters from the periodic spectrum extracted from IRASA. + The algorithm works by smoothing the spectrum, zeroing out negative values and + extracting peaks based on user specified parameters. + + Parameters: smoothing window : int, optional, default: 2 + Smoothing window in Hz handed over to the savitzky-golay filter. + cut_spectrum : tuple of (float, float), optional, default (1, 40) + Cut the periodic spectrum to limit peak finding to a sensible range + peak_threshold : float, optional, default: 1 + Relative threshold for detecting peaks. This threshold is defined in + relative units of the periodic spectrum + min_peak_height : float, optional, default: 0.01 + Absolute threshold for identifying peaks. The threhsold is defined in relative + units of the power spectrum. Setting this is somewhat necessary when a + "knee" is present in the data as it will carry over to the periodic spctrum in irasa. + peak_width_limits : tuple of (float, float), optional, default (.5, 12) + Limits on possible peak width, in Hz, as (lower_bound, upper_bound) + + Returns: df_peaks: DataFrame + DataFrame containing the center frequency, bandwidth and peak height for each channel + + """ + + return get_peak_params_sprint( + periodic_spectrum=self.periodic[np.newaxis, :, :] if self.periodic.ndim == min_ndim else self.periodic, + freqs=self.freqs, + times=self.time, + ch_names=self.ch_names, + smooth=smooth, + smoothing_window=smoothing_window, + cut_spectrum=cut_spectrum, + peak_threshold=peak_threshold, + min_peak_height=min_peak_height, + polyorder=polyorder, + peak_width_limits=peak_width_limits, + ) diff --git a/pyrasa/utils/irasa_utils.py b/pyrasa/utils/irasa_utils.py index 3657f63..9df90b8 100644 --- a/pyrasa/utils/irasa_utils.py +++ b/pyrasa/utils/irasa_utils.py @@ -1,5 +1,6 @@ """Utilities for signal decompositon using IRASA""" +import fractions from collections.abc import Callable from copy import copy @@ -7,18 +8,70 @@ import scipy.signal as dsp from scipy.signal import ShortTimeFFT +from pyrasa.utils.types import IrasaFun -def _crop_data( - band: list | tuple, freqs: np.ndarray, psd_aperiodic: np.ndarray, psd_periodic: np.ndarray, axis: int + +# TODO: Port to Cython +def _gen_irasa( + data: np.ndarray, + orig_spectrum: np.ndarray, + fs: int, + irasa_fun: IrasaFun, + hset: np.ndarray, + time: np.ndarray | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + This function is implementing the IRASA algorithm using a custom function to + compute a power/cross-spectral density and returns an "original", "periodic" and "aperiodic spectrum". + This implementation of the IRASA algorithm is based on the yasa.irasa function in (Vallat & Walker, 2021). + + [1] Vallat, Raphael, and Matthew P. Walker. “An open-source, + high-performance tool for automated sleep staging.” + Elife 10 (2021). doi: https://doi.org/10.7554/eLife.70092 + """ + + spectra = np.zeros((len(hset), *orig_spectrum.shape)) + for i, h in enumerate(hset): + rat = fractions.Fraction(str(h)) + up, down = rat.numerator, rat.denominator + + # Much faster than FFT-based resampling + data_up = dsp.resample_poly(data, up, down, axis=-1) + data_down = dsp.resample_poly(data, down, up, axis=-1) + + # Calculate an up/downsampled version of the PSD using same params as original + spectrum_up = irasa_fun(data=data_up, fs=int(fs * h), h=h, time_orig=time, up_down='up') + spectrum_dw = irasa_fun(data=data_down, fs=int(fs / h), h=h, time_orig=time, up_down='down') + + # geometric mean between up and downsampled + # be aware of the input dimensions + if spectra.ndim == 2: # noqa PLR2004 + spectra[i, :] = np.sqrt(spectrum_up * spectrum_dw) + if spectra.ndim == 3: # noqa PLR2004 + spectra[i, :, :] = np.sqrt(spectrum_up * spectrum_dw) + + aperiodic_spectrum = np.median(spectra, axis=0) + periodic_spectrum = orig_spectrum - aperiodic_spectrum + return orig_spectrum, aperiodic_spectrum, periodic_spectrum + + +def _crop_data( + band: list | tuple, + freqs: np.ndarray, + psd_aperiodic: np.ndarray, + psd_periodic: np.ndarray, + psd: np.ndarray, + axis: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Utility function to crop spectra to a defined frequency range""" mask_freqs = np.ma.masked_outside(freqs, *band).mask freqs = freqs[~mask_freqs] psd_aperiodic = np.compress(~mask_freqs, psd_aperiodic, axis=axis) psd_periodic = np.compress(~mask_freqs, psd_periodic, axis=axis) + psd = np.compress(~mask_freqs, psd, axis=axis) - return freqs, psd_aperiodic, psd_periodic + return freqs, psd_aperiodic, psd_periodic, psd def _gen_time_from_sft(SFT: type[dsp.ShortTimeFFT], sgramm: np.ndarray) -> np.ndarray: # noqa N803 diff --git a/pyrasa/utils/types.py b/pyrasa/utils/types.py index a4dd657..c8795a1 100644 --- a/pyrasa/utils/types.py +++ b/pyrasa/utils/types.py @@ -1,9 +1,25 @@ -from typing import Protocol +from typing import Protocol, TypedDict import numpy as np +import pandas as pd +from attrs import define class IrasaFun(Protocol): def __call__( self, data: np.ndarray, fs: int, h: int | None, up_down: str | None, time_orig: np.ndarray | None = None ) -> np.ndarray: ... + + +class IrasaSprintKwargsTyped(TypedDict): + mfft: int + hop: int + win_duration: float + dpss_settings: dict + win_kwargs: dict + + +@define +class SlopeFit: + aperiodic_params: pd.DataFrame + gof: pd.DataFrame diff --git a/tests/test_basic_irasa.py b/tests/test_basic_irasa.py index 7289f6d..5c24d6f 100644 --- a/tests/test_basic_irasa.py +++ b/tests/test_basic_irasa.py @@ -1,10 +1,7 @@ import numpy as np import pytest -import scipy.signal as dsp from pyrasa import irasa -from pyrasa.utils.aperiodic_utils import compute_slope -from pyrasa.utils.peak_utils import get_peak_params from .settings import EXPONENT, FS, MIN_CORR_PSD_CMB, OSC_FREQ, TOLERANCE @@ -18,20 +15,20 @@ @pytest.mark.parametrize('fs', FS, scope='session') def test_irasa(combined_signal, fs, osc_freq, exponent): f_range = [1, 100] - freqs_rasa, psd_ap, psd_pe = irasa(combined_signal, fs, f_range, psd_kwargs={'nperseg': 4 * fs}) + irasa_spectrum = irasa(combined_signal, fs, f_range, psd_kwargs={'nperseg': 4 * fs}) # test the shape of the output - assert freqs_rasa.shape[0] == psd_ap.shape[1] == psd_pe.shape[1] + assert irasa_spectrum.freqs.shape[0] == irasa_spectrum.aperiodic.shape[1] == irasa_spectrum.periodic.shape[1] # test the selected frequency range - assert bool(np.logical_and(freqs_rasa[0] == f_range[0], freqs_rasa[-1] == f_range[1])) + assert bool(np.logical_and(irasa_spectrum.freqs[0] == f_range[0], irasa_spectrum.freqs[-1] == f_range[1])) # test whether recombining periodic and aperiodic spectrum is equivalent to the original spectrum - freqs_psd, psd = dsp.welch(combined_signal, fs, nperseg=int(4 * fs)) - psd_cmb = psd_ap[0, :] + psd_pe[0, :] - freq_logical = np.logical_and(freqs_psd >= f_range[0], freqs_psd <= f_range[1]) - r = np.corrcoef(psd[freq_logical], psd_cmb)[0, 1] + # freqs_psd, psd = dsp.welch(combined_signal, fs, nperseg=int(4 * fs)) + psd_cmb = irasa_spectrum.aperiodic[0, :] + irasa_spectrum.periodic[0, :] + # freq_logical = np.logical_and(freqs_psd >= f_range[0], freqs_psd <= f_range[1]) + r = np.corrcoef(irasa_spectrum.raw_spectrum, psd_cmb)[0, 1] assert r > MIN_CORR_PSD_CMB # test whether we can reconstruct the exponent correctly - ap_params, _ = compute_slope(psd_ap, freqs_rasa, fit_func='fixed') - assert bool(np.isclose(ap_params['Exponent'][0], np.abs(exponent), atol=TOLERANCE)) + slope_fit = irasa_spectrum.get_slopes(fit_func='fixed') + assert bool(np.isclose(slope_fit.aperiodic_params['Exponent'][0], np.abs(exponent), atol=TOLERANCE)) # test whether we can reconstruct the peak frequency correctly - pe_params = get_peak_params(psd_pe, freqs_rasa) + pe_params = irasa_spectrum.get_peaks() assert bool(np.isclose(np.round(pe_params['cf'], 0), osc_freq)) diff --git a/tests/test_compute_slope.py b/tests/test_compute_slope.py index 02c3f3f..c678f18 100644 --- a/tests/test_compute_slope.py +++ b/tests/test_compute_slope.py @@ -19,23 +19,23 @@ def test_slope_fitting_fixed(fixed_aperiodic_signal, fs, exponent): freqs, psd = freqs[freq_logical], psd[freq_logical] # test whether we can reconstruct the exponent correctly - ap_params_f, gof_f = compute_slope(psd, freqs, fit_func='fixed') - assert pytest.approx(ap_params_f['Exponent'][0], abs=TOLERANCE) == np.abs(exponent) + slope_fit_f = compute_slope(psd, freqs, fit_func='fixed') + assert pytest.approx(slope_fit_f.aperiodic_params['Exponent'][0], abs=TOLERANCE) == np.abs(exponent) # test goodness of fit should be close to r_squared == 1 for linear model - assert gof_f['r_squared'][0] > MIN_R2 + assert slope_fit_f.gof['r_squared'][0] > MIN_R2 # test if we can set fit bounds w/o error # _, _ = compute_slope(psd, freqs, fit_func='fixed', fit_bounds=[2, 50]) # bic and aic for fixed model should be better if linear - ap_params_k, gof_k = compute_slope(psd, freqs, fit_func='knee') + slope_fit_k = compute_slope(psd, freqs, fit_func='knee') # assert gof_k['AIC'][0] > gof['AIC'][0] - assert gof_k['BIC'][0] > gof_f['BIC'][0] + assert slope_fit_k.gof['BIC'][0] > slope_fit_f.gof['BIC'][0] # test the effect of scaling - ap_params_fs, gof_fs = compute_slope(psd, freqs, fit_func='fixed', scale=True) - assert np.isclose(ap_params_fs['Exponent'], ap_params_f['Exponent']) - assert np.isclose(gof_fs['r_squared'], gof_f['r_squared']) + slope_fit_fs = compute_slope(psd, freqs, fit_func='fixed', scale=True) + assert np.isclose(slope_fit_fs.aperiodic_params['Exponent'], slope_fit_f.aperiodic_params['Exponent']) + assert np.isclose(slope_fit_fs.gof['r_squared'], slope_fit_f.gof['r_squared']) @pytest.mark.parametrize('exponent, fs', [(-1, 500)], scope='session') @@ -64,25 +64,3 @@ def test_slope_fitting_settings( # test for warning with pytest.warns(UserWarning, match=match_txt): compute_slope(psd[freq_logical], freqs[freq_logical], fit_func='fixed') - - -# Takes too long need to pregenerate -# @pytest.mark.parametrize('exponent, fs, knee_freq', [(-1, 500, 15)], scope='session') -# def test_slope_fitting_knee(knee_aperiodic_signal, fs, exponent): -# f_range = [1, 200] -# # test whether recombining periodic and aperiodic spectrum is equivalent to the original spectrum -# freqs, psd = dsp.welch(knee_aperiodic_signal, fs, nperseg=int(4 * fs)) -# freq_logical = np.logical_and(freqs >= f_range[0], freqs <= f_range[1]) -# freqs, psd = freqs[freq_logical], psd[freq_logical] -# # test whether we can reconstruct the exponent correctly -# ap_params_k, gof_k = compute_slope(psd, freqs, fit_func='knee') -# ap_params_f, gof_f = compute_slope(psd, freqs, fit_func='fixed') -# # assert pytest.approx(ap_params_k['Exponent_1'][0], abs=TOLERANCE) == 0 -# assert bool(np.isclose(ap_params_k['Exponent_2'][0], np.abs(exponent), atol=TOLERANCE)) -# assert bool(np.isclose(ap_params_k['Knee Frequency (Hz)'][0], KNEE_FREQ, atol=KNEE_TOLERANCE)) -# # test goodness of fit -# assert gof_k['r_squared'][0] > MIN_R2 -# assert gof_k['r_squared'][0] > gof_f['r_squared'][0] # r2 for knee model should be higher than knee if knee -# # bic and aic for knee model should be better if knee -# assert gof_k['AIC'][0] < gof_f['AIC'][0] -# assert gof_k['BIC'][0] < gof_f['BIC'][0] diff --git a/tests/test_irasa_knee.py b/tests/test_irasa_knee.py index 90c24ea..c858c87 100644 --- a/tests/test_irasa_knee.py +++ b/tests/test_irasa_knee.py @@ -3,8 +3,6 @@ import scipy.signal as dsp from pyrasa import irasa -from pyrasa.utils.aperiodic_utils import compute_slope -from pyrasa.utils.peak_utils import get_peak_params from .settings import EXP_KNEE_COMBO, FS, KNEE_TOLERANCE, MIN_CORR_PSD_CMB, OSC_FREQ, TOLERANCE @@ -17,27 +15,29 @@ @pytest.mark.parametrize('fs', FS, scope='session') def test_irasa_knee_peakless(load_knee_aperiodic_signal, fs, exponent, knee): f_range = [0.1, 100] - freqs_rasa, psd_ap, psd_pe = irasa(load_knee_aperiodic_signal, fs, f_range, psd_kwargs={'nperseg': 4 * fs}) + irasa_out = irasa(load_knee_aperiodic_signal, fs, f_range, psd_kwargs={'nperseg': 4 * fs}) # test the shape of the output - assert freqs_rasa.shape[0] == psd_ap.shape[1] == psd_pe.shape[1] + assert irasa_out.freqs.shape[0] == irasa_out.aperiodic.shape[1] == irasa_out.periodic.shape[1] freqs_psd, psd = dsp.welch(load_knee_aperiodic_signal, fs, nperseg=int(4 * fs)) - psd_cmb = psd_ap[0, :] + psd_pe[0, :] + psd_cmb = irasa_out.aperiodic[0, :] + irasa_out.periodic[0, :] freq_logical = np.logical_and(freqs_psd >= f_range[0], freqs_psd <= f_range[1]) r = np.corrcoef(psd[freq_logical], psd_cmb)[0, 1] assert r > MIN_CORR_PSD_CMB - ap_params_k, gof_k = compute_slope(psd_ap, freqs_rasa, fit_func='knee') - ap_params_f, gof_f = compute_slope(psd_ap, freqs_rasa, fit_func='fixed') + slope_fit_k = irasa_out.get_slopes(fit_func='knee') + slope_fit_f = irasa_out.get_slopes(fit_func='fixed') # test whether we can get the first exponent correctly - assert bool(np.isclose(ap_params_k['Exponent_1'][0], 0, atol=TOLERANCE)) + assert bool(np.isclose(slope_fit_k.aperiodic_params['Exponent_1'][0], 0, atol=TOLERANCE)) # test whether we can get the second exponent correctly - assert bool(np.isclose(ap_params_k['Exponent_2'][0], np.abs(exponent), atol=TOLERANCE)) + assert bool(np.isclose(slope_fit_k.aperiodic_params['Exponent_2'][0], np.abs(exponent), atol=TOLERANCE)) # test whether we can get the knee correctly - knee_hat = ap_params_k['Knee'][0] ** (1 / (2 * ap_params_k['Exponent_1'][0] + ap_params_k['Exponent_2'][0])) + knee_hat = slope_fit_k.aperiodic_params['Knee'][0] ** ( + 1 / (2 * slope_fit_k.aperiodic_params['Exponent_1'][0] + slope_fit_k.aperiodic_params['Exponent_2'][0]) + ) knee_real = knee ** (1 / np.abs(exponent)) assert bool(np.isclose(knee_hat, knee_real, atol=KNEE_TOLERANCE)) # test bic/aic -> should be better for knee - assert gof_k['AIC'][0] < gof_f['AIC'][0] - assert gof_k['BIC'][0] < gof_f['BIC'][0] + assert slope_fit_k.gof['AIC'][0] < slope_fit_f.gof['AIC'][0] + assert slope_fit_k.gof['BIC'][0] < slope_fit_f.gof['BIC'][0] # knee model @@ -46,27 +46,29 @@ def test_irasa_knee_peakless(load_knee_aperiodic_signal, fs, exponent, knee): @pytest.mark.parametrize('osc_freq', OSC_FREQ, scope='session') def test_irasa_knee_cmb(load_knee_cmb_signal, fs, exponent, knee, osc_freq): f_range = [0.1, 100] - freqs_rasa, psd_ap, psd_pe = irasa(load_knee_cmb_signal, fs, f_range, psd_kwargs={'nperseg': 4 * fs}) + irasa_out = irasa(load_knee_cmb_signal, fs, f_range, psd_kwargs={'nperseg': 4 * fs}) # test the shape of the output - assert freqs_rasa.shape[0] == psd_ap.shape[1] == psd_pe.shape[1] + assert irasa_out.freqs.shape[0] == irasa_out.aperiodic.shape[1] == irasa_out.periodic.shape[1] freqs_psd, psd = dsp.welch(load_knee_cmb_signal, fs, nperseg=int(4 * fs)) - psd_cmb = psd_ap[0, :] + psd_pe[0, :] + psd_cmb = irasa_out.aperiodic[0, :] + irasa_out.periodic[0, :] freq_logical = np.logical_and(freqs_psd >= f_range[0], freqs_psd <= f_range[1]) r = np.corrcoef(psd[freq_logical], psd_cmb)[0, 1] assert r > MIN_CORR_PSD_CMB - ap_params_k, gof_k = compute_slope(psd_ap, freqs_rasa, fit_func='knee') - ap_params_f, gof_f = compute_slope(psd_ap, freqs_rasa, fit_func='fixed') + slope_fit_k = irasa_out.get_slopes(fit_func='knee') + slope_fit_f = irasa_out.get_slopes(fit_func='fixed') # test whether we can get the first exponent correctly - assert bool(np.isclose(ap_params_k['Exponent_1'][0], 0, atol=TOLERANCE)) + assert bool(np.isclose(slope_fit_k.aperiodic_params['Exponent_1'][0], 0, atol=TOLERANCE)) # test whether we can get the second exponent correctly - assert bool(np.isclose(ap_params_k['Exponent_2'][0], np.abs(exponent), atol=TOLERANCE)) + assert bool(np.isclose(slope_fit_k.aperiodic_params['Exponent_2'][0], np.abs(exponent), atol=TOLERANCE)) # test whether we can get the knee correctly - knee_hat = ap_params_k['Knee'][0] ** (1 / (2 * ap_params_k['Exponent_1'][0] + ap_params_k['Exponent_2'][0])) + knee_hat = slope_fit_k.aperiodic_params['Knee'][0] ** ( + 1 / (2 * slope_fit_k.aperiodic_params['Exponent_1'][0] + slope_fit_k.aperiodic_params['Exponent_2'][0]) + ) knee_real = knee ** (1 / np.abs(exponent)) assert bool(np.isclose(knee_hat, knee_real, atol=KNEE_TOLERANCE)) # test bic/aic -> should be better for knee - assert gof_k['AIC'][0] < gof_f['AIC'][0] - assert gof_k['BIC'][0] < gof_f['BIC'][0] + assert slope_fit_k.gof['AIC'][0] < slope_fit_f.gof['AIC'][0] + assert slope_fit_k.gof['BIC'][0] < slope_fit_f.gof['BIC'][0] # test whether we can reconstruct the peak frequency correctly - pe_params = get_peak_params(psd_pe, freqs_rasa) + pe_params = irasa_out.get_peaks() assert bool(np.isclose(np.round(pe_params['cf'], 0), osc_freq)) diff --git a/tests/test_irasa_sprint.py b/tests/test_irasa_sprint.py index 8d1a954..1cc0e86 100644 --- a/tests/test_irasa_sprint.py +++ b/tests/test_irasa_sprint.py @@ -3,8 +3,7 @@ from neurodsp.utils.sim import set_random_seed from pyrasa.irasa import irasa_sprint -from pyrasa.utils.aperiodic_utils import compute_slope_sprint -from pyrasa.utils.peak_utils import get_band_info, get_peak_params_sprint +from pyrasa.utils.peak_utils import get_band_info from .settings import MIN_R2_SPRINT, TOLERANCE @@ -12,25 +11,24 @@ def test_irasa_sprint(ts4sprint): - sgramm_ap, sgramm_p, freqs_ir, times_ir = irasa_sprint( + irasa_tf = irasa_sprint( ts4sprint[np.newaxis, :], fs=500, band=(1, 100), - freq_res=0.5, # smooth=False, n_avgs=[3, 7, 11] + freq_res=0.5, ) # check basic aperiodic detection - df_aps, df_gof = compute_slope_sprint(sgramm_ap[np.newaxis, :, :], freqs=freqs_ir, times=times_ir, fit_func='fixed') + slope_fit = irasa_tf.get_slopes(fit_func='fixed') + # irasa_tf.aperiodic[np.newaxis, :, :], freqs=irasa_tf.freqs, times=irasa_tf.time, + # ) - assert df_gof['r_squared'].mean() > MIN_R2_SPRINT - assert np.isclose(df_aps.query('time < 7')['Exponent'].mean(), 1, atol=TOLERANCE) - assert np.isclose(df_aps.query('time > 7')['Exponent'].mean(), 2, atol=TOLERANCE) + assert slope_fit.gof['r_squared'].mean() > MIN_R2_SPRINT + assert np.isclose(slope_fit.aperiodic_params.query('time < 7')['Exponent'].mean(), 1, atol=TOLERANCE) + assert np.isclose(slope_fit.aperiodic_params.query('time > 7')['Exponent'].mean(), 2, atol=TOLERANCE) # check basic peak detection - df_peaks = get_peak_params_sprint( - sgramm_p[np.newaxis, :, :], - freqs=freqs_ir, - times=times_ir, + df_peaks = irasa_tf.get_peaks( smooth=True, smoothing_window=1, min_peak_height=0.01, @@ -73,11 +71,6 @@ def test_irasa_sprint(ts4sprint): # test settings def test_irasa_sprint_settings(ts4sprint): - # test smoothing - # sgramm_ap, sgramm_p, freqs_ir, times_ir = irasa_sprint( - # ts4sprint[np.newaxis, :], fs=500, band=(1, 100), freq_res=0.5, smooth=True, n_avgs=[3] - # ) - # test dpss import scipy.signal as dsp @@ -87,8 +80,6 @@ def test_irasa_sprint_settings(ts4sprint): band=(1, 100), win_func=dsp.windows.dpss, freq_res=0.5, - # smooth=False, - # n_avgs=[3, 7, 11], ) # test too much bandwidht @@ -100,6 +91,4 @@ def test_irasa_sprint_settings(ts4sprint): win_func=dsp.windows.dpss, dpss_settings_time_bandwidth=1, freq_res=0.5, - # smooth=False, - # n_avgs=[3, 7, 11], ) diff --git a/tests/test_mne.py b/tests/test_mne.py index aad426c..46bd47c 100644 --- a/tests/test_mne.py +++ b/tests/test_mne.py @@ -12,13 +12,11 @@ def test_mne(gen_mne_data_raw): mne_data, epochs = gen_mne_data_raw # test raw - aperiodic_mne, periodic_mne = irasa_raw( - mne_data, band=(0.25, 50), duration=2, hset_info=(1.0, 2.0, 0.05), as_array=False - ) - aperiodic_mne.get_slopes(fit_func='fixed') - periodic_mne.get_peaks(smoothing_window=2) + irasa_raw_result = irasa_raw(mne_data, band=(0.25, 50), duration=2, hset_info=(1.0, 2.0, 0.05)) + irasa_raw_result.aperiodic.get_slopes(fit_func='fixed') + irasa_raw_result.periodic.get_peaks(smoothing_window=2) # test epochs - aperiodic, periodic = irasa_epochs(epochs, band=(0.5, 50), hset_info=(1.0, 2.0, 0.05), as_array=False) - aperiodic.get_slopes(fit_func='fixed', scale=True) - periodic.get_peaks(smoothing_window=2) + irasa_epoched_result = irasa_epochs(epochs, band=(0.5, 50), hset_info=(1.0, 2.0, 0.05)) + irasa_epoched_result.aperiodic.get_slopes(fit_func='fixed', scale=True) + irasa_epoched_result.periodic.get_peaks(smoothing_window=2)