Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Invalid Arguments on Two Dimensional Discretize (HOTFIX). #521

Merged
merged 5 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ straightforward as possible.

### Fixed

-
- BUG: Invalid Arguments on Two Dimensional Discretize. [#521](https://github.com/RocketPy-Team/RocketPy/pull/521)

## [v1.1.4] - 2023-12-07

Expand Down
21 changes: 9 additions & 12 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def set_source(self, source):
self : Function
Returns the Function instance.
"""
_ = self._check_user_input(
*_, interpolation, extrapolation = self._check_user_input(
source,
self.__inputs__,
self.__outputs__,
Expand Down Expand Up @@ -277,10 +277,10 @@ def source_function(_):
self.source = source
# Update extrapolation method
if self.__extrapolation__ is None:
self.set_extrapolation()
self.set_extrapolation(extrapolation)
# Set default interpolation for point source if it hasn't
if self.__interpolation__ is None:
self.set_interpolation()
self.set_interpolation(interpolation)
else:
# Updates interpolation coefficients
self.set_interpolation(self.__interpolation__)
Expand Down Expand Up @@ -560,14 +560,12 @@ def set_discrete(
# Create nodes to evaluate function
xs = np.linspace(lower[0], upper[0], sam[0])
ys = np.linspace(lower[1], upper[1], sam[1])
xs, ys = np.meshgrid(xs, ys)
xs, ys = xs.flatten(), ys.flatten()
mesh = [[xs[i], ys[i]] for i in range(len(xs))]
xs, ys = np.array(np.meshgrid(xs, ys)).reshape(2, xs.size * ys.size)
# Evaluate function at all mesh nodes and convert it to matrix
zs = np.array(self.get_value(mesh))
self.set_source(np.concatenate(([xs], [ys], [zs])).transpose())
zs = np.array(self.get_value(xs, ys))
self.__interpolation__ = "shepard"
self.__extrapolation__ = "natural"
self.set_source(np.concatenate(([xs], [ys], [zs])).transpose())
return self

def set_discrete_based_on_model(
Expand Down Expand Up @@ -664,11 +662,8 @@ def set_discrete_based_on_model(
# Create nodes to evaluate function
xs = model_function.source[:, 0]
ys = model_function.source[:, 1]
xs, ys = np.meshgrid(xs, ys)
xs, ys = xs.flatten(), ys.flatten()
mesh = [[xs[i], ys[i]] for i in range(len(xs))]
# Evaluate function at all mesh nodes and convert it to matrix
zs = np.array(self.get_value(mesh))
zs = np.array(self.get_value(xs, ys))
self.set_source(np.concatenate(([xs], [ys], [zs])).transpose())

interp = (
Expand Down Expand Up @@ -2860,6 +2855,8 @@ def _check_user_input(

# check source for data type
# if list or ndarray, check for dimensions, interpolation and extrapolation
if isinstance(source, Function):
source = source.get_source()
if isinstance(source, (list, np.ndarray, str, Path)):
# Deal with csv or txt
if isinstance(source, (str, Path)):
Expand Down
49 changes: 49 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,55 @@ def test_multivariable_function_plot(mock_show):
assert func.plot() == None


def test_set_discrete_2d():
"""Tests the set_discrete method of the Function for
two dimensional domains.
"""
func = Function(lambda x, y: x**2 + y**2)
discretized_func = func.set_discrete([-5, -7], [8, 10], [50, 100])

assert isinstance(discretized_func, Function)
assert isinstance(func, Function)
assert discretized_func.source.shape == (50 * 100, 3)
assert np.isclose(discretized_func.source[0, 0], -5)
assert np.isclose(discretized_func.source[0, 1], -7)
assert np.isclose(discretized_func.source[-1, 0], 8)
assert np.isclose(discretized_func.source[-1, 1], 10)


def test_set_discrete_2d_simplified():
"""Tests the set_discrete method of the Function for
two dimensional domains with simplified inputs.
"""
source = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
func = Function(source=source, inputs=["x", "y"], outputs=["z"])
discretized_func = func.set_discrete(-1, 1, 10)

assert isinstance(discretized_func, Function)
assert isinstance(func, Function)
assert discretized_func.source.shape == (100, 3)
assert np.isclose(discretized_func.source[0, 0], -1)
assert np.isclose(discretized_func.source[0, 1], -1)
assert np.isclose(discretized_func.source[-1, 0], 1)
assert np.isclose(discretized_func.source[-1, 1], 1)


def test_set_discrete_based_on_2d_model(func_2d_from_csv):
"""Tests the set_discrete_based_on_model method with a 2d model
Function.
"""
func = Function(lambda x, y: x**2 + y**2)
discretized_func = func.set_discrete_based_on_model(func_2d_from_csv)

assert isinstance(discretized_func, Function)
assert isinstance(func, Function)
assert np.array_equal(
discretized_func.source[:, :2], func_2d_from_csv.source[:, :2]
)
assert discretized_func.__interpolation__ == func_2d_from_csv.__interpolation__
assert discretized_func.__extrapolation__ == func_2d_from_csv.__extrapolation__


@pytest.mark.parametrize(
"x,y,z_expected",
[
Expand Down
Loading