Skip to content

Commit

Permalink
Merge pull request #521 from RocketPy-Team/bug/function-2d-discretize
Browse files Browse the repository at this point in the history
BUG: Invalid Arguments on Two Dimensional Discretize (HOTFIX).
  • Loading branch information
Gui-FernandesBR authored Jan 18, 2024
2 parents fa3d9a7 + 223d598 commit 4127313
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 13 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ straightforward as possible.

### Fixed

-
- BUG: Invalid Arguments on Two Dimensional Discretize. [#521](https://github.com/RocketPy-Team/RocketPy/pull/521)

## [v1.1.4] - 2023-12-07

Expand Down
21 changes: 9 additions & 12 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def set_source(self, source):
self : Function
Returns the Function instance.
"""
_ = self._check_user_input(
*_, interpolation, extrapolation = self._check_user_input(
source,
self.__inputs__,
self.__outputs__,
Expand Down Expand Up @@ -277,10 +277,10 @@ def source_function(_):
self.source = source
# Update extrapolation method
if self.__extrapolation__ is None:
self.set_extrapolation()
self.set_extrapolation(extrapolation)
# Set default interpolation for point source if it hasn't
if self.__interpolation__ is None:
self.set_interpolation()
self.set_interpolation(interpolation)
else:
# Updates interpolation coefficients
self.set_interpolation(self.__interpolation__)
Expand Down Expand Up @@ -560,14 +560,12 @@ def set_discrete(
# Create nodes to evaluate function
xs = np.linspace(lower[0], upper[0], sam[0])
ys = np.linspace(lower[1], upper[1], sam[1])
xs, ys = np.meshgrid(xs, ys)
xs, ys = xs.flatten(), ys.flatten()
mesh = [[xs[i], ys[i]] for i in range(len(xs))]
xs, ys = np.array(np.meshgrid(xs, ys)).reshape(2, xs.size * ys.size)
# Evaluate function at all mesh nodes and convert it to matrix
zs = np.array(self.get_value(mesh))
self.set_source(np.concatenate(([xs], [ys], [zs])).transpose())
zs = np.array(self.get_value(xs, ys))
self.__interpolation__ = "shepard"
self.__extrapolation__ = "natural"
self.set_source(np.concatenate(([xs], [ys], [zs])).transpose())
return self

def set_discrete_based_on_model(
Expand Down Expand Up @@ -664,11 +662,8 @@ def set_discrete_based_on_model(
# Create nodes to evaluate function
xs = model_function.source[:, 0]
ys = model_function.source[:, 1]
xs, ys = np.meshgrid(xs, ys)
xs, ys = xs.flatten(), ys.flatten()
mesh = [[xs[i], ys[i]] for i in range(len(xs))]
# Evaluate function at all mesh nodes and convert it to matrix
zs = np.array(self.get_value(mesh))
zs = np.array(self.get_value(xs, ys))
self.set_source(np.concatenate(([xs], [ys], [zs])).transpose())

interp = (
Expand Down Expand Up @@ -2860,6 +2855,8 @@ def _check_user_input(

# check source for data type
# if list or ndarray, check for dimensions, interpolation and extrapolation
if isinstance(source, Function):
source = source.get_source()
if isinstance(source, (list, np.ndarray, str, Path)):
# Deal with csv or txt
if isinstance(source, (str, Path)):
Expand Down
49 changes: 49 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,55 @@ def test_multivariable_function_plot(mock_show):
assert func.plot() == None


def test_set_discrete_2d():
"""Tests the set_discrete method of the Function for
two dimensional domains.
"""
func = Function(lambda x, y: x**2 + y**2)
discretized_func = func.set_discrete([-5, -7], [8, 10], [50, 100])

assert isinstance(discretized_func, Function)
assert isinstance(func, Function)
assert discretized_func.source.shape == (50 * 100, 3)
assert np.isclose(discretized_func.source[0, 0], -5)
assert np.isclose(discretized_func.source[0, 1], -7)
assert np.isclose(discretized_func.source[-1, 0], 8)
assert np.isclose(discretized_func.source[-1, 1], 10)


def test_set_discrete_2d_simplified():
"""Tests the set_discrete method of the Function for
two dimensional domains with simplified inputs.
"""
source = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
func = Function(source=source, inputs=["x", "y"], outputs=["z"])
discretized_func = func.set_discrete(-1, 1, 10)

assert isinstance(discretized_func, Function)
assert isinstance(func, Function)
assert discretized_func.source.shape == (100, 3)
assert np.isclose(discretized_func.source[0, 0], -1)
assert np.isclose(discretized_func.source[0, 1], -1)
assert np.isclose(discretized_func.source[-1, 0], 1)
assert np.isclose(discretized_func.source[-1, 1], 1)


def test_set_discrete_based_on_2d_model(func_2d_from_csv):
"""Tests the set_discrete_based_on_model method with a 2d model
Function.
"""
func = Function(lambda x, y: x**2 + y**2)
discretized_func = func.set_discrete_based_on_model(func_2d_from_csv)

assert isinstance(discretized_func, Function)
assert isinstance(func, Function)
assert np.array_equal(
discretized_func.source[:, :2], func_2d_from_csv.source[:, :2]
)
assert discretized_func.__interpolation__ == func_2d_from_csv.__interpolation__
assert discretized_func.__extrapolation__ == func_2d_from_csv.__extrapolation__


@pytest.mark.parametrize(
"x,y,z_expected",
[
Expand Down

0 comments on commit 4127313

Please sign in to comment.