From 84b3c660cab2b33623c2136b7f7214cf190c93cd Mon Sep 17 00:00:00 2001 From: Fabi Date: Mon, 16 Sep 2024 09:15:56 +0200 Subject: [PATCH] fixed test bug --- examples/irasa_sprint.ipynb | 2 +- tests/conftest.py | 29 +++++++++++++++++++++++++++++ tests/test_irasa_knee.py | 18 ++++++++++++++---- 3 files changed, 44 insertions(+), 5 deletions(-) diff --git a/examples/irasa_sprint.ipynb b/examples/irasa_sprint.ipynb index bf11504..0cbb2b0 100644 --- a/examples/irasa_sprint.ipynb +++ b/examples/irasa_sprint.ipynb @@ -513,7 +513,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.4" + "version": "3.1.undefined" } }, "nbformat": 4, diff --git a/tests/conftest.py b/tests/conftest.py index b6cc310..5eb63c6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -87,6 +87,35 @@ def ts4sprint(fs, exponent_1, exponent_2): yield sim_ts +@pytest.fixture(scope='session') +def ts4sprint_knee(fs, exponent_1, exponent_2): + alpha = sim_oscillation(n_seconds=0.5, fs=fs, freq=10) + no_alpha = np.zeros(len(alpha)) + beta = sim_oscillation(n_seconds=0.5, fs=fs, freq=25) + no_beta = np.zeros(len(beta)) + + knee1 = 20 ** np.abs(exponent_1) + knee2 = 20 ** np.abs(exponent_2) + exp_1 = sim_knee(n_seconds=2.5, fs=fs, exponent1=0, exponent2=exponent_1, knee=knee1) + exp_2 = sim_knee(n_seconds=2.5, fs=fs, exponent1=0, exponent2=exponent_2, knee=knee2) + + # %% + alphas = np.concatenate([no_alpha, alpha, no_alpha, alpha, no_alpha]) + betas = np.concatenate([beta, no_beta, beta, no_beta, beta]) + + sim_ts = np.concatenate( + [ + exp_1 + alphas, + exp_1 + alphas + betas, + exp_1 + betas, + exp_2 + alphas, + exp_2 + alphas + betas, + exp_2 + betas, + ] + ) + yield sim_ts + + @pytest.fixture(scope='session') def gen_mne_data_raw(): data_path = sample.data_path() diff --git a/tests/test_irasa_knee.py b/tests/test_irasa_knee.py index d951ca4..e28194a 100644 --- a/tests/test_irasa_knee.py +++ b/tests/test_irasa_knee.py @@ -114,9 +114,9 @@ def test_aperiodic_error(load_knee_cmb_signal, fs, exponent, knee, osc_freq): @pytest.mark.parametrize('fs', [1000], scope='session') @pytest.mark.parametrize('exponent_1', [-0], scope='session') @pytest.mark.parametrize('exponent_2', [-2], scope='session') -def test_aperiodic_error_tf(ts4sprint, fs, exponent, knee, osc_freq): +def test_aperiodic_error_tf(ts4sprint_knee, fs, exponent_1, exponent_2): irasa_out = irasa_sprint( - ts4sprint, + ts4sprint_knee, fs=fs, band=(0.1, 50), overlap_fraction=0.95, @@ -125,7 +125,7 @@ def test_aperiodic_error_tf(ts4sprint, fs, exponent, knee, osc_freq): ) irasa_out_bad = irasa_sprint( - ts4sprint, + ts4sprint_knee, fs=fs, band=(0.1, 50), overlap_fraction=0.95, @@ -133,4 +133,14 @@ def test_aperiodic_error_tf(ts4sprint, fs, exponent, knee, osc_freq): hset_info=(1, 8.0, 0.05), ) - assert np.mean(irasa_out.get_aperiodic_error()) < np.mean(irasa_out_bad.get_aperiodic_error()) + kwargs = { + 'cut_spectrum': (1, 40), + 'smooth': True, + 'smoothing_window': 3, + 'min_peak_height': 0.01, + 'peak_width_limits': (0.5, 12), + } + + assert np.mean(irasa_out.get_aperiodic_error(peak_kwargs=kwargs)) < np.mean( + irasa_out_bad.get_aperiodic_error(peak_kwargs=kwargs) + )