Skip to content

Commit

Permalink
Merge pull request #478 from RocketPy-Team/hotfix/csv-2d-function
Browse files Browse the repository at this point in the history
HOTFIX: 2D .CSV Function and missing set_get_value_opt call
  • Loading branch information
MateusStano authored Nov 22, 2023
2 parents 8428fd8 + 2ee958a commit e18ca22
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 25 deletions.
37 changes: 22 additions & 15 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,6 @@ def source_function(_):

# Finally set data source as source
self.source = source
# Update extrapolation method
if self.__extrapolation__ is None:
self.set_extrapolation()
# Set default interpolation for point source if it hasn't
if self.__interpolation__ is None:
self.set_interpolation()
else:
# Updates interpolation coefficients
self.set_interpolation(self.__interpolation__)
# Do things if function is multivariate
else:
self.x_array = source[:, 0]
Expand All @@ -251,6 +242,15 @@ def source_function(_):

# Finally set data source as source
self.source = source
# Update extrapolation method
if self.__extrapolation__ is None:
self.set_extrapolation()
# Set default interpolation for point source if it hasn't
if self.__interpolation__ is None:
self.set_interpolation()
else:
# Updates interpolation coefficients
self.set_interpolation(self.__interpolation__)
return self

@cached_property
Expand Down Expand Up @@ -329,7 +329,7 @@ def set_extrapolation(self, method="constant"):
def set_get_value_opt(self):
"""Crates a method that evaluates interpolations rather quickly
when compared to other options available, such as just calling
the object instance or calling ``Function.get_value directly``. See
the object instance or calling ``Function.get_value`` directly. See
``Function.get_value_opt`` for documentation.
Returns
Expand Down Expand Up @@ -2785,7 +2785,8 @@ def _check_user_input(
dimensions of inputs and outputs. If the outputs list has more than
one element.
TypeError
If the source is not a list, np.ndarray, or Function object.
If the source is not a list, np.ndarray, Function object, str or
Path.
Warning
If inputs or outputs do not match for a Function source, or if
defaults are used for inputs, interpolation,and extrapolation for a
Expand Down Expand Up @@ -2825,10 +2826,16 @@ def _check_user_input(

# check source for data type
# if list or ndarray, check for dimensions, interpolation and extrapolation
if isinstance(source, (list, np.ndarray)):
# this will also trigger an error if the source is not a list of
# numbers or if the array is not homogeneous
source = np.array(source, dtype=np.float64)
if isinstance(source, (list, np.ndarray, str, Path)):
# Deal with csv or txt
if isinstance(source, (str, Path)):
# Convert to numpy array
source = np.loadtxt(source, delimiter=",", dtype=float)

else:
# this will also trigger an error if the source is not a list of
# numbers or if the array is not homogeneous
source = np.array(source, dtype=np.float64)

# check dimensions
source_dim = source.shape[1]
Expand Down
21 changes: 17 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,9 +1112,22 @@ def func_from_csv():
"""
func = Function(
source="tests/fixtures/airfoils/e473-10e6-degrees.csv",
inputs=["Scalar"],
outputs=["Scalar"],
interpolation="linear",
extrapolation="natural",
)
return func


