Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HOTFIX: 2D .CSV Function and missing set_get_value_opt call #478

Merged
merged 4 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,6 @@ def source_function(_):

# Finally set data source as source
self.source = source
# Update extrapolation method
if self.__extrapolation__ is None:
self.set_extrapolation()
# Set default interpolation for point source if it hasn't
if self.__interpolation__ is None:
self.set_interpolation()
else:
# Updates interpolation coefficients
self.set_interpolation(self.__interpolation__)
# Do things if function is multivariate
else:
self.x_array = source[:, 0]
Expand All @@ -251,6 +242,15 @@ def source_function(_):

# Finally set data source as source
self.source = source
# Update extrapolation method
if self.__extrapolation__ is None:
self.set_extrapolation()
# Set default interpolation for point source if it hasn't
if self.__interpolation__ is None:
self.set_interpolation()
else:
# Updates interpolation coefficients
self.set_interpolation(self.__interpolation__)
return self

@cached_property
Expand Down Expand Up @@ -329,7 +329,7 @@ def set_extrapolation(self, method="constant"):
def set_get_value_opt(self):
"""Crates a method that evaluates interpolations rather quickly
when compared to other options available, such as just calling
the object instance or calling ``Function.get_value directly``. See
the object instance or calling ``Function.get_value`` directly. See
``Function.get_value_opt`` for documentation.

Returns
Expand Down Expand Up @@ -2825,7 +2825,12 @@ def _check_user_input(

# check source for data type
# if list or ndarray, check for dimensions, interpolation and extrapolation
if isinstance(source, (list, np.ndarray)):
if isinstance(source, (list, np.ndarray, str, Path)):
# Deal with csv or txt
if isinstance(source, (str, Path)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe a try-except with FileNotFound would be beneficial here, just because it is a good practice. But again, this is also done in set_source.

# Convert to numpy array
source = np.loadtxt(source, delimiter=",", dtype=float)

Comment on lines +2829 to +2834
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that this is necessary in order to check the validity of the inputs/outputs, but it is quite strange to read the csv here, throw away the result, and read it again in set_source.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that loading it twice is really not a good idea. It can be a huge problem if the CSV file is large and takes a while to read.

Furthermore, I would not recommend changing this in this hotfix PR. It doesn't exactly fix a bug.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that this is necessary in order to check the validity of the inputs/outputs, but it is quite strange to read the csv here, throw away the result, and read it again in set_source.

Yep, it is pretty strange, but the whole _check_user_input is already strange. It is called twice (and I am not sure but depending on the source type it might be called 3 times) during the __init__. Also it has to process the source to be able to check if it is a valid input, but it does not save that processing, the saving is left to be dealt in set_source where it has to do all the calculations again. The source is processed twice for sources of type list and np.array as well

And if the source is anything besides (list, np.ndarray, str, Path) this _check_user_input method does essentially nothing

Also _check_user_input is a static_method for some reason?

What I am getting at is that all the user's inputs arguments should not be checked inside the same function, but rather in the calls of the set methods for those attributes. Maybe even inside .setter decorators. Structure wise, this method does not really work for the Function class, but this is a only a hotfix pr, we can focus on improving the method later

Furthermore, I would not recommend changing this in this hotfix PR. It doesn't exactly fix a bug.

This str and Path check is half of the bug corrections in this PR. Without these lines 2D .csv Functions will not have interpolation

Copy link
Collaborator

@phmbressan phmbressan Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, in retrospect, it is a bit odd to perform the same operations multiple times in check inputs and while setting the source.

I approved the PR, since I don't think changes to the check method is in the scope of a hotfix. The PR fixes the issue with csv_2d. In a future PR the check inputs could be made as a way of pre-processing the source so as to only return a np.array or a callable to the set_source.

This would decouple responsabilities a fair bit and could be implemented in various ways, maybe with functools.singledispatch.

# this will also trigger an error if the source is not a list of
# numbers or if the array is not homogeneous
source = np.array(source, dtype=np.float64)
Expand Down
20 changes: 16 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,9 +1112,21 @@ def func_from_csv():
"""
func = Function(
source="tests/fixtures/airfoils/e473-10e6-degrees.csv",
inputs=["Scalar"],
outputs=["Scalar"],
interpolation="linear",
extrapolation="natural",
)
return func


