Skip to content

Commit

Permalink
TST: add tests for multivariable functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
phmbressan committed Nov 15, 2023
1 parent 0634290 commit 36d91c2
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 3 deletions.
6 changes: 3 additions & 3 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,12 @@ def source(x):
# Update extrapolation method
if (
self.__extrapolation__ is None
or self.__extrapolation__ == "shepard"
or self.__extrapolation__ == "natural"
):
self.set_extrapolation("shepard")
self.set_extrapolation("natural")
else:
raise ValueError(
"Multidimensional datasets only support shepard extrapolation."
"Multidimensional datasets only support natural extrapolation."
)

# Set default multidimensional interpolation if it hasn't
Expand Down
70 changes: 70 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,73 @@ def test_integral_spline_interpolation(request, func, a, b):
func.integral(a, b, numerical=True),
atol=1e-3,
)


@pytest.mark.parametrize("a", [-1, 0, 1])
@pytest.mark.parametrize("b", [-1, 0, 1])
def test_multivariable_dataset(a, b):
"""Test the Function class with a multivariable dataset."""
# Test plane f(x,y) = x + y
source = [
(-1, -1, -2),
(-1, 0, -1),
(-1, 1, 0),
(0, -1, -1),
(0, 0, 0),
(0, 1, 1),
(1, -1, 0),
(1, 0, 1),
(1, 1, 2),
]
func = Function(source=source, inputs=["x", "y"], outputs=["z"])

# Assert interpolation and extrapolation methods
assert func.get_interpolation_method() == "shepard"
assert func.get_extrapolation_method() == "natural"

# Assert values
assert np.isclose(func(a, b), a + b, atol=1e-6)


@pytest.mark.parametrize("a", [-1, -0.5, 0, 0.5, 1])
@pytest.mark.parametrize("b", [-1, -0.5, 0, 0.5, 1])
def test_multivariable_function(a, b):
"""Test the Function class with a multivariable function."""
# Test plane f(x,y) = sin(x + y)
source = lambda x, y: np.sin(x + y)
func = Function(source=source, inputs=["x", "y"], outputs=["z"])

# Assert values
assert np.isclose(func(a, b), np.sin(a + b), atol=1e-6)


@patch("matplotlib.pyplot.show")
def test_multivariable_dataset_plot(mock_show):
"""Test the plot method of the Function class with a multivariable dataset."""
# Test plane f(x,y) = x - y
source = [
(-1, -1, -1),
(-1, 0, -1),
(-1, 1, -2),
(0, 1, 1),
(0, 0, 0),
(0, 1, -1),
(1, -1, 2),
(1, 0, 1),
(1, 1, 0),
]
func = Function(source=source, inputs=["x", "y"], outputs=["z"])

# Assert plot
assert func.plot() == None


@patch("matplotlib.pyplot.show")
def test_multivariable_function_plot(mock_show):
"""Test the plot method of the Function class with a multivariable function."""
# Test plane f(x,y) = sin(x + y)
source = lambda x, y: np.sin(x * y)
func = Function(source=source, inputs=["x", "y"], outputs=["z"])

# Assert plot
assert func.plot() == None

0 comments on commit 36d91c2

Please sign in to comment.