Skip to content

Commit

Permalink
updates tests so they pass
Browse files Browse the repository at this point in the history
  • Loading branch information
billbrod committed Dec 11, 2023
1 parent 231c931 commit eca9fb9
Showing 1 changed file with 93 additions and 131 deletions.
224 changes: 93 additions & 131 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import utils_testing

import nemos.basis as basis
from contextlib import nullcontext as does_not_raise

# automatic define user accessible basis and check the methods

Expand Down Expand Up @@ -153,16 +154,14 @@ def test_number_of_required_inputs_evaluate(self, n_input):
Confirms that the evaluate() method correctly handles the number of input samples that are provided.
"""
basis_obj = self.cls(n_basis_funcs=5)
raise_exception = n_input != basis_obj._n_input_dimensionality
inputs = [np.linspace(0, 1, 20)] * n_input
if raise_exception:
with pytest.raises(
ValueError,
match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ "
"inputs,",
):
basis_obj.evaluate(*inputs)
if n_input == 0:
expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument")
elif n_input != basis_obj._n_input_dimensionality:
expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",)
else:
expectation = does_not_raise()
with expectation:
basis_obj.evaluate(*inputs)

@pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100])
Expand Down Expand Up @@ -204,15 +203,13 @@ def test_evaluate_on_grid_input_number(self, n_input):
"""
basis_obj = self.cls(n_basis_funcs=5)
inputs = [10] * n_input
raise_exception = n_input != basis_obj._n_input_dimensionality
if raise_exception:
with pytest.raises(
ValueError,
match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, "
r"[0-9]+ inputs provided instead.",
):
basis_obj.evaluate_on_grid(*inputs)
if n_input == 0:
expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument")
elif n_input != basis_obj._n_input_dimensionality:
expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",)
else:
expectation = does_not_raise()
with expectation:
basis_obj.evaluate_on_grid(*inputs)


Expand Down Expand Up @@ -313,15 +310,14 @@ def test_number_of_required_inputs_evaluate(self, n_input):
Confirms that the evaluate() method correctly handles the number of input samples that are provided.
"""
basis_obj = self.cls(n_basis_funcs=5)
raise_exception = n_input != basis_obj._n_input_dimensionality
inputs = [np.linspace(0, 1, 20)] * n_input
if raise_exception:
with pytest.raises(
ValueError,
match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,",
):
basis_obj.evaluate(*inputs)
if n_input == 0:
expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument")
elif n_input != basis_obj._n_input_dimensionality:
expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",)
else:
expectation = does_not_raise()
with expectation:
basis_obj.evaluate(*inputs)

@pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100])
Expand Down Expand Up @@ -363,15 +359,13 @@ def test_evaluate_on_grid_input_number(self, n_input):
"""
basis_obj = self.cls(n_basis_funcs=5)
inputs = [10] * n_input
raise_exception = n_input != basis_obj._n_input_dimensionality
if raise_exception:
with pytest.raises(
ValueError,
match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, "
r"[0-9]+ inputs provided instead.",
):
basis_obj.evaluate_on_grid(*inputs)
if n_input == 0:
expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument")
elif n_input != basis_obj._n_input_dimensionality:
expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",)
else:
expectation = does_not_raise()
with expectation:
basis_obj.evaluate_on_grid(*inputs)


Expand Down Expand Up @@ -470,15 +464,14 @@ def test_number_of_required_inputs_evaluate(self, n_input):
Confirms that the evaluate() method correctly handles the number of input samples that are provided.
"""
basis_obj = self.cls(n_basis_funcs=5, order=3)
raise_exception = n_input != basis_obj._n_input_dimensionality
inputs = [np.linspace(0, 1, 20)] * n_input
if raise_exception:
with pytest.raises(
ValueError,
match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,",
):
basis_obj.evaluate(*inputs)
if n_input == 0:
expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument")
elif n_input != basis_obj._n_input_dimensionality:
expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",)
else:
expectation = does_not_raise()
with expectation:
basis_obj.evaluate(*inputs)

@pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100])
Expand Down Expand Up @@ -520,15 +513,13 @@ def test_evaluate_on_grid_input_number(self, n_input):
"""
basis_obj = self.cls(n_basis_funcs=5, order=3)
inputs = [10] * n_input
raise_exception = n_input != basis_obj._n_input_dimensionality
if raise_exception:
with pytest.raises(
ValueError,
match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, "
r"[0-9]+ inputs provided instead.",
):
basis_obj.evaluate_on_grid(*inputs)
if n_input == 0:
expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument")
elif n_input != basis_obj._n_input_dimensionality:
expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",)
else:
expectation = does_not_raise()
with expectation:
basis_obj.evaluate_on_grid(*inputs)


