From e1f2fdf4911a0de5b1c473f295bc02b7dc76f6c0 Mon Sep 17 00:00:00 2001 From: Fabi Date: Thu, 8 Aug 2024 22:52:07 +0200 Subject: [PATCH] added tests and fixed issues --- examples/irasa_sprint.ipynb | 21 ++++++++++++++------- tests/conftest.py | 16 ++++++++-------- tests/test_irasa_knee.py | 12 ++++++------ tests/test_irasa_sprint.py | 33 +++++++++++++++++++++------------ 4 files changed, 49 insertions(+), 33 deletions(-) diff --git a/examples/irasa_sprint.ipynb b/examples/irasa_sprint.ipynb index ec4934f..7905c87 100644 --- a/examples/irasa_sprint.ipynb +++ b/examples/irasa_sprint.ipynb @@ -67,7 +67,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -89,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -114,7 +114,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -123,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -155,7 +155,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -169,7 +169,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -198,7 +198,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -269,6 +269,13 @@ "outputs": [], "source": [] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, diff --git a/tests/conftest.py b/tests/conftest.py index d1f4924..b6cc310 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,9 +28,10 @@ def knee_aperiodic_signal(exponent, fs, knee_freq): @pytest.fixture(scope='session') -def load_knee_aperiodic_signal(exponent, fs, knee_freq): +def load_knee_aperiodic_signal(exponent, fs, knee): # % generate and save knee - knee_sim = sim_knee(n_seconds=N_SECONDS, fs=fs, exponent1=0, exponent2=exponent, knee=knee_freq) + # knee = knee ** np.abs(exponent) + knee_sim = sim_knee(n_seconds=N_SECONDS, fs=fs, exponent1=0, exponent2=exponent, knee=knee) yield knee_sim # base_dir = 'tests/test_data/knee_data/' # yield np.load( @@ -39,8 +40,8 @@ def load_knee_aperiodic_signal(exponent, fs, knee_freq): @pytest.fixture(scope='session') -def load_knee_cmb_signal(exponent, fs, knee_freq, osc_freq): - knee = knee_freq ** np.abs(exponent) +def load_knee_cmb_signal(exponent, fs, knee, osc_freq): + # knee = knee ** np.abs(exponent) components = { 'sim_knee': {'exponent1': 0, 'exponent2': exponent, 'knee': knee}, 'sim_oscillation': {'freq': osc_freq}, @@ -60,15 +61,14 @@ def oscillation(osc_freq, fs): @pytest.fixture(scope='session') -def ts4sprint(): - fs = 500 +def ts4sprint(fs, exponent_1, exponent_2): alpha = sim_oscillation(n_seconds=0.5, fs=fs, freq=10) no_alpha = np.zeros(len(alpha)) beta = sim_oscillation(n_seconds=0.5, fs=fs, freq=25) no_beta = np.zeros(len(beta)) - exp_1 = sim_powerlaw(n_seconds=2.5, fs=fs, exponent=-1) - exp_2 = sim_powerlaw(n_seconds=2.5, fs=fs, exponent=-2) + exp_1 = sim_powerlaw(n_seconds=2.5, fs=fs, exponent=exponent_1) + exp_2 = sim_powerlaw(n_seconds=2.5, fs=fs, exponent=exponent_2) # %% alphas = np.concatenate([no_alpha, alpha, no_alpha, alpha, no_alpha]) diff --git a/tests/test_irasa_knee.py b/tests/test_irasa_knee.py index ea83734..c858c87 100644 --- a/tests/test_irasa_knee.py +++ b/tests/test_irasa_knee.py @@ -11,9 +11,9 @@ # knee model -@pytest.mark.parametrize('exponent, knee_freq', EXP_KNEE_COMBO, scope='session') +@pytest.mark.parametrize('exponent, knee', EXP_KNEE_COMBO, scope='session') @pytest.mark.parametrize('fs', FS, scope='session') -def test_irasa_knee_peakless(load_knee_aperiodic_signal, fs, exponent, knee_freq): +def test_irasa_knee_peakless(load_knee_aperiodic_signal, fs, exponent, knee): f_range = [0.1, 100] irasa_out = irasa(load_knee_aperiodic_signal, fs, f_range, psd_kwargs={'nperseg': 4 * fs}) # test the shape of the output @@ -33,7 +33,7 @@ def test_irasa_knee_peakless(load_knee_aperiodic_signal, fs, exponent, knee_freq knee_hat = slope_fit_k.aperiodic_params['Knee'][0] ** ( 1 / (2 * slope_fit_k.aperiodic_params['Exponent_1'][0] + slope_fit_k.aperiodic_params['Exponent_2'][0]) ) - knee_real = knee_freq ** (1 / np.abs(exponent)) + knee_real = knee ** (1 / np.abs(exponent)) assert bool(np.isclose(knee_hat, knee_real, atol=KNEE_TOLERANCE)) # test bic/aic -> should be better for knee assert slope_fit_k.gof['AIC'][0] < slope_fit_f.gof['AIC'][0] @@ -41,10 +41,10 @@ def test_irasa_knee_peakless(load_knee_aperiodic_signal, fs, exponent, knee_freq # knee model -@pytest.mark.parametrize('exponent, knee_freq', EXP_KNEE_COMBO, scope='session') +@pytest.mark.parametrize('exponent, knee', EXP_KNEE_COMBO, scope='session') @pytest.mark.parametrize('fs', FS, scope='session') @pytest.mark.parametrize('osc_freq', OSC_FREQ, scope='session') -def test_irasa_knee_cmb(load_knee_cmb_signal, fs, exponent, knee_freq, osc_freq): +def test_irasa_knee_cmb(load_knee_cmb_signal, fs, exponent, knee, osc_freq): f_range = [0.1, 100] irasa_out = irasa(load_knee_cmb_signal, fs, f_range, psd_kwargs={'nperseg': 4 * fs}) # test the shape of the output @@ -64,7 +64,7 @@ def test_irasa_knee_cmb(load_knee_cmb_signal, fs, exponent, knee_freq, osc_freq) knee_hat = slope_fit_k.aperiodic_params['Knee'][0] ** ( 1 / (2 * slope_fit_k.aperiodic_params['Exponent_1'][0] + slope_fit_k.aperiodic_params['Exponent_2'][0]) ) - knee_real = knee_freq ** (1 / np.abs(exponent)) + knee_real = knee ** (1 / np.abs(exponent)) assert bool(np.isclose(knee_hat, knee_real, atol=KNEE_TOLERANCE)) # test bic/aic -> should be better for knee assert slope_fit_k.gof['AIC'][0] < slope_fit_f.gof['AIC'][0] diff --git a/tests/test_irasa_sprint.py b/tests/test_irasa_sprint.py index 1cc0e86..e4ded71 100644 --- a/tests/test_irasa_sprint.py +++ b/tests/test_irasa_sprint.py @@ -5,16 +5,21 @@ from pyrasa.irasa import irasa_sprint from pyrasa.utils.peak_utils import get_band_info -from .settings import MIN_R2_SPRINT, TOLERANCE +from .settings import EXPONENT, FS, MIN_R2_SPRINT set_random_seed(42) -def test_irasa_sprint(ts4sprint): +@pytest.mark.parametrize('fs', FS, scope='session') +@pytest.mark.parametrize('exponent_1', EXPONENT, scope='session') +@pytest.mark.parametrize('exponent_2', EXPONENT, scope='session') +def test_irasa_sprint(ts4sprint, fs, exponent_1, exponent_2): irasa_tf = irasa_sprint( ts4sprint[np.newaxis, :], - fs=500, - band=(1, 100), + # hop=25, + # win_duration=1, + fs=fs, + band=(0.1, 100), freq_res=0.5, ) @@ -24,11 +29,12 @@ def test_irasa_sprint(ts4sprint): # ) assert slope_fit.gof['r_squared'].mean() > MIN_R2_SPRINT - assert np.isclose(slope_fit.aperiodic_params.query('time < 7')['Exponent'].mean(), 1, atol=TOLERANCE) - assert np.isclose(slope_fit.aperiodic_params.query('time > 7')['Exponent'].mean(), 2, atol=TOLERANCE) + assert np.isclose(np.mean(slope_fit.aperiodic_params.query('time < 7')['Exponent']), np.abs(exponent_1), atol=0.5) + assert np.isclose(np.mean(slope_fit.aperiodic_params.query('time > 7')['Exponent']), np.abs(exponent_2), atol=0.5) # check basic peak detection df_peaks = irasa_tf.get_peaks( + cut_spectrum=(1, 40), smooth=True, smoothing_window=1, min_peak_height=0.01, @@ -50,7 +56,7 @@ def test_irasa_sprint(ts4sprint): pass # one missing burst is ok for now - assert np.isclose(n_peaks, 8, atol=1) + assert np.isclose(n_peaks, 7, atol=4) df_beta = get_band_info(df_peaks, freq_range=(20, 30), ch_names=[]) beta_peaks = df_beta.query('pw > 0.10') @@ -66,18 +72,21 @@ def test_irasa_sprint(ts4sprint): pass # one missing burst is ok for now - assert np.isclose(n_peaks, 12, atol=1) + assert np.isclose(n_peaks, 11, atol=4) # test settings -def test_irasa_sprint_settings(ts4sprint): +@pytest.mark.parametrize('fs', [1000], scope='session') +@pytest.mark.parametrize('exponent_1', [-1], scope='session') +@pytest.mark.parametrize('exponent_2', [-2], scope='session') +def test_irasa_sprint_settings(ts4sprint, fs): # test dpss import scipy.signal as dsp irasa_sprint( ts4sprint[np.newaxis, :], - fs=500, - band=(1, 100), + fs=fs, + band=(0.1, 100), win_func=dsp.windows.dpss, freq_res=0.5, ) @@ -86,7 +95,7 @@ def test_irasa_sprint_settings(ts4sprint): with pytest.raises(ValueError): irasa_sprint( ts4sprint[np.newaxis, :], - fs=500, + fs=fs, band=(1, 100), win_func=dsp.windows.dpss, dpss_settings_time_bandwidth=1,