@pytest.fixture
def func_2d_from_csv():
"""Create a 2d function based on a csv file.
Returns
-------
rocketpy.Function
A function based on a csv file.
"""
# Do not define any of the optional parameters so that the tests can check
# if the defaults are being used correctly.
func = Function(
source="tests/fixtures/function/2d.csv",
)
return func
133 changes: 133 additions & 0 deletions tests/fixtures/function/2d.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
0.0, 0.0, 0.000
0.0, 0.1, 0.000
0.0, 0.2, 0.000
0.0, 0.3, 0.000
0.0, 0.4, 0.000
0.0, 0.5, 0.000
0.0, 0.6, 0.000
0.0, 0.7, 0.000
0.0, 0.8, 0.000
0.0, 0.9, 0.000
0.0, 1.0, 0.000
0.0, 1.1, 0.000
0.1, 0.0, 0.000
0.2, 0.0, 0.000
0.3, 0.0, 0.000
0.4, 0.0, 0.000
0.5, 0.0, 0.000
0.6, 0.0, 0.000
0.7, 0.0, 0.000
0.8, 0.0, 0.000
0.9, 0.0, 0.000
1.0, 0.0, 0.000
1.1, 0.0, 0.000
0.1, 0.1, 0.000
0.1, 0.2, 0.000
0.1, 0.3, 0.045
0.1, 0.4, 0.022
0.1, 0.5, 0.026
0.1, 0.6, 0.076
0.1, 0.7, 0.051
0.1, 0.8, 0.058
0.1, 0.9, 0.022
0.1, 1.0, 0.000
0.1, 1.1, 0.000
0.2, 0.1, 0.022
0.2, 0.2, 0.062
0.2, 0.3, 0.093
0.2, 0.4, 0.076
0.2, 0.5, 0.070
0.2, 0.6, 0.129
0.2, 0.7, 0.102
0.2, 0.8, 0.115
0.2, 0.9, 0.084
0.2, 1.0, 0.018
0.2, 1.1, 0.000
0.3, 0.1, 0.056
0.3, 0.2, 0.106
0.3, 0.3, 0.147
0.3, 0.4, 0.139
0.3, 0.5, 0.139
0.3, 0.6, 0.183
0.3, 0.7, 0.169
0.3, 0.8, 0.183
0.3, 0.9, 0.149
0.3, 1.0, 0.093
0.3, 1.1, 0.070
0.4, 0.1, 0.120
0.4, 0.2, 0.169
0.4, 0.3, 0.214
0.4, 0.4, 0.195
0.4, 0.5, 0.214
0.4, 0.6, 0.262
0.4, 0.7, 0.253
0.4, 0.8, 0.271
0.4, 0.9, 0.253
0.4, 1.0, 0.192
0.4, 1.1, 0.164
0.5, 0.1, 0.217
0.5, 0.2, 0.217
0.5, 0.3, 0.275
0.5, 0.4, 0.257
0.5, 0.5, 0.284
0.5, 0.6, 0.349
0.5, 0.7, 0.340
0.5, 0.8, 0.360
0.5, 0.9, 0.340
0.5, 1.0, 0.288
0.5, 1.1, 0.253
0.6, 0.1, 0.245
0.6, 0.2, 0.288
0.6, 0.3, 0.382
0.6, 0.4, 0.360
0.6, 0.5, 0.382
0.6, 0.6, 0.457
0.6, 0.7, 0.445
0.6, 0.8, 0.447
0.6, 0.9, 0.434
0.6, 1.0, 0.403
0.6, 1.1, 0.360
0.7, 0.1, 0.320
0.7, 0.2, 0.392
0.7, 0.3, 0.487
0.7, 0.4, 0.476
0.7, 0.5, 0.476
0.7, 0.6, 0.564
0.7, 0.7, 0.527
0.7, 0.8, 0.531
0.7, 0.9, 0.527
0.7, 1.0, 0.520
0.7, 1.1, 0.487
0.8, 0.1, 0.426
0.8, 0.2, 0.507
0.8, 0.3, 0.568
0.8, 0.4, 0.538
0.8, 0.5, 0.538
0.8, 0.6, 0.617
0.8, 0.7, 0.613
0.8, 0.8, 0.624
0.8, 0.9, 0.613
0.8, 1.0, 0.520
0.8, 1.1, 0.591
0.9, 0.1, 0.507
0.9, 0.2, 0.568
0.9, 0.3, 0.613
0.9, 0.4, 0.600
0.9, 0.5, 0.609
0.9, 0.6, 0.684
0.9, 0.7, 0.684
0.9, 0.8, 0.702
0.9, 0.9, 0.708
0.9, 1.0, 0.624
0.9, 1.1, 0.674
1.0, 0.1, 0.937
1.0, 0.2, 0.937
1.0, 0.3, 0.937
1.0, 0.4, 0.887
1.0, 0.5, 0.803
1.0, 0.6, 0.930
1.0, 0.7, 0.887
1.0, 0.8, 0.921
1.0, 0.9, 0.815
1.0, 1.0, 0.844
1.0, 1.1, 0.803
43 changes: 37 additions & 6 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,34 @@


