Skip to content

Commit

Permalink
gave irasa_spectrum methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Schmidt Fabian committed Aug 1, 2024
1 parent f12e5da commit d19b4cb
Show file tree
Hide file tree
Showing 11 changed files with 174 additions and 97 deletions.
14 changes: 10 additions & 4 deletions pyrasa/irasa.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ def _local_irasa_fun(
data=np.squeeze(data), orig_spectrum=psd, fs=fs, irasa_fun=_local_irasa_fun, hset=hset
)

freq, psd_aperiodic, psd_periodic = _crop_data(band, freq, psd_aperiodic, psd_periodic, axis=-1)
freq, psd_aperiodic, psd_periodic, psd = _crop_data(band, freq, psd_aperiodic, psd_periodic, psd, axis=-1)

return IrasaSpectrum(freqs=freq, aperiodic=psd_aperiodic, periodic=psd_periodic)
return IrasaSpectrum(freqs=freq, raw_spectrum=psd, aperiodic=psd_aperiodic, periodic=psd_periodic)


# irasa sprint
Expand Down Expand Up @@ -286,12 +286,18 @@ def _local_irasa_fun(
)

# NOTE: we need to transpose the data as crop_data extracts stuff from the last axis
freq, sgramm_aperiodic, sgramm_periodic = _crop_data(band, freq, sgramm_aperiodic, sgramm_periodic, axis=0)
freq, sgramm_aperiodic, sgramm_periodic, sgramm = _crop_data(
band, freq, sgramm_aperiodic, sgramm_periodic, sgramm, axis=0
)

# adjust time info (i.e. cut the padded stuff)
tmax = data.shape[1] / fs
t_mask = np.logical_and(time >= 0, time < tmax)

return IrasaTfSpectrum(
freqs=freq, time=time[t_mask], periodic=sgramm_periodic[:, t_mask], aperiodic=sgramm_aperiodic[:, t_mask]
freqs=freq,
time=time[t_mask],
raw_spectrum=sgramm,
periodic=sgramm_periodic[:, t_mask],
aperiodic=sgramm_aperiodic[:, t_mask],
)
21 changes: 10 additions & 11 deletions pyrasa/irasa_mne/mne_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pyrasa.utils.aperiodic_utils import compute_slope
from pyrasa.utils.peak_utils import get_peak_params
from pyrasa.utils.types import SlopeFit


class PeriodicSpectrumArray(SpectrumArray):
Expand Down Expand Up @@ -170,7 +171,7 @@ def __init__(

def get_slopes(
self: SpectrumArray, fit_func: str = 'fixed', scale: bool = False, fit_bounds: tuple[float, float] | None = None
) -> tuple[pd.DataFrame, pd.DataFrame]:
) -> SlopeFit:
"""
This method can be used to extract aperiodic parameters from the aperiodic spectrum extracted from IRASA.
The algorithm works by applying one of two different curve fit functions and returns the associated parameters,
Expand All @@ -191,7 +192,7 @@ def get_slopes(
"""