Expand Down Expand Up @@ -633,15 +624,14 @@ def test_samples_range_matches_evaluate_requirements(self, sample_range: tuple):
def test_number_of_required_inputs_evaluate(self, n_input):
"""Tests whether the evaluate method correctly processes the number of required inputs."""
basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6))
raise_exception = n_input != basis_obj._n_input_dimensionality
inputs = [np.linspace(0, 1, 20)] * n_input
if raise_exception:
with pytest.raises(
ValueError,
match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,",
):
basis_obj.evaluate(*inputs)
if n_input == 0:
expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument")
elif n_input != basis_obj._n_input_dimensionality:
expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",)
else:
expectation = does_not_raise()
with expectation:
basis_obj.evaluate(*inputs)

@pytest.mark.parametrize("sample_size", [-1, 0, 1, 2, 3, 4, 5, 6, 10, 11, 100])
Expand Down Expand Up @@ -682,15 +672,13 @@ def test_evaluate_on_grid_input_number(self, n_input):
"""Tests whether the evaluate_on_grid method correctly processes the Input dimensionality."""
basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6))
inputs = [10] * n_input
raise_exception = n_input != basis_obj._n_input_dimensionality
if raise_exception:
with pytest.raises(
ValueError,
match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, "
r"[0-9]+ inputs provided instead.",
):
basis_obj.evaluate_on_grid(*inputs)
if n_input == 0:
expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument")
elif n_input != basis_obj._n_input_dimensionality:
expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",)
else:
expectation = does_not_raise()
with expectation:
basis_obj.evaluate_on_grid(*inputs)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -839,15 +827,14 @@ def test_number_of_required_inputs_evaluate(self, n_input):
Confirms that the evaluate() method correctly handles the number of input samples that are provided.
"""
basis_obj = self.cls(n_basis_funcs=5, order=3)
raise_exception = n_input != basis_obj._n_input_dimensionality
inputs = [np.linspace(0, 1, 20)] * n_input
if raise_exception:
with pytest.raises(
ValueError,
match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,",
):
basis_obj.evaluate(*inputs)
if n_input == 0:
expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument")
elif n_input != basis_obj._n_input_dimensionality:
expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",)
else:
expectation = does_not_raise()
with expectation:
basis_obj.evaluate(*inputs)

@pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100])
Expand Down Expand Up @@ -893,14 +880,13 @@ def test_evaluate_on_grid_input_number(self, n_input):
"""
basis_obj = self.cls(n_basis_funcs=5, order=3)
inputs = [10] * n_input
raise_exception = n_input != basis_obj._n_input_dimensionality
if raise_exception:
with pytest.raises(
ValueError,
match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs",
):
basis_obj.evaluate_on_grid(*inputs)
if n_input == 0:
expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument")
elif n_input != basis_obj._n_input_dimensionality:
expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",)
else:
expectation = does_not_raise()
with expectation:
basis_obj.evaluate_on_grid(*inputs)


