Skip to content

Commit

Permalink
added methods to irasa_tf_spectrum
Browse files Browse the repository at this point in the history
  • Loading branch information
Schmidt Fabian committed Aug 1, 2024
1 parent fd2edd1 commit 3a44d8f
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 19 deletions.
8 changes: 7 additions & 1 deletion pyrasa/utils/aperiodic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def compute_slope_sprint(
freqs: np.ndarray,
times: np.ndarray,
fit_func: str,
scale: bool = False,
ch_names: Iterable = (),
fit_bounds: tuple[float, float] | None = None,
) -> SlopeFit:
Expand Down Expand Up @@ -268,7 +269,12 @@ def compute_slope_sprint(

for ix, t in enumerate(times):
slope_fit = compute_slope(
aperiodic_spectrum[:, :, ix], freqs=freqs, fit_func=fit_func, ch_names=ch_names, fit_bounds=fit_bounds
aperiodic_spectrum[:, :, ix],
freqs=freqs,
fit_func=fit_func,
ch_names=ch_names,
fit_bounds=fit_bounds,
scale=scale,
)
slope_fit.aperiodic_params['time'] = t
slope_fit.gof['time'] = t
Expand Down
88 changes: 88 additions & 0 deletions pyrasa/utils/irasa_tf_spectrum.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import numpy as np
import pandas as pd
from attrs import define

from pyrasa.utils.aperiodic_utils import compute_slope_sprint
from pyrasa.utils.peak_utils import get_peak_params_sprint
from pyrasa.utils.types import SlopeFit

min_ndim = 2


@define
class IrasaTfSpectrum:
Expand All @@ -9,3 +16,84 @@ class IrasaTfSpectrum:
raw_spectrum: np.ndarray
aperiodic: np.ndarray
periodic: np.ndarray

def get_slopes(
self, fit_func: str = 'fixed', scale: bool = False, fit_bounds: tuple[float, float] | None = None
) -> SlopeFit:
"""
This method can be used to extract aperiodic parameters from the aperiodic spectrum extracted from IRASA.
The algorithm works by applying one of two different curve fit functions and returns the associated parameters,
as well as the respective goodness of fit.
Parameters:
fit_func : string
Can be either "fixed" or "knee".
fit_bounds : None, tuple
Lower and upper bound for the fit function,
should be None if the whole frequency range is desired.
Otherwise a tuple of (lower, upper)
Returns: SlopeFit
df_aps: DataFrame
DataFrame containing the center frequency, bandwidth and peak height for each channel
df_gof: DataFrame
DataFrame containing the goodness of fit of the specific fit function for each channel.
"""
return compute_slope_sprint(
aperiodic_spectrum=self.aperiodic[np.newaxis, :, :] if self.aperiodic.ndim == min_ndim else self.aperiodic,
freqs=self.freqs,
times=self.time,
# ch_names=self.ch_names,
scale=scale,
fit_func=fit_func,
fit_bounds=fit_bounds,
)

def get_peaks(
self,
smooth: bool = True,
smoothing_window: float | int = 1,
cut_spectrum: tuple[float, float] | None = None,
peak_threshold: float = 2.5,
min_peak_height: float = 0.0,
polyorder: int = 1,
peak_width_limits: tuple[float, float] = (0.5, 12),
) -> pd.DataFrame:
"""
This method can be used to extract peak parameters from the periodic spectrum extracted from IRASA.
The algorithm works by smoothing the spectrum, zeroing out negative values and
extracting peaks based on user specified parameters.
Parameters: smoothing window : int, optional, default: 2
Smoothing window in Hz handed over to the savitzky-golay filter.
cut_spectrum : tuple of (float, float), optional, default (1, 40)
Cut the periodic spectrum to limit peak finding to a sensible range
peak_threshold : float, optional, default: 1
Relative threshold for detecting peaks. This threshold is defined in
relative units of the periodic spectrum
min_peak_height : float, optional, default: 0.01
Absolute threshold for identifying peaks. The threhsold is defined in relative
units of the power spectrum. Setting this is somewhat necessary when a
"knee" is present in the data as it will carry over to the periodic spctrum in irasa.
peak_width_limits : tuple of (float, float), optional, default (.5, 12)
Limits on possible peak width, in Hz, as (lower_bound, upper_bound)
Returns: df_peaks: DataFrame
DataFrame containing the center frequency, bandwidth and peak height for each channel
"""

return get_peak_params_sprint(
periodic_spectrum=self.periodic[np.newaxis, :, :] if self.periodic.ndim == min_ndim else self.periodic,
freqs=self.freqs,
times=self.time,
# self.ch_names,
smooth=smooth,
smoothing_window=smoothing_window,
cut_spectrum=cut_spectrum,
peak_threshold=peak_threshold,
min_peak_height=min_peak_height,
polyorder=polyorder,
peak_width_limits=peak_width_limits,
)
23 changes: 5 additions & 18 deletions tests/test_irasa_sprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from neurodsp.utils.sim import set_random_seed

from pyrasa.irasa import irasa_sprint
from pyrasa.utils.aperiodic_utils import compute_slope_sprint
from pyrasa.utils.peak_utils import get_band_info, get_peak_params_sprint
from pyrasa.utils.peak_utils import get_band_info

from .settings import MIN_R2_SPRINT, TOLERANCE

Expand All @@ -20,19 +19,16 @@ def test_irasa_sprint(ts4sprint):
)

# check basic aperiodic detection
slope_fit = compute_slope_sprint(
irasa_tf.aperiodic[np.newaxis, :, :], freqs=irasa_tf.freqs, times=irasa_tf.time, fit_func='fixed'
)
slope_fit = irasa_tf.get_slopes(fit_func='fixed')
# irasa_tf.aperiodic[np.newaxis, :, :], freqs=irasa_tf.freqs, times=irasa_tf.time,
# )

assert slope_fit.gof['r_squared'].mean() > MIN_R2_SPRINT
assert np.isclose(slope_fit.aperiodic_params.query('time < 7')['Exponent'].mean(), 1, atol=TOLERANCE)
assert np.isclose(slope_fit.aperiodic_params.query('time > 7')['Exponent'].mean(), 2, atol=TOLERANCE)

# check basic peak detection
df_peaks = get_peak_params_sprint(
irasa_tf.periodic[np.newaxis, :, :],
freqs=irasa_tf.freqs,
times=irasa_tf.time,
df_peaks = irasa_tf.get_peaks(
smooth=True,
smoothing_window=1,
min_peak_height=0.01,
Expand Down Expand Up @@ -75,11 +71,6 @@ def test_irasa_sprint(ts4sprint):

# test settings
def test_irasa_sprint_settings(ts4sprint):
# test smoothing
# sgramm_ap, sgramm_p, freqs_ir, times_ir = irasa_sprint(
# ts4sprint[np.newaxis, :], fs=500, band=(1, 100), freq_res=0.5, smooth=True, n_avgs=[3]
# )

# test dpss
import scipy.signal as dsp

Expand All @@ -89,8 +80,6 @@ def test_irasa_sprint_settings(ts4sprint):
band=(1, 100),
win_func=dsp.windows.dpss,
freq_res=0.5,
# smooth=False,
# n_avgs=[3, 7, 11],
)

# test too much bandwidht
Expand All @@ -102,6 +91,4 @@ def test_irasa_sprint_settings(ts4sprint):
win_func=dsp.windows.dpss,
dpss_settings_time_bandwidth=1,
freq_res=0.5,
# smooth=False,
# n_avgs=[3, 7, 11],
)

0 comments on commit 3a44d8f

Please sign in to comment.