@pytest.fixture
def func_2d_from_csv():
"""Create a 2d function based on a csv file. The csv file contains the
MateusStano marked this conversation as resolved.
Show resolved Hide resolved
Returns
-------
rocketpy.Function
A function based on a csv file.
"""
# Do not define any of the optional parameters so that the tests can check
# if the defaults are being used correctly.
func = Function(
source="tests/fixtures/function/2d.csv",
)
return func
133 changes: 133 additions & 0 deletions tests/fixtures/function/2d.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
0.0, 0.0, 0.000
0.0, 0.1, 0.000
0.0, 0.2, 0.000
0.0, 0.3, 0.000
0.0, 0.4, 0.000
0.0, 0.5, 0.000
0.0, 0.6, 0.000
0.0, 0.7, 0.000
0.0, 0.8, 0.000
0.0, 0.9, 0.000
0.0, 1.0, 0.000
0.0, 1.1, 0.000
0.1, 0.0, 0.000
0.2, 0.0, 0.000
0.3, 0.0, 0.000
0.4, 0.0, 0.000
0.5, 0.0, 0.000
0.6, 0.0, 0.000
0.7, 0.0, 0.000
0.8, 0.0, 0.000
0.9, 0.0, 0.000
1.0, 0.0, 0.000
1.1, 0.0, 0.000
0.1, 0.1, 0.000
0.1, 0.2, 0.000
0.1, 0.3, 0.045
0.1, 0.4, 0.022
0.1, 0.5, 0.026
0.1, 0.6, 0.076
0.1, 0.7, 0.051
0.1, 0.8, 0.058
0.1, 0.9, 0.022
0.1, 1.0, 0.000
0.1, 1.1, 0.000
0.2, 0.1, 0.022
0.2, 0.2, 0.062
0.2, 0.3, 0.093
0.2, 0.4, 0.076
0.2, 0.5, 0.070
0.2, 0.6, 0.129
0.2, 0.7, 0.102
0.2, 0.8, 0.115
0.2, 0.9, 0.084
0.2, 1.0, 0.018
0.2, 1.1, 0.000
0.3, 0.1, 0.056
0.3, 0.2, 0.106
0.3, 0.3, 0.147
0.3, 0.4, 0.139
0.3, 0.5, 0.139
0.3, 0.6, 0.183
0.3, 0.7, 0.169
0.3, 0.8, 0.183
0.3, 0.9, 0.149
0.3, 1.0, 0.093
0.3, 1.1, 0.070
0.4, 0.1, 0.120
0.4, 0.2, 0.169
0.4, 0.3, 0.214
0.4, 0.4, 0.195
0.4, 0.5, 0.214
0.4, 0.6, 0.262
0.4, 0.7, 0.253
0.4, 0.8, 0.271
0.4, 0.9, 0.253
0.4, 1.0, 0.192
0.4, 1.1, 0.164
0.5, 0.1, 0.217
0.5, 0.2, 0.217
0.5, 0.3, 0.275
0.5, 0.4, 0.257
0.5, 0.5, 0.284
0.5, 0.6, 0.349
0.5, 0.7, 0.340
0.5, 0.8, 0.360
0.5, 0.9, 0.340
0.5, 1.0, 0.288
0.5, 1.1, 0.253
0.6, 0.1, 0.245
0.6, 0.2, 0.288
0.6, 0.3, 0.382
0.6, 0.4, 0.360
0.6, 0.5, 0.382
0.6, 0.6, 0.457
0.6, 0.7, 0.445
0.6, 0.8, 0.447
0.6, 0.9, 0.434
0.6, 1.0, 0.403
0.6, 1.1, 0.360
0.7, 0.1, 0.320
0.7, 0.2, 0.392
0.7, 0.3, 0.487
0.7, 0.4, 0.476
0.7, 0.5, 0.476
0.7, 0.6, 0.564
0.7, 0.7, 0.527
0.7, 0.8, 0.531
0.7, 0.9, 0.527
0.7, 1.0, 0.520
0.7, 1.1, 0.487
0.8, 0.1, 0.426
0.8, 0.2, 0.507
0.8, 0.3, 0.568
0.8, 0.4, 0.538
0.8, 0.5, 0.538
0.8, 0.6, 0.617
0.8, 0.7, 0.613
0.8, 0.8, 0.624
0.8, 0.9, 0.613
0.8, 1.0, 0.520
0.8, 1.1, 0.591
0.9, 0.1, 0.507
0.9, 0.2, 0.568
0.9, 0.3, 0.613
0.9, 0.4, 0.600
0.9, 0.5, 0.609
0.9, 0.6, 0.684
0.9, 0.7, 0.684
0.9, 0.8, 0.702
0.9, 0.9, 0.708
0.9, 1.0, 0.624
0.9, 1.1, 0.674
1.0, 0.1, 0.937
1.0, 0.2, 0.937
1.0, 0.3, 0.937
1.0, 0.4, 0.887
1.0, 0.5, 0.803
1.0, 0.6, 0.930
1.0, 0.7, 0.887
1.0, 0.8, 0.921
1.0, 0.9, 0.815
1.0, 1.0, 0.844
1.0, 1.1, 0.803
43 changes: 37 additions & 6 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,34 @@