df_aps, df_gof = compute_slope(
return compute_slope(
self.get_data(),
self.freqs,
ch_names=self.ch_names,
Expand All @@ -200,8 +201,6 @@ def get_slopes(
fit_bounds=fit_bounds,
)

return df_aps, df_gof


# %%
class PeriodicEpochsSpectrum(EpochsSpectrumArray):
Expand Down Expand Up @@ -395,7 +394,7 @@ def __init__(

def get_slopes(
self: SpectrumArray, fit_func: str = 'fixed', scale: bool = False, fit_bounds: tuple[float, float] | None = None
) -> tuple[pd.DataFrame, pd.DataFrame]:
) -> SlopeFit:
"""
This method can be used to extract aperiodic parameters from the aperiodic spectrum extracted from IRASA.
The algorithm works by applying one of two different curve fit functions and returns the associated parameters,
Expand All @@ -421,7 +420,7 @@ def get_slopes(

aps_list, gof_list = [], []
for ix, cur_epoch in enumerate(self.get_data()):
df_aps, df_gof = compute_slope(
slope_fit = compute_slope(
cur_epoch,
self.freqs,
ch_names=self.ch_names,
Expand All @@ -430,12 +429,12 @@ def get_slopes(
fit_bounds=fit_bounds,
)

df_aps['event_id'] = event_dict[events[ix]]
df_gof['event_id'] = event_dict[events[ix]]
aps_list.append(df_aps)
gof_list.append(df_gof)
slope_fit.aperiodic_params['event_id'] = event_dict[events[ix]]
slope_fit.gof['event_id'] = event_dict[events[ix]]
aps_list.append(slope_fit.aperiodic_params.copy())
gof_list.append(slope_fit.gof.copy())

return pd.concat(aps_list), pd.concat(gof_list)
return SlopeFit(aperiodic_params=pd.concat(aps_list), gof=pd.concat(gof_list))


@define
Expand Down
26 changes: 11 additions & 15 deletions pyrasa/utils/aperiodic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import pandas as pd
from scipy.optimize import curve_fit

from pyrasa.utils.types import SlopeFit


def fixed_model(x: np.ndarray, b0: float, b: float) -> np.ndarray:
"""
Expand Down Expand Up @@ -137,7 +139,7 @@ def compute_slope(
ch_names: Iterable = (),
scale: bool = False,
fit_bounds: tuple[float, float] | None = None,
) -> tuple[pd.DataFrame, pd.DataFrame]:
) -> SlopeFit:
"""
This function can be used to extract aperiodic parameters from the aperiodic spectrum extracted from IRASA.
The algorithm works by applying one of two different curve fit functions and returns the associated parameters,
Expand Down Expand Up @@ -224,10 +226,7 @@ def num_zeros(decimal: int) -> float:
gof_list.append(gof)

# combine & return
df_aps = pd.concat(ap_list)
df_gof = pd.concat(gof_list)

return df_aps, df_gof
return SlopeFit(aperiodic_params=pd.concat(ap_list), gof=pd.concat(gof_list))


def compute_slope_sprint(
Expand All @@ -237,7 +236,7 @@ def compute_slope_sprint(
fit_func: str,
ch_names: Iterable = (),
fit_bounds: tuple[float, float] | None = None,
) -> tuple[pd.DataFrame, pd.DataFrame]:
) -> SlopeFit:
"""
This function can be used to extract aperiodic parameters from the aperiodic spectrogram extracted from IRASA.
The algorithm works by applying one of two different curve fit functions and returns the associated parameters,
Expand Down Expand Up @@ -268,16 +267,13 @@ def compute_slope_sprint(
ap_t_list, gof_t_list = [], []

for ix, t in enumerate(times):
cur_aps, cur_gof = compute_slope(
slope_fit = compute_slope(
aperiodic_spectrum[:, :, ix], freqs=freqs, fit_func=fit_func, ch_names=ch_names, fit_bounds=fit_bounds
)
cur_aps['time'] = t
cur_gof['time'] = t

ap_t_list.append(cur_aps)
gof_t_list.append(cur_gof)
slope_fit.aperiodic_params['time'] = t
slope_fit.gof['time'] = t

df_ap_time = pd.concat(ap_t_list)
df_gof_time = pd.concat(gof_t_list)
ap_t_list.append(slope_fit.aperiodic_params)
gof_t_list.append(slope_fit.gof)

return df_ap_time, df_gof_time
return SlopeFit(aperiodic_params=pd.concat(ap_t_list), gof=pd.concat(gof_t_list))
84 changes: 84 additions & 0 deletions pyrasa/utils/irasa_spectrum.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,93 @@
import numpy as np
import pandas as pd
from attrs import define

from pyrasa.utils.aperiodic_utils import compute_slope
from pyrasa.utils.peak_utils import get_peak_params
from pyrasa.utils.types import SlopeFit


@define
class IrasaSpectrum:
freqs: np.ndarray
raw_spectrum: np.ndarray
aperiodic: np.ndarray
periodic: np.ndarray
# ch_names: np.ndarray | None

def get_slopes(
self, fit_func: str = 'fixed', scale: bool = False, fit_bounds: tuple[float, float] | None = None
) -> SlopeFit:
"""
This method can be used to extract aperiodic parameters from the aperiodic spectrum extracted from IRASA.
The algorithm works by applying one of two different curve fit functions and returns the associated parameters,
as well as the respective goodness of fit.
Parameters:
fit_func : string
Can be either "fixed" or "knee".
fit_bounds : None, tuple
Lower and upper bound for the fit function,
should be None if the whole frequency range is desired.
Otherwise a tuple of (lower, upper)
Returns: SlopeFit
df_aps: DataFrame
DataFrame containing the center frequency, bandwidth and peak height for each channel
df_gof: DataFrame
DataFrame containing the goodness of fit of the specific fit function for each channel.
"""
return compute_slope(
self.aperiodic,
self.freqs,
# ch_names=self.ch_names,
scale=scale,
fit_func=fit_func,
fit_bounds=fit_bounds,
)

def get_peaks(
self,
smoothing_window: float | int = 1,
cut_spectrum: tuple[float, float] | None = None,
peak_threshold: float = 2.5,
min_peak_height: float = 0.0,
polyorder: int = 1,
peak_width_limits: tuple[float, float] = (0.5, 12),
) -> pd.DataFrame:
"""
This method can be used to extract peak parameters from the periodic spectrum extracted from IRASA.
The algorithm works by smoothing the spectrum, zeroing out negative values and
extracting peaks based on user specified parameters.
Parameters: smoothing window : int, optional, default: 2
Smoothing window in Hz handed over to the savitzky-golay filter.
cut_spectrum : tuple of (float, float), optional, default (1, 40)
Cut the periodic spectrum to limit peak finding to a sensible range
peak_threshold : float, optional, default: 1
Relative threshold for detecting peaks. This threshold is defined in
relative units of the periodic spectrum
min_peak_height : float, optional, default: 0.01
Absolute threshold for identifying peaks. The threhsold is defined in relative
units of the power spectrum. Setting this is somewhat necessary when a
"knee" is present in the data as it will carry over to the periodic spctrum in irasa.
peak_width_limits : tuple of (float, float), optional, default (.5, 12)
Limits on possible peak width, in Hz, as (lower_bound, upper_bound)
Returns: df_peaks: DataFrame
DataFrame containing the center frequency, bandwidth and peak height for each channel
"""

return get_peak_params(
self.periodic,
self.freqs,
# self.ch_names,
smoothing_window=smoothing_window,
cut_spectrum=cut_spectrum,
peak_threshold=peak_threshold,
min_peak_height=min_peak_height,
polyorder=polyorder,
peak_width_limits=peak_width_limits,
)
1 change: 1 addition & 0 deletions pyrasa/utils/irasa_tf_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
class IrasaTfSpectrum:
freqs: np.ndarray
time: np.ndarray
raw_spectrum: np.ndarray
aperiodic: np.ndarray
periodic: np.ndarray
10 changes: 8 additions & 2 deletions pyrasa/utils/irasa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,22 @@ def _gen_irasa(


def _crop_data(
band: list | tuple, freqs: np.ndarray, psd_aperiodic: np.ndarray, psd_periodic: np.ndarray, axis: int
band: list | tuple,
freqs: np.ndarray,
psd_aperiodic: np.ndarray,
psd_periodic: np.ndarray,
psd: np.ndarray,
axis: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Utility function to crop spectra to a defined frequency range"""

mask_freqs = np.ma.masked_outside(freqs, *band).mask
freqs = freqs[~mask_freqs]
psd_aperiodic = np.compress(~mask_freqs, psd_aperiodic, axis=axis)
psd_periodic = np.compress(~mask_freqs, psd_periodic, axis=axis)
psd = np.compress(~mask_freqs, psd, axis=axis)

return freqs, psd_aperiodic, psd_periodic
return freqs, psd_aperiodic, psd_periodic, psd


def _gen_time_from_sft(SFT: type[dsp.ShortTimeFFT], sgramm: np.ndarray) -> np.ndarray: # noqa N803
Expand Down
8 changes: 8 additions & 0 deletions pyrasa/utils/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Protocol, TypedDict

import numpy as np
import pandas as pd
from attrs import define


class IrasaFun(Protocol):
Expand All @@ -15,3 +17,9 @@ class IrasaSprintKwargsTyped(TypedDict):
win_duration: float
dpss_settings: dict
win_kwargs: dict


@define
class SlopeFit:
aperiodic_params: pd.DataFrame
gof: pd.DataFrame
23 changes: 10 additions & 13 deletions tests/test_basic_irasa.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import numpy as np
import pytest
import scipy.signal as dsp

from pyrasa import irasa
from pyrasa.utils.aperiodic_utils import compute_slope
from pyrasa.utils.peak_utils import get_peak_params

from .settings import EXPONENT, FS, MIN_CORR_PSD_CMB, OSC_FREQ, TOLERANCE

Expand All @@ -18,20 +15,20 @@
@pytest.mark.parametrize('fs', FS, scope='session')
def test_irasa(combined_signal, fs, osc_freq, exponent):
f_range = [1, 100]
irasa_out = irasa(combined_signal, fs, f_range, psd_kwargs={'nperseg': 4 * fs})
irasa_spectrum = irasa(combined_signal, fs, f_range, psd_kwargs={'nperseg': 4 * fs})
# test the shape of the output
assert irasa_out.freqs.shape[0] == irasa_out.aperiodic.shape[1] == irasa_out.periodic.shape[1]
assert irasa_spectrum.freqs.shape[0] == irasa_spectrum.aperiodic.shape[1] == irasa_spectrum.periodic.shape[1]
# test the selected frequency range
assert bool(np.logical_and(irasa_out.freqs[0] == f_range[0], irasa_out.freqs[-1] == f_range[1]))
assert bool(np.logical_and(irasa_spectrum.freqs[0] == f_range[0], irasa_spectrum.freqs[-1] == f_range[1]))
# test whether recombining periodic and aperiodic spectrum is equivalent to the original spectrum
freqs_psd, psd = dsp.welch(combined_signal, fs, nperseg=int(4 * fs))
psd_cmb = irasa_out.aperiodic[0, :] + irasa_out.periodic[0, :]
freq_logical = np.logical_and(freqs_psd >= f_range[0], freqs_psd <= f_range[1])
r = np.corrcoef(psd[freq_logical], psd_cmb)[0, 1]
# freqs_psd, psd = dsp.welch(combined_signal, fs, nperseg=int(4 * fs))
psd_cmb = irasa_spectrum.aperiodic[0, :] + irasa_spectrum.periodic[0, :]
# freq_logical = np.logical_and(freqs_psd >= f_range[0], freqs_psd <= f_range[1])
r = np.corrcoef(irasa_spectrum.raw_spectrum, psd_cmb)[0, 1]
assert r > MIN_CORR_PSD_CMB
# test whether we can reconstruct the exponent correctly
ap_params, _ = compute_slope(irasa_out.aperiodic, irasa_out.freqs, fit_func='fixed')
assert bool(np.isclose(ap_params['Exponent'][0], np.abs(exponent), atol=TOLERANCE))
slope_fit = irasa_spectrum.get_slopes(fit_func='fixed')
assert bool(np.isclose(slope_fit.aperiodic_params['Exponent'][0], np.abs(exponent), atol=TOLERANCE))
# test whether we can reconstruct the peak frequency correctly
pe_params = get_peak_params(irasa_out.periodic, irasa_out.freqs)
pe_params = irasa_spectrum.get_peaks()
assert bool(np.isclose(np.round(pe_params['cf'], 0), osc_freq))
Loading

0 comments on commit d19b4cb

Please sign in to comment.