Skip to content

Commit

Permalink
add fixed channel info
Browse files Browse the repository at this point in the history
  • Loading branch information
Schmidt Fabian committed Aug 1, 2024
1 parent 3a44d8f commit bc0c916
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 13 deletions.
7 changes: 6 additions & 1 deletion pyrasa/irasa.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def irasa(
fs: int,
band: tuple[float, float],
psd_kwargs: dict,
ch_names: np.ndarray | None = None,
win_func: Callable = dsp.windows.hann,
win_func_kwargs: dict | None = None,
dpss_settings_time_bandwidth: float = 2.0,
Expand Down Expand Up @@ -141,13 +142,16 @@ def _local_irasa_fun(

freq, psd_aperiodic, psd_periodic, psd = _crop_data(band, freq, psd_aperiodic, psd_periodic, psd, axis=-1)

return IrasaSpectrum(freqs=freq, raw_spectrum=psd, aperiodic=psd_aperiodic, periodic=psd_periodic)
return IrasaSpectrum(
freqs=freq, raw_spectrum=psd, aperiodic=psd_aperiodic, periodic=psd_periodic, ch_names=ch_names
)


# irasa sprint
def irasa_sprint( # noqa PLR0915 C901
data: np.ndarray,
fs: int,
ch_names: np.ndarray | None = None,
band: tuple[float, float] = (1.0, 100.0),
freq_res: float = 0.5,
win_duration: float = 0.4,
Expand Down Expand Up @@ -300,4 +304,5 @@ def _local_irasa_fun(
raw_spectrum=sgramm,
periodic=sgramm_periodic[:, t_mask],
aperiodic=sgramm_aperiodic[:, t_mask],
ch_names=ch_names,
)
10 changes: 5 additions & 5 deletions pyrasa/utils/aperiodic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def compute_slope(
aperiodic_spectrum: np.ndarray,
freqs: np.ndarray,
fit_func: str,
ch_names: Iterable = (),
ch_names: Iterable | None = None,
scale: bool = False,
fit_bounds: tuple[float, float] | None = None,
) -> SlopeFit:
Expand Down Expand Up @@ -179,8 +179,8 @@ def compute_slope(
assert freqs.ndim == 1, 'freqs needs to be of shape (freqs,).'

assert isinstance(
ch_names, list | tuple | np.ndarray
), 'Channel names should be of type list, tuple or numpy.ndarray'
ch_names, list | tuple | np.ndarray | None
), 'Channel names should be of type list, tuple or numpy.ndarray or None'

if fit_bounds is not None:
fmin, fmax = freqs.min(), freqs.max()
Expand All @@ -196,7 +196,7 @@ def compute_slope(
aperiodic_spectrum = aperiodic_spectrum[:, 1:]

# generate channel names if not given
if len(ch_names) == 0:
if ch_names is None:
ch_names = np.arange(aperiodic_spectrum.shape[0])

if scale:
Expand Down Expand Up @@ -235,7 +235,7 @@ def compute_slope_sprint(
times: np.ndarray,
fit_func: str,
scale: bool = False,
ch_names: Iterable = (),
ch_names: Iterable | None = None,
fit_bounds: tuple[float, float] | None = None,
) -> SlopeFit:
"""
Expand Down
10 changes: 5 additions & 5 deletions pyrasa/utils/irasa_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class IrasaSpectrum:
raw_spectrum: np.ndarray
aperiodic: np.ndarray
periodic: np.ndarray
# ch_names: np.ndarray | None
ch_names: np.ndarray | None

def get_slopes(
self, fit_func: str = 'fixed', scale: bool = False, fit_bounds: tuple[float, float] | None = None
Expand All @@ -39,9 +39,9 @@ def get_slopes(
"""
return compute_slope(
self.aperiodic,
self.freqs,
# ch_names=self.ch_names,
aperiodic_spectrum=self.aperiodic,
freqs=self.freqs,
ch_names=self.ch_names,
scale=scale,
fit_func=fit_func,
fit_bounds=fit_bounds,
Expand Down Expand Up @@ -83,7 +83,7 @@ def get_peaks(
return get_peak_params(
self.periodic,
self.freqs,
# self.ch_names,
self.ch_names,
smoothing_window=smoothing_window,
cut_spectrum=cut_spectrum,
peak_threshold=peak_threshold,
Expand Down
5 changes: 3 additions & 2 deletions pyrasa/utils/irasa_tf_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class IrasaTfSpectrum:
raw_spectrum: np.ndarray
aperiodic: np.ndarray
periodic: np.ndarray
ch_names: np.ndarray | None

def get_slopes(
self, fit_func: str = 'fixed', scale: bool = False, fit_bounds: tuple[float, float] | None = None
Expand Down Expand Up @@ -44,7 +45,7 @@ def get_slopes(
aperiodic_spectrum=self.aperiodic[np.newaxis, :, :] if self.aperiodic.ndim == min_ndim else self.aperiodic,
freqs=self.freqs,
times=self.time,
# ch_names=self.ch_names,
ch_names=self.ch_names,
scale=scale,
fit_func=fit_func,
fit_bounds=fit_bounds,
Expand Down Expand Up @@ -88,7 +89,7 @@ def get_peaks(
periodic_spectrum=self.periodic[np.newaxis, :, :] if self.periodic.ndim == min_ndim else self.periodic,
freqs=self.freqs,
times=self.time,
# self.ch_names,
ch_names=self.ch_names,
smooth=smooth,
smoothing_window=smoothing_window,
cut_spectrum=cut_spectrum,
Expand Down

0 comments on commit bc0c916

Please sign in to comment.