From 06342900461dd3f96a5b09740a687f1e5e11ea8d Mon Sep 17 00:00:00 2001 From: Pedro Bressan Date: Wed, 15 Nov 2023 14:37:37 -0300 Subject: [PATCH 1/2] BUG: fix extrapolation of multivariable functions. --- rocketpy/mathutils/function.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 1fd878b36..4a879bf1f 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -231,8 +231,29 @@ def source(x): # Finally set data source as source self.source = source - if self.__interpolation__ is None: + + # Update extrapolation method + if ( + self.__extrapolation__ is None + or self.__extrapolation__ == "shepard" + ): + self.set_extrapolation("shepard") + else: + raise ValueError( + "Multidimensional datasets only support shepard extrapolation." + ) + + # Set default multidimensional interpolation if it hasn't + if ( + self.__interpolation__ is None + or self.__interpolation__ == "shepard" + ): self.set_interpolation("shepard") + else: + raise ValueError( + "Multidimensional datasets only support shepard interpolation." + ) + # Return self return self From 36d91c2d9af25b866db3eb9f43b559e65064535f Mon Sep 17 00:00:00 2001 From: Pedro Bressan Date: Wed, 15 Nov 2023 15:02:30 -0300 Subject: [PATCH 2/2] TST: add tests for multivariable functions. --- rocketpy/mathutils/function.py | 6 +-- tests/test_function.py | 70 ++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 4a879bf1f..27fd8b717 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -235,12 +235,12 @@ def source(x): # Update extrapolation method if ( self.__extrapolation__ is None - or self.__extrapolation__ == "shepard" + or self.__extrapolation__ == "natural" ): - self.set_extrapolation("shepard") + self.set_extrapolation("natural") else: raise ValueError( - "Multidimensional datasets only support shepard extrapolation." + "Multidimensional datasets only support natural extrapolation." ) # Set default multidimensional interpolation if it hasn't diff --git a/tests/test_function.py b/tests/test_function.py index 4cbac9c33..82d38c1e5 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -173,3 +173,73 @@ def test_integral_spline_interpolation(request, func, a, b): func.integral(a, b, numerical=True), atol=1e-3, ) + + +@pytest.mark.parametrize("a", [-1, 0, 1]) +@pytest.mark.parametrize("b", [-1, 0, 1]) +def test_multivariable_dataset(a, b): + """Test the Function class with a multivariable dataset.""" + # Test plane f(x,y) = x + y + source = [ + (-1, -1, -2), + (-1, 0, -1), + (-1, 1, 0), + (0, -1, -1), + (0, 0, 0), + (0, 1, 1), + (1, -1, 0), + (1, 0, 1), + (1, 1, 2), + ] + func = Function(source=source, inputs=["x", "y"], outputs=["z"]) + + # Assert interpolation and extrapolation methods + assert func.get_interpolation_method() == "shepard" + assert func.get_extrapolation_method() == "natural" + + # Assert values + assert np.isclose(func(a, b), a + b, atol=1e-6) + + +@pytest.mark.parametrize("a", [-1, -0.5, 0, 0.5, 1]) +@pytest.mark.parametrize("b", [-1, -0.5, 0, 0.5, 1]) +def test_multivariable_function(a, b): + """Test the Function class with a multivariable function.""" + # Test plane f(x,y) = sin(x + y) + source = lambda x, y: np.sin(x + y) + func = Function(source=source, inputs=["x", "y"], outputs=["z"]) + + # Assert values + assert np.isclose(func(a, b), np.sin(a + b), atol=1e-6) + + +@patch("matplotlib.pyplot.show") +def test_multivariable_dataset_plot(mock_show): + """Test the plot method of the Function class with a multivariable dataset.""" + # Test plane f(x,y) = x - y + source = [ + (-1, -1, -1), + (-1, 0, -1), + (-1, 1, -2), + (0, 1, 1), + (0, 0, 0), + (0, 1, -1), + (1, -1, 2), + (1, 0, 1), + (1, 1, 0), + ] + func = Function(source=source, inputs=["x", "y"], outputs=["z"]) + + # Assert plot + assert func.plot() == None + + +@patch("matplotlib.pyplot.show") +def test_multivariable_function_plot(mock_show): + """Test the plot method of the Function class with a multivariable function.""" + # Test plane f(x,y) = sin(x + y) + source = lambda x, y: np.sin(x * y) + func = Function(source=source, inputs=["x", "y"], outputs=["z"]) + + # Assert plot + assert func.plot() == None