diff --git a/tests/test_basis.py b/tests/test_basis.py index 2e66a6cb..2d7e83f3 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -7,6 +7,7 @@ import utils_testing import nemos.basis as basis +from contextlib import nullcontext as does_not_raise # automatic define user accessible basis and check the methods @@ -153,16 +154,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): Confirms that the evaluate() method correctly handles the number of input samples that are provided. """ basis_obj = self.cls(n_basis_funcs=5) - raise_exception = n_input != basis_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ " - "inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @@ -204,15 +203,13 @@ def test_evaluate_on_grid_input_number(self, n_input): """ basis_obj = self.cls(n_basis_funcs=5) inputs = [10] * n_input - raise_exception = n_input != basis_obj._n_input_dimensionality - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, " - r"[0-9]+ inputs provided instead.", - ): - basis_obj.evaluate_on_grid(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @@ -313,15 +310,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): Confirms that the evaluate() method correctly handles the number of input samples that are provided. """ basis_obj = self.cls(n_basis_funcs=5) - raise_exception = n_input != basis_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @@ -363,15 +359,13 @@ def test_evaluate_on_grid_input_number(self, n_input): """ basis_obj = self.cls(n_basis_funcs=5) inputs = [10] * n_input - raise_exception = n_input != basis_obj._n_input_dimensionality - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, " - r"[0-9]+ inputs provided instead.", - ): - basis_obj.evaluate_on_grid(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @@ -470,15 +464,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): Confirms that the evaluate() method correctly handles the number of input samples that are provided. """ basis_obj = self.cls(n_basis_funcs=5, order=3) - raise_exception = n_input != basis_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @@ -520,15 +513,13 @@ def test_evaluate_on_grid_input_number(self, n_input): """ basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [10] * n_input - raise_exception = n_input != basis_obj._n_input_dimensionality - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, " - r"[0-9]+ inputs provided instead.", - ): - basis_obj.evaluate_on_grid(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @@ -633,15 +624,14 @@ def test_samples_range_matches_evaluate_requirements(self, sample_range: tuple): def test_number_of_required_inputs_evaluate(self, n_input): """Tests whether the evaluate method correctly processes the number of required inputs.""" basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6)) - raise_exception = n_input != basis_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 2, 3, 4, 5, 6, 10, 11, 100]) @@ -682,15 +672,13 @@ def test_evaluate_on_grid_input_number(self, n_input): """Tests whether the evaluate_on_grid method correctly processes the Input dimensionality.""" basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6)) inputs = [10] * n_input - raise_exception = n_input != basis_obj._n_input_dimensionality - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, " - r"[0-9]+ inputs provided instead.", - ): - basis_obj.evaluate_on_grid(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @pytest.mark.parametrize( @@ -839,15 +827,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): Confirms that the evaluate() method correctly handles the number of input samples that are provided. """ basis_obj = self.cls(n_basis_funcs=5, order=3) - raise_exception = n_input != basis_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @@ -893,14 +880,13 @@ def test_evaluate_on_grid_input_number(self, n_input): """ basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [10] * n_input - raise_exception = n_input != basis_obj._n_input_dimensionality - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs", - ): - basis_obj.evaluate_on_grid(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @@ -1032,15 +1018,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): Confirms that the evaluate() method correctly handles the number of input samples that are provided. """ basis_obj = self.cls(n_basis_funcs=5, order=3) - raise_exception = n_input != basis_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs", - ): - basis_obj.evaluate(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @@ -1086,14 +1071,13 @@ def test_evaluate_on_grid_input_number(self, n_input): """ basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [10] * n_input - raise_exception = n_input != basis_obj._n_input_dimensionality - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs", - ): - basis_obj.evaluate_on_grid(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @@ -1254,18 +1238,13 @@ def test_number_of_required_inputs_evaluate( basis_a_obj = self.instantiate_basis(n_basis_a, basis_a) basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj + basis_b_obj - raise_exception = ( - n_input - != basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality - ) + required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input != required_dim: + expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [11, 20]) @@ -1341,18 +1320,12 @@ def test_evaluate_on_grid_input_number( basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj + basis_b_obj inputs = [20] * n_input - raise_exception = ( - n_input - != basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality - ) - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, " - r"[0-9]+ inputs provided instead.", - ): - basis_obj.evaluate_on_grid(*inputs) + required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + if n_input != required_dim: + expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @@ -1473,18 +1446,13 @@ def test_number_of_required_inputs_evaluate( basis_a_obj = self.instantiate_basis(n_basis_a, basis_a) basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj * basis_b_obj - raise_exception = ( - n_input - != basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality - ) + required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input != required_dim: + expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [11, 20]) @@ -1560,18 +1528,12 @@ def test_evaluate_on_grid_input_number( basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj * basis_b_obj inputs = [20] * n_input - raise_exception = ( - n_input - != basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality - ) - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs, " - r"[0-9]+ inputs provided instead.", - ): - basis_obj.evaluate_on_grid(*inputs) + required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + if n_input != required_dim: + expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @pytest.mark.parametrize("basis_a", [basis.MSplineBasis])