Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

starting attrs integration #48

Merged
merged 6 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ classifiers = [
]
keywords=['spectral parametrization', 'oscillations', 'power spectra', '1/f']
requires-python = ">= 3.11"
dependencies = ["numpy", "pandas", "scipy>=1.12"]
dependencies = ["numpy", "pandas", "scipy>=1.12", "attrs"]

[project.optional-dependencies]
mne = ['mne']
Expand Down
79 changes: 25 additions & 54 deletions pyrasa/irasa.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,23 @@
import fractions
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

import numpy as np
import scipy.signal as dsp

from pyrasa.utils.irasa_spectrum import IrasaSpectrum
from pyrasa.utils.irasa_tf_spectrum import IrasaTfSpectrum

# from scipy.stats.mstats import gmean
from pyrasa.utils.irasa_utils import (
_check_irasa_settings,
_compute_psd_welch,
_compute_sgramm,
_crop_data, # _find_nearest, _gen_time_from_sft, _get_windows,
_crop_data,
_gen_irasa,
)
from pyrasa.utils.types import IrasaFun

if TYPE_CHECKING:
from pyrasa.utils.input_classes import IrasaSprintKwargsTyped


# TODO: Port to Cython
def _gen_irasa(
data: np.ndarray,
orig_spectrum: np.ndarray,
fs: int,
irasa_fun: IrasaFun,
hset: np.ndarray,
time: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
This function is implementing the IRASA algorithm using a custom function to
compute a power/cross-spectral density and returns an "original", "periodic" and "aperiodic spectrum".
This implementation of the IRASA algorithm is based on the yasa.irasa function in (Vallat & Walker, 2021).

[1] Vallat, Raphael, and Matthew P. Walker. “An open-source,
high-performance tool for automated sleep staging.”
Elife 10 (2021). doi: https://doi.org/10.7554/eLife.70092
"""

spectra = np.zeros((len(hset), *orig_spectrum.shape))
for i, h in enumerate(hset):
rat = fractions.Fraction(str(h))
up, down = rat.numerator, rat.denominator

# Much faster than FFT-based resampling
data_up = dsp.resample_poly(data, up, down, axis=-1)
data_down = dsp.resample_poly(data, down, up, axis=-1)

# Calculate an up/downsampled version of the PSD using same params as original
spectrum_up = irasa_fun(data=data_up, fs=int(fs * h), h=h, time_orig=time, up_down='up')
spectrum_dw = irasa_fun(data=data_down, fs=int(fs / h), h=h, time_orig=time, up_down='down')

# geometric mean between up and downsampled
# be aware of the input dimensions
if spectra.ndim == 2: # noqa PLR2004
spectra[i, :] = np.sqrt(spectrum_up * spectrum_dw)
if spectra.ndim == 3: # noqa PLR2004
spectra[i, :, :] = np.sqrt(spectrum_up * spectrum_dw)

aperiodic_spectrum = np.median(spectra, axis=0)
periodic_spectrum = orig_spectrum - aperiodic_spectrum
return orig_spectrum, aperiodic_spectrum, periodic_spectrum
from pyrasa.utils.types import IrasaSprintKwargsTyped


# %% irasa
Expand All @@ -68,6 +26,7 @@ def irasa(
fs: int,
band: tuple[float, float],
psd_kwargs: dict,
ch_names: np.ndarray | None = None,
win_func: Callable = dsp.windows.hann,
win_func_kwargs: dict | None = None,
dpss_settings_time_bandwidth: float = 2.0,
Expand All @@ -76,7 +35,7 @@ def irasa(
filter_settings: tuple[float | None, float | None] = (None, None),
hset_info: tuple[float, float, float] = (1.05, 2.0, 0.05),
hset_accuracy: int = 4,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> IrasaSpectrum:
"""
This function can be used to generate aperiodic and periodic power spectra from a time series
using the IRASA algorithm (Wen & Liu, 2016).
Expand Down Expand Up @@ -181,15 +140,18 @@ def _local_irasa_fun(
data=np.squeeze(data), orig_spectrum=psd, fs=fs, irasa_fun=_local_irasa_fun, hset=hset
)

freq, psd_aperiodic, psd_periodic = _crop_data(band, freq, psd_aperiodic, psd_periodic, axis=-1)
freq, psd_aperiodic, psd_periodic, psd = _crop_data(band, freq, psd_aperiodic, psd_periodic, psd, axis=-1)

return freq, psd_aperiodic, psd_periodic
return IrasaSpectrum(
freqs=freq, raw_spectrum=psd, aperiodic=psd_aperiodic, periodic=psd_periodic, ch_names=ch_names
)


# irasa sprint
def irasa_sprint( # noqa PLR0915 C901
data: np.ndarray,
fs: int,
ch_names: np.ndarray | None = None,
band: tuple[float, float] = (1.0, 100.0),
freq_res: float = 0.5,
win_duration: float = 0.4,
Expand All @@ -202,7 +164,7 @@ def irasa_sprint( # noqa PLR0915 C901
filter_settings: tuple[float | None, float | None] = (None, None),
hset_info: tuple[float, float, float] = (1.05, 2.0, 0.05),
hset_accuracy: int = 4,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
) -> IrasaTfSpectrum:
"""

This function can be used to seperate aperiodic from periodic power spectra
Expand Down Expand Up @@ -328,10 +290,19 @@ def _local_irasa_fun(
)

# NOTE: we need to transpose the data as crop_data extracts stuff from the last axis
freq, sgramm_aperiodic, sgramm_periodic = _crop_data(band, freq, sgramm_aperiodic, sgramm_periodic, axis=0)
freq, sgramm_aperiodic, sgramm_periodic, sgramm = _crop_data(
band, freq, sgramm_aperiodic, sgramm_periodic, sgramm, axis=0
)

# adjust time info (i.e. cut the padded stuff)
tmax = data.shape[1] / fs
t_mask = np.logical_and(time >= 0, time < tmax)

return sgramm_aperiodic[:, t_mask], sgramm_periodic[:, t_mask], freq, time[t_mask]
return IrasaTfSpectrum(
freqs=freq,
time=time[t_mask],
raw_spectrum=sgramm,
periodic=sgramm_periodic[:, t_mask],
aperiodic=sgramm_aperiodic[:, t_mask],
ch_names=ch_names,
)
51 changes: 24 additions & 27 deletions pyrasa/irasa_mne/irasa_mne.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from pyrasa.irasa_mne.mne_objs import (
AperiodicEpochsSpectrum,
AperiodicSpectrumArray,
IrasaEpoched,
IrasaRaw,
PeriodicEpochsSpectrum,
PeriodicSpectrumArray,
)
Expand All @@ -16,8 +18,7 @@ def irasa_raw(
duration: float | None = None,
overlap: float | int = 50,
hset_info: tuple[float, float, float] = (1.05, 2.0, 0.05),
as_array: bool = False,
) -> tuple[AperiodicSpectrumArray, PeriodicSpectrumArray] | tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> IrasaRaw:
"""
This function can be used to seperate aperiodic from periodic power spectra using
the IRASA algorithm (Wen & Liu, 2016).
Expand Down Expand Up @@ -87,7 +88,7 @@ def irasa_raw(
'noverlap': int(fs * duration * overlap),
}

freq, psd_aperiodic, psd_periodic = irasa(
irasa_spectrum = irasa(
data_array,
fs=fs,
band=band,
Expand All @@ -96,22 +97,17 @@ def irasa_raw(
psd_kwargs=kwargs_psd,
)

if as_array is True:
return psd_aperiodic, psd_periodic, freq

else:
aperiodic = AperiodicSpectrumArray(psd_aperiodic, info, freqs=freq)
periodic = PeriodicSpectrumArray(psd_periodic, info, freqs=freq)

return aperiodic, periodic
return IrasaRaw(
periodic=PeriodicSpectrumArray(irasa_spectrum.periodic, info, freqs=irasa_spectrum.freqs),
aperiodic=AperiodicSpectrumArray(irasa_spectrum.aperiodic, info, freqs=irasa_spectrum.freqs),
)


def irasa_epochs(
data: mne.Epochs,
band: tuple[float, float] = (1.0, 100.0),
hset_info: tuple[float, float, float] = (1.05, 2.0, 0.05),
as_array: bool = False,
) -> tuple[AperiodicSpectrumArray, PeriodicSpectrumArray] | tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> IrasaEpoched:
"""
This function can be used to seperate aperiodic from periodic power spectra
using the IRASA algorithm (Wen & Liu, 2016).
Expand Down Expand Up @@ -177,24 +173,25 @@ def irasa_epochs(
# Do the actual IRASA stuff..
psd_list_aperiodic, psd_list_periodic = [], []
for epoch in data_array:
freq, psd_aperiodic, psd_periodic = irasa(
irasa_spectrum = irasa(
epoch,
fs=fs,
band=band,
filter_settings=(data.info['highpass'], data.info['lowpass']),
hset_info=hset_info,
psd_kwargs=kwargs_psd,
)
psd_list_aperiodic.append(psd_aperiodic)
psd_list_periodic.append(psd_periodic)

psd_aperiodic = np.array(psd_list_aperiodic)
psd_periodic = np.array(psd_list_periodic)

if as_array is True:
return psd_aperiodic, psd_periodic, freq
else:
aperiodic = AperiodicEpochsSpectrum(psd_aperiodic, info, freqs=freq, events=events, event_id=event_ids)
periodic = PeriodicEpochsSpectrum(psd_periodic, info, freqs=freq, events=events, event_id=event_ids)

return aperiodic, periodic
psd_list_aperiodic.append(irasa_spectrum.aperiodic.copy())
psd_list_periodic.append(irasa_spectrum.periodic.copy())

psds_aperiodic = np.array(psd_list_aperiodic)
psds_periodic = np.array(psd_list_periodic)

return IrasaEpoched(
periodic=PeriodicEpochsSpectrum(
psds_periodic, info, freqs=irasa_spectrum.freqs, events=events, event_id=event_ids
),
aperiodic=AperiodicEpochsSpectrum(
psds_aperiodic, info, freqs=irasa_spectrum.freqs, events=events, event_id=event_ids
),
)
34 changes: 23 additions & 11 deletions pyrasa/irasa_mne/mne_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import mne
import numpy as np
import pandas as pd
from attrs import define
from mne.time_frequency import EpochsSpectrumArray, SpectrumArray

from pyrasa.utils.aperiodic_utils import compute_slope
from pyrasa.utils.peak_utils import get_peak_params
from pyrasa.utils.types import SlopeFit


class PeriodicSpectrumArray(SpectrumArray):
Expand Down Expand Up @@ -169,7 +171,7 @@ def __init__(

def get_slopes(
self: SpectrumArray, fit_func: str = 'fixed', scale: bool = False, fit_bounds: tuple[float, float] | None = None
) -> tuple[pd.DataFrame, pd.DataFrame]:
) -> SlopeFit:
"""
This method can be used to extract aperiodic parameters from the aperiodic spectrum extracted from IRASA.
The algorithm works by applying one of two different curve fit functions and returns the associated parameters,
Expand All @@ -190,7 +192,7 @@ def get_slopes(

"""

df_aps, df_gof = compute_slope(
return compute_slope(
self.get_data(),
self.freqs,
ch_names=self.ch_names,
Expand All @@ -199,8 +201,6 @@ def get_slopes(
fit_bounds=fit_bounds,
)

return df_aps, df_gof


# %%
class PeriodicEpochsSpectrum(EpochsSpectrumArray):
Expand Down Expand Up @@ -394,7 +394,7 @@ def __init__(

def get_slopes(
self: SpectrumArray, fit_func: str = 'fixed', scale: bool = False, fit_bounds: tuple[float, float] | None = None
) -> tuple[pd.DataFrame, pd.DataFrame]:
) -> SlopeFit:
"""
This method can be used to extract aperiodic parameters from the aperiodic spectrum extracted from IRASA.
The algorithm works by applying one of two different curve fit functions and returns the associated parameters,
Expand All @@ -420,7 +420,7 @@ def get_slopes(

aps_list, gof_list = [], []
for ix, cur_epoch in enumerate(self.get_data()):
df_aps, df_gof = compute_slope(
slope_fit = compute_slope(
cur_epoch,
self.freqs,
ch_names=self.ch_names,
Expand All @@ -429,9 +429,21 @@ def get_slopes(
fit_bounds=fit_bounds,
)

df_aps['event_id'] = event_dict[events[ix]]
df_gof['event_id'] = event_dict[events[ix]]
aps_list.append(df_aps)
gof_list.append(df_gof)
slope_fit.aperiodic_params['event_id'] = event_dict[events[ix]]
slope_fit.gof['event_id'] = event_dict[events[ix]]
aps_list.append(slope_fit.aperiodic_params.copy())
gof_list.append(slope_fit.gof.copy())

return SlopeFit(aperiodic_params=pd.concat(aps_list), gof=pd.concat(gof_list))


@define
class IrasaRaw:
periodic: PeriodicSpectrumArray
aperiodic: AperiodicSpectrumArray


return pd.concat(aps_list), pd.concat(gof_list)
@define
class IrasaEpoched:
periodic: PeriodicEpochsSpectrum
aperiodic: AperiodicEpochsSpectrum
Loading