# Test Function creation from .csv file
def test_function_from_csv(func_from_csv):
def test_function_from_csv(func_from_csv, func_2d_from_csv):
"""Test the Function class creation from a .csv file.

Parameters
----------
func_from_csv : rocketpy.Function
A Function object created from a .csv file.
func_2d_from_csv : rocketpy.Function
A Function object created from a .csv file with 2 inputs.
"""
# Assert the function is zero at 0 but with a certain tolerance
assert np.isclose(func_from_csv(0), 0.0, atol=1e-6)
assert np.isclose(func_2d_from_csv(0, 0), 0.0, atol=1e-6)
# Check the __str__ method
assert func_from_csv.__str__() == "Function from R1 to R1 : (Scalar) → (Scalar)"
assert (
func_2d_from_csv.__str__()
== "Function from R2 to R1 : (Input 1, Input 2) → (Scalar)"
)
# Check the __repr__ method
assert func_from_csv.__repr__() == "'Function from R1 to R1 : (Scalar) → (Scalar)'"
assert (
func_2d_from_csv.__repr__()
== "'Function from R2 to R1 : (Input 1, Input 2) → (Scalar)'"
)


def test_getters(func_from_csv):
def test_getters(func_from_csv, func_2d_from_csv):
"""Test the different getters of the Function class.

Parameters
Expand All @@ -36,13 +47,20 @@ def test_getters(func_from_csv):
"""
assert func_from_csv.get_inputs() == ["Scalar"]
assert func_from_csv.get_outputs() == ["Scalar"]
assert func_from_csv.get_interpolation_method() == "linear"
assert func_from_csv.get_extrapolation_method() == "natural"
assert func_from_csv.get_interpolation_method() == "spline"
assert func_from_csv.get_extrapolation_method() == "constant"
assert np.isclose(func_from_csv.get_value(0), 0.0, atol=1e-6)
assert np.isclose(func_from_csv.get_value_opt(0), 0.0, atol=1e-6)

assert func_2d_from_csv.get_inputs() == ["Input 1", "Input 2"]
assert func_2d_from_csv.get_outputs() == ["Scalar"]
assert func_2d_from_csv.get_interpolation_method() == "shepard"
assert func_2d_from_csv.get_extrapolation_method() == "natural"
assert np.isclose(func_2d_from_csv.get_value(0, 0), 0.0, atol=1e-6)
assert np.isclose(func_2d_from_csv.get_value_opt(0, 0), 0.0, atol=1e-6)

def test_setters(func_from_csv):

def test_setters(func_from_csv, func_2d_from_csv):
"""Test the different setters of the Function class.

Parameters
Expand All @@ -60,9 +78,18 @@ def test_setters(func_from_csv):
func_from_csv.set_extrapolation("natural")
assert func_from_csv.get_extrapolation_method() == "natural"

func_2d_from_csv.set_inputs(["Scalar1", "Scalar2"])
assert func_2d_from_csv.get_inputs() == ["Scalar1", "Scalar2"]
func_2d_from_csv.set_outputs(["Scalar3"])
assert func_2d_from_csv.get_outputs() == ["Scalar3"]
func_2d_from_csv.set_interpolation("shepard")
assert func_2d_from_csv.get_interpolation_method() == "shepard"
func_2d_from_csv.set_extrapolation("zero")
assert func_2d_from_csv.get_extrapolation_method() == "zero"


@patch("matplotlib.pyplot.show")
def test_plots(mock_show, func_from_csv):
def test_plots(mock_show, func_from_csv, func_2d_from_csv):
"""Test different plot methods of the Function class.

Parameters
Expand All @@ -74,6 +101,10 @@ def test_plots(mock_show, func_from_csv):
"""
# Test plot methods
assert func_from_csv.plot() == None
assert func_2d_from_csv.plot() == None
# Test plot methods with limits
assert func_from_csv.plot(-1, 1) == None
assert func_2d_from_csv.plot(-1, 1) == None
# Test compare_plots
func2 = Function(
source="tests/fixtures/airfoils/e473-10e6-degrees.csv",
Expand Down
Loading