# Test Function creation from .csv file
def test_function_from_csv(func_from_csv):
def test_function_from_csv(func_from_csv, func_2d_from_csv):
"""Test the Function class creation from a .csv file.
Parameters
----------
func_from_csv : rocketpy.Function
A Function object created from a .csv file.
func_2d_from_csv : rocketpy.Function
A Function object created from a .csv file with 2 inputs.
"""
# Assert the function is zero at 0 but with a certain tolerance
assert np.isclose(func_from_csv(0), 0.0, atol=1e-6)
assert np.isclose(func_2d_from_csv(0, 0), 0.0, atol=1e-6)
# Check the __str__ method
assert func_from_csv.__str__() == "Function from R1 to R1 : (Scalar) → (Scalar)"
assert (
func_2d_from_csv.__str__()
== "Function from R2 to R1 : (Input 1, Input 2) → (Scalar)"
)
# Check the __repr__ method
assert func_from_csv.__repr__() == "'Function from R1 to R1 : (Scalar) → (Scalar)'"
assert (
func_2d_from_csv.__repr__()
== "'Function from R2 to R1 : (Input 1, Input 2) → (Scalar)'"
)


def test_getters(func_from_csv):
def test_getters(func_from_csv, func_2d_from_csv):
"""Test the different getters of the Function class.
Parameters
Expand All @@ -36,13 +47,20 @@ def test_getters(func_from_csv):
"""
assert func_from_csv.get_inputs() == ["Scalar"]
assert func_from_csv.get_outputs() == ["Scalar"]
assert func_from_csv.get_interpolation_method() == "linear"
assert func_from_csv.get_extrapolation_method() == "natural"
assert func_from_csv.get_interpolation_method() == "spline"
assert func_from_csv.get_extrapolation_method() == "constant"
assert np.isclose(func_from_csv.get_value(0), 0.0, atol=1e-6)
assert np.isclose(func_from_csv.get_value_opt(0), 0.0, atol=1e-6)

assert func_2d_from_csv.get_inputs() == ["Input 1", "Input 2"]
assert func_2d_from_csv.get_outputs() == ["Scalar"]
assert func_2d_from_csv.get_interpolation_method() == "shepard"
assert func_2d_from_csv.get_extrapolation_method() == "natural"
assert np.isclose(func_2d_from_csv.get_value(0, 0), 0.0, atol=1e-6)
assert np.isclose(func_2d_from_csv.get_value_opt(0, 0), 0.0, atol=1e-6)

def test_setters(func_from_csv):

def test_setters(func_from_csv, func_2d_from_csv):
"""Test the different setters of the Function class.
Parameters
Expand All @@ -60,9 +78,18 @@ def test_setters(func_from_csv):
func_from_csv.set_extrapolation("natural")
assert func_from_csv.get_extrapolation_method() == "natural"

func_2d_from_csv.set_inputs(["Scalar1", "Scalar2"])
assert func_2d_from_csv.get_inputs() == ["Scalar1", "Scalar2"]
func_2d_from_csv.set_outputs(["Scalar3"])
assert func_2d_from_csv.get_outputs() == ["Scalar3"]
func_2d_from_csv.set_interpolation("shepard")
assert func_2d_from_csv.get_interpolation_method() == "shepard"
func_2d_from_csv.set_extrapolation("zero")
assert func_2d_from_csv.get_extrapolation_method() == "zero"


@patch("matplotlib.pyplot.show")
def test_plots(mock_show, func_from_csv):
def test_plots(mock_show, func_from_csv, func_2d_from_csv):
"""Test different plot methods of the Function class.
Parameters
Expand All @@ -74,6 +101,10 @@ def test_plots(mock_show, func_from_csv):
"""
# Test plot methods
assert func_from_csv.plot() == None
assert func_2d_from_csv.plot() == None
# Test plot methods with limits
assert func_from_csv.plot(-1, 1) == None
assert func_2d_from_csv.plot(-1, 1) == None
# Test compare_plots
func2 = Function(
source="tests/fixtures/airfoils/e473-10e6-degrees.csv",
Expand Down

0 comments on commit e18ca22

Please sign in to comment.