diff --git a/CHANGELOG.md b/CHANGELOG.md index dc5457371..9fe9d9f7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,7 +40,7 @@ straightforward as possible. ### Fixed -- +- BUG: Invalid Arguments on Two Dimensional Discretize. [#521](https://github.com/RocketPy-Team/RocketPy/pull/521) ## [v1.1.4] - 2023-12-07 diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 1c02a3da1..288fb9af1 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -189,7 +189,7 @@ def set_source(self, source): self : Function Returns the Function instance. """ - _ = self._check_user_input( + *_, interpolation, extrapolation = self._check_user_input( source, self.__inputs__, self.__outputs__, @@ -277,10 +277,10 @@ def source_function(_): self.source = source # Update extrapolation method if self.__extrapolation__ is None: - self.set_extrapolation() + self.set_extrapolation(extrapolation) # Set default interpolation for point source if it hasn't if self.__interpolation__ is None: - self.set_interpolation() + self.set_interpolation(interpolation) else: # Updates interpolation coefficients self.set_interpolation(self.__interpolation__) @@ -560,14 +560,12 @@ def set_discrete( # Create nodes to evaluate function xs = np.linspace(lower[0], upper[0], sam[0]) ys = np.linspace(lower[1], upper[1], sam[1]) - xs, ys = np.meshgrid(xs, ys) - xs, ys = xs.flatten(), ys.flatten() - mesh = [[xs[i], ys[i]] for i in range(len(xs))] + xs, ys = np.array(np.meshgrid(xs, ys)).reshape(2, xs.size * ys.size) # Evaluate function at all mesh nodes and convert it to matrix - zs = np.array(self.get_value(mesh)) - self.set_source(np.concatenate(([xs], [ys], [zs])).transpose()) + zs = np.array(self.get_value(xs, ys)) self.__interpolation__ = "shepard" self.__extrapolation__ = "natural" + self.set_source(np.concatenate(([xs], [ys], [zs])).transpose()) return self def set_discrete_based_on_model( @@ -664,11 +662,8 @@ def set_discrete_based_on_model( # Create nodes to evaluate function xs = model_function.source[:, 0] ys = model_function.source[:, 1] - xs, ys = np.meshgrid(xs, ys) - xs, ys = xs.flatten(), ys.flatten() - mesh = [[xs[i], ys[i]] for i in range(len(xs))] # Evaluate function at all mesh nodes and convert it to matrix - zs = np.array(self.get_value(mesh)) + zs = np.array(self.get_value(xs, ys)) self.set_source(np.concatenate(([xs], [ys], [zs])).transpose()) interp = ( @@ -2860,6 +2855,8 @@ def _check_user_input( # check source for data type # if list or ndarray, check for dimensions, interpolation and extrapolation + if isinstance(source, Function): + source = source.get_source() if isinstance(source, (list, np.ndarray, str, Path)): # Deal with csv or txt if isinstance(source, (str, Path)): diff --git a/tests/test_function.py b/tests/test_function.py index 5727f146e..3befd40fd 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -372,6 +372,55 @@ def test_multivariable_function_plot(mock_show): assert func.plot() == None +def test_set_discrete_2d(): + """Tests the set_discrete method of the Function for + two dimensional domains. + """ + func = Function(lambda x, y: x**2 + y**2) + discretized_func = func.set_discrete([-5, -7], [8, 10], [50, 100]) + + assert isinstance(discretized_func, Function) + assert isinstance(func, Function) + assert discretized_func.source.shape == (50 * 100, 3) + assert np.isclose(discretized_func.source[0, 0], -5) + assert np.isclose(discretized_func.source[0, 1], -7) + assert np.isclose(discretized_func.source[-1, 0], 8) + assert np.isclose(discretized_func.source[-1, 1], 10) + + +def test_set_discrete_2d_simplified(): + """Tests the set_discrete method of the Function for + two dimensional domains with simplified inputs. + """ + source = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] + func = Function(source=source, inputs=["x", "y"], outputs=["z"]) + discretized_func = func.set_discrete(-1, 1, 10) + + assert isinstance(discretized_func, Function) + assert isinstance(func, Function) + assert discretized_func.source.shape == (100, 3) + assert np.isclose(discretized_func.source[0, 0], -1) + assert np.isclose(discretized_func.source[0, 1], -1) + assert np.isclose(discretized_func.source[-1, 0], 1) + assert np.isclose(discretized_func.source[-1, 1], 1) + + +def test_set_discrete_based_on_2d_model(func_2d_from_csv): + """Tests the set_discrete_based_on_model method with a 2d model + Function. + """ + func = Function(lambda x, y: x**2 + y**2) + discretized_func = func.set_discrete_based_on_model(func_2d_from_csv) + + assert isinstance(discretized_func, Function) + assert isinstance(func, Function) + assert np.array_equal( + discretized_func.source[:, :2], func_2d_from_csv.source[:, :2] + ) + assert discretized_func.__interpolation__ == func_2d_from_csv.__interpolation__ + assert discretized_func.__extrapolation__ == func_2d_from_csv.__extrapolation__ + + @pytest.mark.parametrize( "x,y,z_expected", [