Expand Down Expand Up @@ -1032,15 +1018,14 @@ def test_number_of_required_inputs_evaluate(self, n_input):
Confirms that the evaluate() method correctly handles the number of input samples that are provided.
"""
basis_obj = self.cls(n_basis_funcs=5, order=3)
raise_exception = n_input != basis_obj._n_input_dimensionality
inputs = [np.linspace(0, 1, 20)] * n_input
if raise_exception:
with pytest.raises(
ValueError,
match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs",
):
basis_obj.evaluate(*inputs)
if n_input == 0:
expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument")
elif n_input != basis_obj._n_input_dimensionality:
expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",)
else:
expectation = does_not_raise()
with expectation:
basis_obj.evaluate(*inputs)

@pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100])
Expand Down Expand Up @@ -1086,14 +1071,13 @@ def test_evaluate_on_grid_input_number(self, n_input):
"""
basis_obj = self.cls(n_basis_funcs=5, order=3)
inputs = [10] * n_input
raise_exception = n_input != basis_obj._n_input_dimensionality
if raise_exception:
with pytest.raises(
ValueError,
match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs",
):
basis_obj.evaluate_on_grid(*inputs)
if n_input == 0:
expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument")
elif n_input != basis_obj._n_input_dimensionality:
expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",)
else:
expectation = does_not_raise()
with expectation:
basis_obj.evaluate_on_grid(*inputs)


Expand Down Expand Up @@ -1254,18 +1238,13 @@ def test_number_of_required_inputs_evaluate(
basis_a_obj = self.instantiate_basis(n_basis_a, basis_a)
basis_b_obj = self.instantiate_basis(n_basis_b, basis_b)
basis_obj = basis_a_obj + basis_b_obj
raise_exception = (
n_input
!= basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality
)
required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality
inputs = [np.linspace(0, 1, 20)] * n_input
if raise_exception:
with pytest.raises(
ValueError,
match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,",
):
basis_obj.evaluate(*inputs)
if n_input != required_dim:
expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.")
else:
expectation = does_not_raise()
with expectation:
basis_obj.evaluate(*inputs)

@pytest.mark.parametrize("sample_size", [11, 20])
Expand Down Expand Up @@ -1341,18 +1320,12 @@ def test_evaluate_on_grid_input_number(
basis_b_obj = self.instantiate_basis(n_basis_b, basis_b)
basis_obj = basis_a_obj + basis_b_obj
inputs = [20] * n_input
raise_exception = (
n_input
!= basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality
)
if raise_exception:
with pytest.raises(
ValueError,
match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, "
r"[0-9]+ inputs provided instead.",
):
basis_obj.evaluate_on_grid(*inputs)
required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality
if n_input != required_dim:
expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.")
else:
expectation = does_not_raise()
with expectation:
basis_obj.evaluate_on_grid(*inputs)


Expand Down Expand Up @@ -1473,18 +1446,13 @@ def test_number_of_required_inputs_evaluate(
basis_a_obj = self.instantiate_basis(n_basis_a, basis_a)
basis_b_obj = self.instantiate_basis(n_basis_b, basis_b)
basis_obj = basis_a_obj * basis_b_obj
raise_exception = (
n_input
!= basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality
)
required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality
inputs = [np.linspace(0, 1, 20)] * n_input
if raise_exception:
with pytest.raises(
ValueError,
match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,",
):
basis_obj.evaluate(*inputs)
if n_input != required_dim:
expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.")
else:
expectation = does_not_raise()
with expectation:
basis_obj.evaluate(*inputs)

@pytest.mark.parametrize("sample_size", [11, 20])
Expand Down Expand Up @@ -1560,18 +1528,12 @@ def test_evaluate_on_grid_input_number(
basis_b_obj = self.instantiate_basis(n_basis_b, basis_b)
basis_obj = basis_a_obj * basis_b_obj
inputs = [20] * n_input
raise_exception = (
n_input
!= basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality
)
if raise_exception:
with pytest.raises(
ValueError,
match=r"Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs, "
r"[0-9]+ inputs provided instead.",
):
basis_obj.evaluate_on_grid(*inputs)
required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality
if n_input != required_dim:
expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.")
else:
expectation = does_not_raise()
with expectation:
basis_obj.evaluate_on_grid(*inputs)

@pytest.mark.parametrize("basis_a", [basis.MSplineBasis])
Expand Down

0 comments on commit eca9fb9

Please sign in to comment.