Skip to content

Commit

Permalink
Merge pull request #48 from schmidtfa/add_attrs
Browse files Browse the repository at this point in the history
starting attrs integration
  • Loading branch information
schmidtfa authored Aug 1, 2024
2 parents 87a384d + bc0c916 commit 6d8fa14
Show file tree
Hide file tree
Showing 16 changed files with 423 additions and 225 deletions.
3 changes: 2 additions & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ classifiers = [
]
keywords=['spectral parametrization', 'oscillations', 'power spectra', '1/f']
requires-python = ">= 3.11"
dependencies = ["numpy", "pandas", "scipy>=1.12"]
dependencies = ["numpy", "pandas", "scipy>=1.12", "attrs"]

[project.optional-dependencies]
mne = ['mne']
Expand Down
79 changes: 25 additions & 54 deletions pyrasa/irasa.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,23 @@
import fractions
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

import numpy as np
import scipy.signal as dsp

from pyrasa.utils.irasa_spectrum import IrasaSpectrum
from pyrasa.utils.irasa_tf_spectrum import IrasaTfSpectrum

# from scipy.stats.mstats import gmean
from pyrasa.utils.irasa_utils import (
_check_irasa_settings,
_compute_psd_welch,
_compute_sgramm,
_crop_data, # _find_nearest, _gen_time_from_sft, _get_windows,
_crop_data,
_gen_irasa,
)
from pyrasa.utils.types import IrasaFun

if TYPE_CHECKING:
from pyrasa.utils.input_classes import IrasaSprintKwargsTyped


