Skip to content

Commit

Permalink
Merge branch 'develop' into enh/optional-discretize-mutation
Browse files Browse the repository at this point in the history
  • Loading branch information
phmbressan committed Jan 19, 2024
2 parents a1ba056 + fe90f77 commit 2205fdf
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ straightforward as possible.
- BUG: fin_flutter_analysis doesn't find any fin set [#510](https://github.com/RocketPy-Team/RocketPy/pull/510)
- FIX: EmptyMotor is breaking the Rocket.draw() method [#516](https://github.com/RocketPy-Team/RocketPy/pull/516)
- BUG: 3D trajectory plot not labeling axes [#533](https://github.com/RocketPy-Team/RocketPy/pull/533)
- BUG: Invalid Arguments on Two Dimensional Discretize. [#521](https://github.com/RocketPy-Team/RocketPy/pull/521)

## [v1.1.4] - 2023-12-07

Expand Down
19 changes: 8 additions & 11 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def set_source(self, source):
self : Function
Returns the Function instance.
"""
_ = self._check_user_input(
*_, interpolation, extrapolation = self._check_user_input(
source,
self.__inputs__,
self.__outputs__,
Expand Down Expand Up @@ -281,10 +281,10 @@ def source_function(_):
self.source = source
# Update extrapolation method
if self.__extrapolation__ is None:
self.set_extrapolation()
self.set_extrapolation(extrapolation)
# Set default interpolation for point source if it hasn't
if self.__interpolation__ is None:
self.set_interpolation()
self.set_interpolation(interpolation)
else:
# Updates interpolation coefficients
self.set_interpolation(self.__interpolation__)
Expand Down Expand Up @@ -581,11 +581,9 @@ def set_discrete(
# Create nodes to evaluate function
xs = np.linspace(lower[0], upper[0], sam[0])
ys = np.linspace(lower[1], upper[1], sam[1])
xs, ys = np.meshgrid(xs, ys)
xs, ys = xs.flatten(), ys.flatten()
mesh = [[xs[i], ys[i]] for i in range(len(xs))]
xs, ys = np.array(np.meshgrid(xs, ys)).reshape(2, xs.size * ys.size)
# Evaluate function at all mesh nodes and convert it to matrix
zs = np.array(func.get_value(mesh))
zs = np.array(func.get_value(xs, ys))
func.set_source(np.concatenate(([xs], [ys], [zs])).transpose())
func.__interpolation__ = "shepard"
func.__extrapolation__ = "natural"
Expand Down Expand Up @@ -698,11 +696,8 @@ def set_discrete_based_on_model(
# Create nodes to evaluate function
xs = model_function.source[:, 0]
ys = model_function.source[:, 1]
xs, ys = np.meshgrid(xs, ys)
xs, ys = xs.flatten(), ys.flatten()
mesh = [[xs[i], ys[i]] for i in range(len(xs))]
# Evaluate function at all mesh nodes and convert it to matrix
zs = np.array(func.get_value(mesh))
zs = np.array(func.get_value(xs, ys))
func.set_source(np.concatenate(([xs], [ys], [zs])).transpose())
else:
raise ValueError(

Check warning on line 703 in rocketpy/mathutils/function.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/mathutils/function.py#L703

Added line #L703 was not covered by tests
Expand Down Expand Up @@ -2949,6 +2944,8 @@ def _check_user_input(

# check source for data type
# if list or ndarray, check for dimensions, interpolation and extrapolation
if isinstance(source, Function):
source = source.get_source()
if isinstance(source, (list, np.ndarray, str, Path)):
# Deal with csv or txt
if isinstance(source, (str, Path)):
Expand Down
49 changes: 49 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,55 @@ def test_multivariable_function_plot(mock_show):
assert func.plot() == None


def test_set_discrete_2d():
"""Tests the set_discrete method of the Function for
two dimensional domains.
"""
func = Function(lambda x, y: x**2 + y**2)
discretized_func = func.set_discrete([-5, -7], [8, 10], [50, 100])

assert isinstance(discretized_func, Function)
assert isinstance(func, Function)
assert discretized_func.source.shape == (50 * 100, 3)
assert np.isclose(discretized_func.source[0, 0], -5)
assert np.isclose(discretized_func.source[0, 1], -7)
assert np.isclose(discretized_func.source[-1, 0], 8)
assert np.isclose(discretized_func.source[-1, 1], 10)


def test_set_discrete_2d_simplified():
"""Tests the set_discrete method of the Function for
two dimensional domains with simplified inputs.
"""
source = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
func = Function(source=source, inputs=["x", "y"], outputs=["z"])
discretized_func = func.set_discrete(-1, 1, 10)

assert isinstance(discretized_func, Function)
assert isinstance(func, Function)
assert discretized_func.source.shape == (100, 3)
assert np.isclose(discretized_func.source[0, 0], -1)
assert np.isclose(discretized_func.source[0, 1], -1)
assert np.isclose(discretized_func.source[-1, 0], 1)
assert np.isclose(discretized_func.source[-1, 1], 1)


def test_set_discrete_based_on_2d_model(func_2d_from_csv):
"""Tests the set_discrete_based_on_model method with a 2d model
Function.
"""
func = Function(lambda x, y: x**2 + y**2)
discretized_func = func.set_discrete_based_on_model(func_2d_from_csv)

assert isinstance(discretized_func, Function)
assert isinstance(func, Function)
assert np.array_equal(
discretized_func.source[:, :2], func_2d_from_csv.source[:, :2]
)
assert discretized_func.__interpolation__ == func_2d_from_csv.__interpolation__
assert discretized_func.__extrapolation__ == func_2d_from_csv.__extrapolation__


@pytest.mark.parametrize(
"x,y,z_expected",
[
Expand Down

0 comments on commit 2205fdf

Please sign in to comment.