From 2de6ef738a324f6d85cd1d2fe2e4a5b84b0e7f39 Mon Sep 17 00:00:00 2001 From: Maren Mahsereci <42842079+mmahsereci@users.noreply.github.com> Date: Tue, 7 May 2024 15:49:55 +0200 Subject: [PATCH] Small bugfix in tests. (#462) --- tests/emukit/quadrature/test_measures.py | 17 +++++++---------- tests/emukit/quadrature/test_warpings.py | 24 ++++++++++++------------ 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/tests/emukit/quadrature/test_measures.py b/tests/emukit/quadrature/test_measures.py index b4b7dff0..d656de59 100644 --- a/tests/emukit/quadrature/test_measures.py +++ b/tests/emukit/quadrature/test_measures.py @@ -75,27 +75,24 @@ def gauss_measure(): return DataGaussMeasure() -measure_test_list = [ - DataLebesgueMeasure(), - DataLebesgueNormalizedMeasure(), - DataGaussIsoMeasure(), - DataGaussMeasure(), -] +measure_test_list = ["lebesgue_measure", "lebesgue_measure_normalized", "gauss_iso_measure", "gauss_measure"] # === tests shared by all measures start here -@pytest.mark.parametrize("measure", measure_test_list) -def test_measure_gradient_values(measure): +@pytest.mark.parametrize("measure_name", measure_test_list) +def test_measure_gradient_values(measure_name, request): + measure = request.getfixturevalue(measure_name) D, measure, dat_bounds = measure.D, measure.measure, measure.dat_bounds func = lambda x: measure.compute_density(x) dfunc = lambda x: measure.compute_density_gradient(x).T check_grad(func, dfunc, in_shape=(3, D), bounds=dat_bounds) -@pytest.mark.parametrize("measure", measure_test_list) -def test_measure_shapes(measure): +@pytest.mark.parametrize("measure_name", measure_test_list) +def test_measure_shapes(measure_name, request): + measure = request.getfixturevalue(measure_name) D, measure = measure.D, measure.measure # box bounds diff --git a/tests/emukit/quadrature/test_warpings.py b/tests/emukit/quadrature/test_warpings.py index db107df7..ef3feea5 100644 --- a/tests/emukit/quadrature/test_warpings.py +++ b/tests/emukit/quadrature/test_warpings.py @@ -17,21 +17,21 @@ def identity_warping(): @pytest.fixture -def squarerroot_warping(): +def square_root_warping(): offset = 1.0 return SquareRootWarping(offset=offset) @pytest.fixture -def inverted_squarerroot_warping(): +def inverted_square_root_warping(): offset = 1.0 return SquareRootWarping(offset=offset, is_inverted=True) warpings = [ "identity_warping", - "squarerroot_warping", - "inverted_squarerroot_warping", + "square_root_warping", + "inverted_square_root_warping", ] @@ -56,16 +56,16 @@ def test_warping_values(warping_name, request): assert_allclose(warping.inverse_transform(warping.transform(Y)), Y, rtol=RTOL, atol=ATOL) -def test_squarerroot_warping_update_parameters(squarerroot_warping, inverted_squarerroot_warping): +def test_square_root_warping_update_parameters(square_root_warping, inverted_square_root_warping): new_offset = 10.0 - squarerroot_warping.update_parameters(offset=new_offset) - assert squarerroot_warping.offset == new_offset + square_root_warping.update_parameters(offset=new_offset) + assert square_root_warping.offset == new_offset - inverted_squarerroot_warping.update_parameters(offset=new_offset) - assert inverted_squarerroot_warping.offset == new_offset + inverted_square_root_warping.update_parameters(offset=new_offset) + assert inverted_square_root_warping.offset == new_offset -def test_squarerroot_warping_inverted_flag(squarerroot_warping, inverted_squarerroot_warping): - assert not squarerroot_warping.is_inverted - assert inverted_squarerroot_warping.is_inverted +def test_square_root_warping_inverted_flag(square_root_warping, inverted_square_root_warping): + assert not square_root_warping.is_inverted + assert inverted_square_root_warping.is_inverted