diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 4a879bf1f..27fd8b717 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -235,12 +235,12 @@ def source(x): # Update extrapolation method if ( self.__extrapolation__ is None - or self.__extrapolation__ == "shepard" + or self.__extrapolation__ == "natural" ): - self.set_extrapolation("shepard") + self.set_extrapolation("natural") else: raise ValueError( - "Multidimensional datasets only support shepard extrapolation." + "Multidimensional datasets only support natural extrapolation." ) # Set default multidimensional interpolation if it hasn't diff --git a/tests/test_function.py b/tests/test_function.py index 4cbac9c33..82d38c1e5 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -173,3 +173,73 @@ def test_integral_spline_interpolation(request, func, a, b): func.integral(a, b, numerical=True), atol=1e-3, ) + + +@pytest.mark.parametrize("a", [-1, 0, 1]) +@pytest.mark.parametrize("b", [-1, 0, 1]) +def test_multivariable_dataset(a, b): + """Test the Function class with a multivariable dataset.""" + # Test plane f(x,y) = x + y + source = [ + (-1, -1, -2), + (-1, 0, -1), + (-1, 1, 0), + (0, -1, -1), + (0, 0, 0), + (0, 1, 1), + (1, -1, 0), + (1, 0, 1), + (1, 1, 2), + ] + func = Function(source=source, inputs=["x", "y"], outputs=["z"]) + + # Assert interpolation and extrapolation methods + assert func.get_interpolation_method() == "shepard" + assert func.get_extrapolation_method() == "natural" + + # Assert values + assert np.isclose(func(a, b), a + b, atol=1e-6) + + +@pytest.mark.parametrize("a", [-1, -0.5, 0, 0.5, 1]) +@pytest.mark.parametrize("b", [-1, -0.5, 0, 0.5, 1]) +def test_multivariable_function(a, b): + """Test the Function class with a multivariable function.""" + # Test plane f(x,y) = sin(x + y) + source = lambda x, y: np.sin(x + y) + func = Function(source=source, inputs=["x", "y"], outputs=["z"]) + + # Assert values + assert np.isclose(func(a, b), np.sin(a + b), atol=1e-6) + + +@patch("matplotlib.pyplot.show") +def test_multivariable_dataset_plot(mock_show): + """Test the plot method of the Function class with a multivariable dataset.""" + # Test plane f(x,y) = x - y + source = [ + (-1, -1, -1), + (-1, 0, -1), + (-1, 1, -2), + (0, 1, 1), + (0, 0, 0), + (0, 1, -1), + (1, -1, 2), + (1, 0, 1), + (1, 1, 0), + ] + func = Function(source=source, inputs=["x", "y"], outputs=["z"]) + + # Assert plot + assert func.plot() == None + + +@patch("matplotlib.pyplot.show") +def test_multivariable_function_plot(mock_show): + """Test the plot method of the Function class with a multivariable function.""" + # Test plane f(x,y) = sin(x + y) + source = lambda x, y: np.sin(x * y) + func = Function(source=source, inputs=["x", "y"], outputs=["z"]) + + # Assert plot + assert func.plot() == None