Skip to content

Commit

Permalink
streamlining irasa_sprint and irasa v1
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabi committed Aug 20, 2024
1 parent d093dbb commit 1c85c7e
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 44 deletions.
39 changes: 19 additions & 20 deletions examples/irasa_sprint.ipynb

Large diffs are not rendered by default.

17 changes: 10 additions & 7 deletions pyrasa/irasa.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def irasa_sprint( # noqa PLR0915 C901
ch_names: np.ndarray | None = None,
band: tuple[float, float] = (1.0, 100.0),
win_duration: float = 0.4,
hop: int = 10,
overlap_fraction: float = 0.90,
win_func: Callable = dsp.windows.hann,
win_func_kwargs: dict | None = None,
dpss_settings_time_bandwidth: float = 2.0,
Expand Down Expand Up @@ -209,8 +209,8 @@ def irasa_sprint( # noqa PLR0915 C901
The frequency range (lower and upper bounds in Hz) over which to compute the spectra. Default is (1.0, 100.0).
win_duration : float, optional
Duration of the window in seconds used for the short-time Fourier transforms (STFTs). Default is 0.4 seconds.
hop : int, optional
Time increment in signal samples for the sliding window in STFT. Default is 10 samples.
overlap_fraction : int, optional
The overlap between the STFT sliding windows as fraction. Default is .99 of the windows.
win_func : Callable, optional
Window function to be used in computing the time frequency spectrum. Default is `dsp.windows.hann`.
win_func_kwargs : dict | None, optional
Expand Down Expand Up @@ -280,14 +280,16 @@ def irasa_sprint( # noqa PLR0915 C901
hset = np.round(np.arange(*hset_info), hset_accuracy)
hset = [h for h in hset if h % 1 != 0] # filter integers

nfft = int(2 ** np.ceil(np.log2(np.max(hset) * win_duration * fs)))
win_kwargs = {'win_func': win_func, 'win_func_kwargs': win_func_kwargs}
dpss_settings = {
'time_bandwidth': dpss_settings_time_bandwidth,
'low_bias': dpss_settings_low_bias,
'eigenvalue_weighting': dpss_eigenvalue_weighting,
}

nfft = int(2 ** np.ceil(np.log2(np.max(hset) * win_duration * fs)))
hop = int((1 - overlap_fraction) * win_duration * fs)
# hop = int((1 - overlap_fraction) * nfft)
irasa_kwargs: IrasaSprintKwargsTyped = {
'nfft': nfft,
'hop': hop,
Expand Down Expand Up @@ -324,12 +326,13 @@ def _local_irasa_fun(
# adjust time info (i.e. cut the padded stuff)
tmax = data.shape[1] / fs
t_mask = np.logical_and(time >= 0, time < tmax)
freq_mask = freq > (1 / win_duration) # mask rayleigh

return IrasaTfSpectrum(
freqs=freq,
freqs=freq[freq_mask],
time=time[t_mask],
raw_spectrum=sgramm,
periodic=sgramm_periodic[:, t_mask],
aperiodic=sgramm_aperiodic[:, t_mask],
periodic=sgramm_periodic[:, t_mask][freq_mask, :],
aperiodic=sgramm_aperiodic[:, t_mask][freq_mask, :],
ch_names=ch_names,
)
8 changes: 6 additions & 2 deletions pyrasa/utils/fit_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,12 @@ def _get_gof(psd: np.ndarray, psd_pred: np.ndarray, k: int, fit_type: str) -> pd
mse = np.mean(residuals**2)
n = len(psd)

bic = n * np.log(mse) + k * np.log(n)
aic = n * np.log(mse) + 2 * k
loglik = -n / 2 * (1 + np.log(mse) + np.log(2 * np.pi))
aic = 2 * (k - loglik)
bic = k * np.log(n) - 2 * loglik

# bic = n * np.log(mse) + k * np.log(n)
# aic = n * np.log(mse) + 2 * k

gof = pd.DataFrame({'mse': mse, 'r_squared': 1 - (ss_res / ss_tot), 'BIC': bic, 'AIC': aic}, index=[0])
gof['fit_type'] = fit_type
Expand Down
2 changes: 1 addition & 1 deletion pyrasa/utils/irasa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ def _compute_sgramm( # noqa C901
x: np.ndarray,
fs: int,
nfft: int,
win_duration: float,
hop: int,
win_duration: float,
dpss_settings: dict,
win_kwargs: dict,
h: float = 1.0,
Expand Down
8 changes: 5 additions & 3 deletions simulations/notebooks/test_basic_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,15 @@
)

#%%
duration = 5

irasa_out_tf = irasa_sprint(
sig,
fs=fs,
band=(1, 50),
win_duration=4,
hset_info=(1, 3, 0.1),
win_duration=duration,
overlap_fraction=.9,
hset_info=(1, 2, 0.1),
)
#%%
from neurodsp.plts import plot_timefrequency
Expand All @@ -64,7 +67,6 @@
powers=irasa_out_tf.aperiodic)



# %%
f, axes = plt.subplots(ncols=2, figsize=(8, 4))
axes[0].set_title('Periodic')
Expand Down
9 changes: 5 additions & 4 deletions tests/test_compute_slope.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,9 @@ def curve_kwargs(self) -> dict[str, any]:
}

aperiodic_fit = compute_aperiodic_model(np.log10(psd), np.log10(freqs), fit_func=CustomFitFun)

# add a high tolerance
assert pytest.approx(np.abs(aperiodic_fit.aperiodic_params['b'][0]), abs=HIGH_TOLERANCE) == np.abs(exponent)

irasa_spectrum = irasa(fixed_aperiodic_signal, fs, f_range, psd_kwargs={'nperseg': 4 * fs})

class CustomFitFun(AbstractFitFun):
log10_aperiodic = True
log10_freq = True
Expand All @@ -134,4 +131,8 @@ def func(self, x: np.ndarray, a: float, b: float) -> np.ndarray:

return y_hat

irasa_spectrum.fit_aperiodic_model(fit_func=CustomFitFun)
irasa_spectrum = irasa(fixed_aperiodic_signal, fs, f_range, psd_kwargs={'nperseg': 4 * fs})
aperiodic_fit = irasa_spectrum.fit_aperiodic_model(fit_func=CustomFitFun)

# add a high tolerance
assert pytest.approx(np.abs(aperiodic_fit.aperiodic_params['b'][0]), abs=HIGH_TOLERANCE) == np.abs(exponent)
19 changes: 12 additions & 7 deletions tests/test_irasa_sprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pyrasa.irasa import irasa_sprint
from pyrasa.utils.peak_utils import get_band_info

from .settings import EXPONENT, FS, MIN_R2_SPRINT
from .settings import EXPONENT, FS, MIN_R2_SPRINT, TOLERANCE

set_random_seed(42)

Expand All @@ -16,10 +16,11 @@
def test_irasa_sprint(ts4sprint, fs, exponent_1, exponent_2):
irasa_tf = irasa_sprint(
ts4sprint[np.newaxis, :],
# hop=25,
# win_duration=1,
win_duration=0.5,
overlap_fraction=0.98,
fs=fs,
band=(0.1, 100),
band=(0.1, 50),
hset_info=(1.05, 4.0, 0.05),
)

# check basic aperiodic detection
Expand All @@ -28,14 +29,18 @@ def test_irasa_sprint(ts4sprint, fs, exponent_1, exponent_2):
# )

assert slope_fit.gof['r_squared'].mean() > MIN_R2_SPRINT
assert np.isclose(np.mean(slope_fit.aperiodic_params.query('time < 7')['Exponent']), np.abs(exponent_1), atol=0.5)
assert np.isclose(np.mean(slope_fit.aperiodic_params.query('time > 7')['Exponent']), np.abs(exponent_2), atol=0.5)
assert np.isclose(
np.mean(slope_fit.aperiodic_params.query('time < 7')['Exponent']), np.abs(exponent_1), atol=TOLERANCE
)
assert np.isclose(
np.mean(slope_fit.aperiodic_params.query('time > 7')['Exponent']), np.abs(exponent_2), atol=TOLERANCE
)

# check basic peak detection
df_peaks = irasa_tf.get_peaks(
cut_spectrum=(1, 40),
smooth=True,
smoothing_window=1,
smoothing_window=3,
min_peak_height=0.01,
peak_width_limits=(0.5, 12),
)
Expand Down

0 comments on commit 1c85c7e

Please sign in to comment.