From bc0c916187ac75850b4d8339ca0e17dca47fcd6e Mon Sep 17 00:00:00 2001 From: Schmidt Fabian Date: Thu, 1 Aug 2024 13:39:25 +0200 Subject: [PATCH] add fixed channel info --- pyrasa/irasa.py | 7 ++++++- pyrasa/utils/aperiodic_utils.py | 10 +++++----- pyrasa/utils/irasa_spectrum.py | 10 +++++----- pyrasa/utils/irasa_tf_spectrum.py | 5 +++-- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/pyrasa/irasa.py b/pyrasa/irasa.py index 29ad408..279a5e2 100644 --- a/pyrasa/irasa.py +++ b/pyrasa/irasa.py @@ -26,6 +26,7 @@ def irasa( fs: int, band: tuple[float, float], psd_kwargs: dict, + ch_names: np.ndarray | None = None, win_func: Callable = dsp.windows.hann, win_func_kwargs: dict | None = None, dpss_settings_time_bandwidth: float = 2.0, @@ -141,13 +142,16 @@ def _local_irasa_fun( freq, psd_aperiodic, psd_periodic, psd = _crop_data(band, freq, psd_aperiodic, psd_periodic, psd, axis=-1) - return IrasaSpectrum(freqs=freq, raw_spectrum=psd, aperiodic=psd_aperiodic, periodic=psd_periodic) + return IrasaSpectrum( + freqs=freq, raw_spectrum=psd, aperiodic=psd_aperiodic, periodic=psd_periodic, ch_names=ch_names + ) # irasa sprint def irasa_sprint( # noqa PLR0915 C901 data: np.ndarray, fs: int, + ch_names: np.ndarray | None = None, band: tuple[float, float] = (1.0, 100.0), freq_res: float = 0.5, win_duration: float = 0.4, @@ -300,4 +304,5 @@ def _local_irasa_fun( raw_spectrum=sgramm, periodic=sgramm_periodic[:, t_mask], aperiodic=sgramm_aperiodic[:, t_mask], + ch_names=ch_names, ) diff --git a/pyrasa/utils/aperiodic_utils.py b/pyrasa/utils/aperiodic_utils.py index b014e8a..030adf4 100644 --- a/pyrasa/utils/aperiodic_utils.py +++ b/pyrasa/utils/aperiodic_utils.py @@ -136,7 +136,7 @@ def compute_slope( aperiodic_spectrum: np.ndarray, freqs: np.ndarray, fit_func: str, - ch_names: Iterable = (), + ch_names: Iterable | None = None, scale: bool = False, fit_bounds: tuple[float, float] | None = None, ) -> SlopeFit: @@ -179,8 +179,8 @@ def compute_slope( assert freqs.ndim == 1, 'freqs needs to be of shape (freqs,).' assert isinstance( - ch_names, list | tuple | np.ndarray - ), 'Channel names should be of type list, tuple or numpy.ndarray' + ch_names, list | tuple | np.ndarray | None + ), 'Channel names should be of type list, tuple or numpy.ndarray or None' if fit_bounds is not None: fmin, fmax = freqs.min(), freqs.max() @@ -196,7 +196,7 @@ def compute_slope( aperiodic_spectrum = aperiodic_spectrum[:, 1:] # generate channel names if not given - if len(ch_names) == 0: + if ch_names is None: ch_names = np.arange(aperiodic_spectrum.shape[0]) if scale: @@ -235,7 +235,7 @@ def compute_slope_sprint( times: np.ndarray, fit_func: str, scale: bool = False, - ch_names: Iterable = (), + ch_names: Iterable | None = None, fit_bounds: tuple[float, float] | None = None, ) -> SlopeFit: """ diff --git a/pyrasa/utils/irasa_spectrum.py b/pyrasa/utils/irasa_spectrum.py index 304be1b..9c7014b 100644 --- a/pyrasa/utils/irasa_spectrum.py +++ b/pyrasa/utils/irasa_spectrum.py @@ -13,7 +13,7 @@ class IrasaSpectrum: raw_spectrum: np.ndarray aperiodic: np.ndarray periodic: np.ndarray - # ch_names: np.ndarray | None + ch_names: np.ndarray | None def get_slopes( self, fit_func: str = 'fixed', scale: bool = False, fit_bounds: tuple[float, float] | None = None @@ -39,9 +39,9 @@ def get_slopes( """ return compute_slope( - self.aperiodic, - self.freqs, - # ch_names=self.ch_names, + aperiodic_spectrum=self.aperiodic, + freqs=self.freqs, + ch_names=self.ch_names, scale=scale, fit_func=fit_func, fit_bounds=fit_bounds, @@ -83,7 +83,7 @@ def get_peaks( return get_peak_params( self.periodic, self.freqs, - # self.ch_names, + self.ch_names, smoothing_window=smoothing_window, cut_spectrum=cut_spectrum, peak_threshold=peak_threshold, diff --git a/pyrasa/utils/irasa_tf_spectrum.py b/pyrasa/utils/irasa_tf_spectrum.py index f37d800..cb78fea 100644 --- a/pyrasa/utils/irasa_tf_spectrum.py +++ b/pyrasa/utils/irasa_tf_spectrum.py @@ -16,6 +16,7 @@ class IrasaTfSpectrum: raw_spectrum: np.ndarray aperiodic: np.ndarray periodic: np.ndarray + ch_names: np.ndarray | None def get_slopes( self, fit_func: str = 'fixed', scale: bool = False, fit_bounds: tuple[float, float] | None = None @@ -44,7 +45,7 @@ def get_slopes( aperiodic_spectrum=self.aperiodic[np.newaxis, :, :] if self.aperiodic.ndim == min_ndim else self.aperiodic, freqs=self.freqs, times=self.time, - # ch_names=self.ch_names, + ch_names=self.ch_names, scale=scale, fit_func=fit_func, fit_bounds=fit_bounds, @@ -88,7 +89,7 @@ def get_peaks( periodic_spectrum=self.periodic[np.newaxis, :, :] if self.periodic.ndim == min_ndim else self.periodic, freqs=self.freqs, times=self.time, - # self.ch_names, + ch_names=self.ch_names, smooth=smooth, smoothing_window=smoothing_window, cut_spectrum=cut_spectrum,