# TODO: Port to Cython
def _gen_irasa(
data: np.ndarray,
orig_spectrum: np.ndarray,
fs: int,
irasa_fun: IrasaFun,
hset: np.ndarray,
time: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
This function is implementing the IRASA algorithm using a custom function to
compute a power/cross-spectral density and returns an "original", "periodic" and "aperiodic spectrum".
This implementation of the IRASA algorithm is based on the yasa.irasa function in (Vallat & Walker, 2021).
[1] Vallat, Raphael, and Matthew P. Walker. “An open-source,
high-performance tool for automated sleep staging.”
Elife 10 (2021). doi: https://doi.org/10.7554/eLife.70092
"""

spectra = np.zeros((len(hset), *orig_spectrum.shape))
for i, h in enumerate(hset):
rat = fractions.Fraction(str(h))
up, down = rat.numerator, rat.denominator

# Much faster than FFT-based resampling
data_up = dsp.resample_poly(data, up, down, axis=-1)
data_down = dsp.resample_poly(data, down, up, axis=-1)

# Calculate an up/downsampled version of the PSD using same params as original
spectrum_up = irasa_fun(data=data_up, fs=int(fs * h), h=h, time_orig=time, up_down='up')
spectrum_dw = irasa_fun(data=data_down, fs=int(fs / h), h=h, time_orig=time, up_down='down')

# geometric mean between up and downsampled
# be aware of the input dimensions
if spectra.ndim == 2: # noqa PLR2004
spectra[i, :] = np.sqrt(spectrum_up * spectrum_dw)
if spectra.ndim == 3: # noqa PLR2004
spectra[i, :, :] = np.sqrt(spectrum_up * spectrum_dw)

aperiodic_spectrum = np.median(spectra, axis=0)
periodic_spectrum = orig_spectrum - aperiodic_spectrum
return orig_spectrum, aperiodic_spectrum, periodic_spectrum
from pyrasa.utils.types import IrasaSprintKwargsTyped


# %% irasa
Expand All @@ -68,6 +26,7 @@ def irasa(
fs: int,
band: tuple[float, float],
psd_kwargs: dict,
ch_names: np.ndarray | None = None,
win_func: Callable = dsp.windows.hann,
win_func_kwargs: dict | None = None,
dpss_settings_time_bandwidth: float = 2.0,
Expand All @@ -76,7 +35,7 @@ def irasa(
filter_settings: tuple[float | None, float | None] = (None, None),
hset_info: tuple[float, float, float] = (1.05, 2.0, 0.05),
hset_accuracy: int = 4,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> IrasaSpectrum:
"""
This function can be used to generate aperiodic and periodic power spectra from a time series
using the IRASA algorithm (Wen & Liu, 2016).
Expand Down Expand Up @@ -181,15 +140,18 @@ def _local_irasa_fun(
data=np.squeeze(data), orig_spectrum=psd, fs=fs, irasa_fun=_local_irasa_fun, hset=hset
)

freq, psd_aperiodic, psd_periodic = _crop_data(band, freq, psd_aperiodic, psd_periodic, axis=-1)
freq, psd_aperiodic, psd_periodic, psd = _crop_data(band, freq, psd_aperiodic, psd_periodic, psd, axis=-1)

return freq, psd_aperiodic, psd_periodic
return IrasaSpectrum(
freqs=freq, raw_spectrum=psd, aperiodic=psd_aperiodic, periodic=psd_periodic, ch_names=ch_names
)


# irasa sprint
def irasa_sprint( # noqa PLR0915 C901
data: np.ndarray,
fs: int,
ch_names: np.ndarray | None = None,
band: tuple[float, float] = (1.0, 100.0),
freq_res: float = 0.5,
win_duration: float = 0.4,
Expand All @@ -202,7 +164,7 @@ def irasa_sprint( # noqa PLR0915 C901
filter_settings: tuple[float | None, float | None] = (None, None),
hset_info: tuple[float, float, float] = (1.05, 2.0, 0.05),
hset_accuracy: int = 4,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
) -> IrasaTfSpectrum:
"""
This function can be used to seperate aperiodic from periodic power spectra
Expand Down Expand Up @@ -328,10 +290,19 @@ def _local_irasa_fun(
)

# NOTE: we need to transpose the data as crop_data extracts stuff from the last axis
freq, sgramm_aperiodic, sgramm_periodic = _crop_data(band, freq, sgramm_aperiodic, sgramm_periodic, axis=0)
freq, sgramm_aperiodic, sgramm_periodic, sgramm = _crop_data(
band, freq, sgramm_aperiodic, sgramm_periodic, sgramm, axis=0
)

# adjust time info (i.e. cut the padded stuff)
tmax = data.shape[1] / fs
t_mask = np.logical_and(time >= 0, time < tmax)

return sgramm_aperiodic[:, t_mask], sgramm_periodic[:, t_mask], freq, time[t_mask]
return IrasaTfSpectrum(
freqs=freq,
time=time[t_mask],
raw_spectrum=sgramm,
periodic=sgramm_periodic[:, t_mask],
aperiodic=sgramm_aperiodic[:, t_mask],
ch_names=ch_names,
)
51 changes: 24 additions & 27 deletions pyrasa/irasa_mne/irasa_mne.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from pyrasa.irasa_mne.mne_objs import (
AperiodicEpochsSpectrum,
AperiodicSpectrumArray,
IrasaEpoched,
IrasaRaw,
PeriodicEpochsSpectrum,
PeriodicSpectrumArray,
)
Expand All @@ -16,8 +18,7 @@ def irasa_raw(
duration: float | None = None,
overlap: float | int = 50,
hset_info: tuple[float, float, float] = (1.05, 2.0, 0.05),
as_array: bool = False,
) -> tuple[AperiodicSpectrumArray, PeriodicSpectrumArray] | tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> IrasaRaw:
"""
This function can be used to seperate aperiodic from periodic power spectra using
the IRASA algorithm (Wen & Liu, 2016).
Expand Down Expand Up @@ -87,7 +88,7 @@ def irasa_raw(
'noverlap': int(fs * duration * overlap),
}

freq, psd_aperiodic, psd_periodic = irasa(
irasa_spectrum = irasa(
data_array,
fs=fs,
band=band,
Expand All @@ -96,22 +97,17 @@ def irasa_raw(
psd_kwargs=kwargs_psd,
)

if as_array is True:
return psd_aperiodic, psd_periodic, freq

else:
aperiodic = AperiodicSpectrumArray(psd_aperiodic, info, freqs=freq)
periodic = PeriodicSpectrumArray(psd_periodic, info, freqs=freq)

return aperiodic, periodic
return IrasaRaw(
periodic=PeriodicSpectrumArray(irasa_spectrum.periodic, info, freqs=irasa_spectrum.freqs),
aperiodic=AperiodicSpectrumArray(irasa_spectrum.aperiodic, info, freqs=irasa_spectrum.freqs),
)


def irasa_epochs(
data: mne.Epochs,
band: tuple[float, float] = (1.0, 100.0),
hset_info: tuple[float, float, float] = (1.05, 2.0, 0.05),
as_array: bool = False,
) -> tuple[AperiodicSpectrumArray, PeriodicSpectrumArray] | tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> IrasaEpoched:
"""
This function can be used to seperate aperiodic from periodic power spectra
using the IRASA algorithm (Wen & Liu, 2016).
Expand Down Expand Up @@ -177,24 +173,25 @@ def irasa_epochs(
# Do the actual IRASA stuff..
psd_list_aperiodic, psd_list_periodic = [], []
for epoch in data_array:
freq, psd_aperiodic, psd_periodic = irasa(
irasa_spectrum = irasa(
epoch,
fs=fs,
band=band,
filter_settings=(data.info['highpass'], data.info['lowpass']),
hset_info=hset_info,
psd_kwargs=kwargs_psd,
)
psd_list_aperiodic.append(psd_aperiodic)
psd_list_periodic.append(psd_periodic)

psd_aperiodic = np.array(psd_list_aperiodic)
psd_periodic = np.array(psd_list_periodic)

if as_array is True:
return psd_aperiodic, psd_periodic, freq
else:
aperiodic = AperiodicEpochsSpectrum(psd_aperiodic, info, freqs=freq, events=events, event_id=event_ids)
periodic = PeriodicEpochsSpectrum(psd_periodic, info, freqs=freq, events=events, event_id=event_ids)

return aperiodic, periodic
psd_list_aperiodic.append(irasa_spectrum.aperiodic.copy())
psd_list_periodic.append(irasa_spectrum.periodic.copy())

psds_aperiodic = np.array(psd_list_aperiodic)
psds_periodic = np.array(psd_list_periodic)

return IrasaEpoched(
periodic=PeriodicEpochsSpectrum(
psds_periodic, info, freqs=irasa_spectrum.freqs, events=events, event_id=event_ids
),
aperiodic=AperiodicEpochsSpectrum(
psds_aperiodic, info, freqs=irasa_spectrum.freqs, events=events, event_id=event_ids
),
)
34 changes: 23 additions & 11 deletions pyrasa/irasa_mne/mne_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import mne
import numpy as np
import pandas as pd
from attrs import define
from mne.time_frequency import EpochsSpectrumArray, SpectrumArray

from pyrasa.utils.aperiodic_utils import compute_slope
from pyrasa.utils.peak_utils import get_peak_params
from pyrasa.utils.types import SlopeFit


class PeriodicSpectrumArray(SpectrumArray):
Expand Down Expand Up @@ -169,7 +171,7 @@ def __init__(

def get_slopes(
self: SpectrumArray, fit_func: str = 'fixed', scale: bool = False, fit_bounds: tuple[float, float] | None = None
) -> tuple[pd.DataFrame, pd.DataFrame]:
) -> SlopeFit:
"""
This method can be used to extract aperiodic parameters from the aperiodic spectrum extracted from IRASA.
The algorithm works by applying one of two different curve fit functions and returns the associated parameters,
Expand All @@ -190,7 +192,7 @@ def get_slopes(
"""

df_aps, df_gof = compute_slope(
return compute_slope(
self.get_data(),
self.freqs,
ch_names=self.ch_names,
Expand All @@ -199,8 +201,6 @@ def get_slopes(
fit_bounds=fit_bounds,
)

return df_aps, df_gof


# %%
class PeriodicEpochsSpectrum(EpochsSpectrumArray):
Expand Down Expand Up @@ -394,7 +394,7 @@ def __init__(

def get_slopes(
self: SpectrumArray, fit_func: str = 'fixed', scale: bool = False, fit_bounds: tuple[float, float] | None = None
) -> tuple[pd.DataFrame, pd.DataFrame]:
) -> SlopeFit:
"""
This method can be used to extract aperiodic parameters from the aperiodic spectrum extracted from IRASA.
The algorithm works by applying one of two different curve fit functions and returns the associated parameters,
Expand All @@ -420,7 +420,7 @@ def get_slopes(

aps_list, gof_list = [], []
for ix, cur_epoch in enumerate(self.get_data()):
df_aps, df_gof = compute_slope(
slope_fit = compute_slope(
cur_epoch,
self.freqs,
ch_names=self.ch_names,
Expand All @@ -429,9 +429,21 @@ def get_slopes(
fit_bounds=fit_bounds,
)

df_aps['event_id'] = event_dict[events[ix]]
df_gof['event_id'] = event_dict[events[ix]]
aps_list.append(df_aps)
gof_list.append(df_gof)
slope_fit.aperiodic_params['event_id'] = event_dict[events[ix]]
slope_fit.gof['event_id'] = event_dict[events[ix]]
aps_list.append(slope_fit.aperiodic_params.copy())
gof_list.append(slope_fit.gof.copy())

return SlopeFit(aperiodic_params=pd.concat(aps_list), gof=pd.concat(gof_list))


@define
class IrasaRaw:
periodic: PeriodicSpectrumArray
aperiodic: AperiodicSpectrumArray


return pd.concat(aps_list), pd.concat(gof_list)
@define
class IrasaEpoched:
periodic: PeriodicEpochsSpectrum
aperiodic: AperiodicEpochsSpectrum
Loading

0 comments on commit 6d8fa14

Please sign in to comment.