From 54c30f88d146242d307387fb423913d1bc618ff4 Mon Sep 17 00:00:00 2001 From: MateusStano Date: Tue, 21 Nov 2023 23:28:00 +0100 Subject: [PATCH 1/4] BUG: fix 2d csv Function definition and missing set_get_value_opt --- rocketpy/mathutils/function.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index ed379e606..344581ce4 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -229,15 +229,6 @@ def source_function(_): # Finally set data source as source self.source = source - # Update extrapolation method - if self.__extrapolation__ is None: - self.set_extrapolation() - # Set default interpolation for point source if it hasn't - if self.__interpolation__ is None: - self.set_interpolation() - else: - # Updates interpolation coefficients - self.set_interpolation(self.__interpolation__) # Do things if function is multivariate else: self.x_array = source[:, 0] @@ -251,6 +242,15 @@ def source_function(_): # Finally set data source as source self.source = source + # Update extrapolation method + if self.__extrapolation__ is None: + self.set_extrapolation() + # Set default interpolation for point source if it hasn't + if self.__interpolation__ is None: + self.set_interpolation() + else: + # Updates interpolation coefficients + self.set_interpolation(self.__interpolation__) return self @cached_property @@ -329,7 +329,7 @@ def set_extrapolation(self, method="constant"): def set_get_value_opt(self): """Crates a method that evaluates interpolations rather quickly when compared to other options available, such as just calling - the object instance or calling ``Function.get_value directly``. See + the object instance or calling ``Function.get_value`` directly. See ``Function.get_value_opt`` for documentation. Returns @@ -2825,7 +2825,12 @@ def _check_user_input( # check source for data type # if list or ndarray, check for dimensions, interpolation and extrapolation - if isinstance(source, (list, np.ndarray)): + if isinstance(source, (list, np.ndarray, str, Path)): + # Deal with csv or txt + if isinstance(source, (str, Path)): + # Convert to numpy array + source = np.loadtxt(source, delimiter=",", dtype=float) + # this will also trigger an error if the source is not a list of # numbers or if the array is not homogeneous source = np.array(source, dtype=np.float64) From 0b2eb80282c793aa844135bbd298eafdcd1ff860 Mon Sep 17 00:00:00 2001 From: MateusStano Date: Tue, 21 Nov 2023 23:30:05 +0100 Subject: [PATCH 2/4] TST: add tests with 2d csv Funciton --- tests/conftest.py | 18 ++++- tests/fixtures/function/2d.csv | 133 +++++++++++++++++++++++++++++++++ tests/test_function.py | 43 +++++++++-- 3 files changed, 186 insertions(+), 8 deletions(-) create mode 100644 tests/fixtures/function/2d.csv diff --git a/tests/conftest.py b/tests/conftest.py index 122b3f49a..05f0992a1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1112,9 +1112,23 @@ def func_from_csv(): """ func = Function( source="tests/fixtures/airfoils/e473-10e6-degrees.csv", - inputs=["Scalar"], + ) outputs=["Scalar"], + return func interpolation="linear", + extrapolation="natural", + +@pytest.fixture +def func_2d_from_csv(): + """Create a 2d function based on a csv file. The csv file contains the + Returns + ------- + rocketpy.Function + A function based on a csv file. + """ + # Do not define any of the optional parameters so that the tests can check + # if the defaults are being used correctly. + func = Function( + source="tests/fixtures/function/2d.csv", ) - return func diff --git a/tests/fixtures/function/2d.csv b/tests/fixtures/function/2d.csv new file mode 100644 index 000000000..b212ed041 --- /dev/null +++ b/tests/fixtures/function/2d.csv @@ -0,0 +1,133 @@ +0.0, 0.0, 0.000 +0.0, 0.1, 0.000 +0.0, 0.2, 0.000 +0.0, 0.3, 0.000 +0.0, 0.4, 0.000 +0.0, 0.5, 0.000 +0.0, 0.6, 0.000 +0.0, 0.7, 0.000 +0.0, 0.8, 0.000 +0.0, 0.9, 0.000 +0.0, 1.0, 0.000 +0.0, 1.1, 0.000 +0.1, 0.0, 0.000 +0.2, 0.0, 0.000 +0.3, 0.0, 0.000 +0.4, 0.0, 0.000 +0.5, 0.0, 0.000 +0.6, 0.0, 0.000 +0.7, 0.0, 0.000 +0.8, 0.0, 0.000 +0.9, 0.0, 0.000 +1.0, 0.0, 0.000 +1.1, 0.0, 0.000 +0.1, 0.1, 0.000 +0.1, 0.2, 0.000 +0.1, 0.3, 0.045 +0.1, 0.4, 0.022 +0.1, 0.5, 0.026 +0.1, 0.6, 0.076 +0.1, 0.7, 0.051 +0.1, 0.8, 0.058 +0.1, 0.9, 0.022 +0.1, 1.0, 0.000 +0.1, 1.1, 0.000 +0.2, 0.1, 0.022 +0.2, 0.2, 0.062 +0.2, 0.3, 0.093 +0.2, 0.4, 0.076 +0.2, 0.5, 0.070 +0.2, 0.6, 0.129 +0.2, 0.7, 0.102 +0.2, 0.8, 0.115 +0.2, 0.9, 0.084 +0.2, 1.0, 0.018 +0.2, 1.1, 0.000 +0.3, 0.1, 0.056 +0.3, 0.2, 0.106 +0.3, 0.3, 0.147 +0.3, 0.4, 0.139 +0.3, 0.5, 0.139 +0.3, 0.6, 0.183 +0.3, 0.7, 0.169 +0.3, 0.8, 0.183 +0.3, 0.9, 0.149 +0.3, 1.0, 0.093 +0.3, 1.1, 0.070 +0.4, 0.1, 0.120 +0.4, 0.2, 0.169 +0.4, 0.3, 0.214 +0.4, 0.4, 0.195 +0.4, 0.5, 0.214 +0.4, 0.6, 0.262 +0.4, 0.7, 0.253 +0.4, 0.8, 0.271 +0.4, 0.9, 0.253 +0.4, 1.0, 0.192 +0.4, 1.1, 0.164 +0.5, 0.1, 0.217 +0.5, 0.2, 0.217 +0.5, 0.3, 0.275 +0.5, 0.4, 0.257 +0.5, 0.5, 0.284 +0.5, 0.6, 0.349 +0.5, 0.7, 0.340 +0.5, 0.8, 0.360 +0.5, 0.9, 0.340 +0.5, 1.0, 0.288 +0.5, 1.1, 0.253 +0.6, 0.1, 0.245 +0.6, 0.2, 0.288 +0.6, 0.3, 0.382 +0.6, 0.4, 0.360 +0.6, 0.5, 0.382 +0.6, 0.6, 0.457 +0.6, 0.7, 0.445 +0.6, 0.8, 0.447 +0.6, 0.9, 0.434 +0.6, 1.0, 0.403 +0.6, 1.1, 0.360 +0.7, 0.1, 0.320 +0.7, 0.2, 0.392 +0.7, 0.3, 0.487 +0.7, 0.4, 0.476 +0.7, 0.5, 0.476 +0.7, 0.6, 0.564 +0.7, 0.7, 0.527 +0.7, 0.8, 0.531 +0.7, 0.9, 0.527 +0.7, 1.0, 0.520 +0.7, 1.1, 0.487 +0.8, 0.1, 0.426 +0.8, 0.2, 0.507 +0.8, 0.3, 0.568 +0.8, 0.4, 0.538 +0.8, 0.5, 0.538 +0.8, 0.6, 0.617 +0.8, 0.7, 0.613 +0.8, 0.8, 0.624 +0.8, 0.9, 0.613 +0.8, 1.0, 0.520 +0.8, 1.1, 0.591 +0.9, 0.1, 0.507 +0.9, 0.2, 0.568 +0.9, 0.3, 0.613 +0.9, 0.4, 0.600 +0.9, 0.5, 0.609 +0.9, 0.6, 0.684 +0.9, 0.7, 0.684 +0.9, 0.8, 0.702 +0.9, 0.9, 0.708 +0.9, 1.0, 0.624 +0.9, 1.1, 0.674 +1.0, 0.1, 0.937 +1.0, 0.2, 0.937 +1.0, 0.3, 0.937 +1.0, 0.4, 0.887 +1.0, 0.5, 0.803 +1.0, 0.6, 0.930 +1.0, 0.7, 0.887 +1.0, 0.8, 0.921 +1.0, 0.9, 0.815 +1.0, 1.0, 0.844 +1.0, 1.1, 0.803 \ No newline at end of file diff --git a/tests/test_function.py b/tests/test_function.py index 545790b6d..1a5b4f189 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -10,23 +10,34 @@ # Test Function creation from .csv file -def test_function_from_csv(func_from_csv): +def test_function_from_csv(func_from_csv, func_2d_from_csv): """Test the Function class creation from a .csv file. Parameters ---------- func_from_csv : rocketpy.Function A Function object created from a .csv file. + func_2d_from_csv : rocketpy.Function + A Function object created from a .csv file with 2 inputs. """ # Assert the function is zero at 0 but with a certain tolerance assert np.isclose(func_from_csv(0), 0.0, atol=1e-6) + assert np.isclose(func_2d_from_csv(0, 0), 0.0, atol=1e-6) # Check the __str__ method assert func_from_csv.__str__() == "Function from R1 to R1 : (Scalar) → (Scalar)" + assert ( + func_2d_from_csv.__str__() + == "Function from R2 to R1 : (Input 1, Input 2) → (Scalar)" + ) # Check the __repr__ method assert func_from_csv.__repr__() == "'Function from R1 to R1 : (Scalar) → (Scalar)'" + assert ( + func_2d_from_csv.__repr__() + == "'Function from R2 to R1 : (Input 1, Input 2) → (Scalar)'" + ) -def test_getters(func_from_csv): +def test_getters(func_from_csv, func_2d_from_csv): """Test the different getters of the Function class. Parameters @@ -36,13 +47,20 @@ def test_getters(func_from_csv): """ assert func_from_csv.get_inputs() == ["Scalar"] assert func_from_csv.get_outputs() == ["Scalar"] - assert func_from_csv.get_interpolation_method() == "linear" - assert func_from_csv.get_extrapolation_method() == "natural" + assert func_from_csv.get_interpolation_method() == "spline" + assert func_from_csv.get_extrapolation_method() == "constant" assert np.isclose(func_from_csv.get_value(0), 0.0, atol=1e-6) assert np.isclose(func_from_csv.get_value_opt(0), 0.0, atol=1e-6) + assert func_2d_from_csv.get_inputs() == ["Input 1", "Input 2"] + assert func_2d_from_csv.get_outputs() == ["Scalar"] + assert func_2d_from_csv.get_interpolation_method() == "shepard" + assert func_2d_from_csv.get_extrapolation_method() == "natural" + assert np.isclose(func_2d_from_csv.get_value(0, 0), 0.0, atol=1e-6) + assert np.isclose(func_2d_from_csv.get_value_opt(0, 0), 0.0, atol=1e-6) -def test_setters(func_from_csv): + +def test_setters(func_from_csv, func_2d_from_csv): """Test the different setters of the Function class. Parameters @@ -60,9 +78,18 @@ def test_setters(func_from_csv): func_from_csv.set_extrapolation("natural") assert func_from_csv.get_extrapolation_method() == "natural" + func_2d_from_csv.set_inputs(["Scalar1", "Scalar2"]) + assert func_2d_from_csv.get_inputs() == ["Scalar1", "Scalar2"] + func_2d_from_csv.set_outputs(["Scalar3"]) + assert func_2d_from_csv.get_outputs() == ["Scalar3"] + func_2d_from_csv.set_interpolation("shepard") + assert func_2d_from_csv.get_interpolation_method() == "shepard" + func_2d_from_csv.set_extrapolation("zero") + assert func_2d_from_csv.get_extrapolation_method() == "zero" + @patch("matplotlib.pyplot.show") -def test_plots(mock_show, func_from_csv): +def test_plots(mock_show, func_from_csv, func_2d_from_csv): """Test different plot methods of the Function class. Parameters @@ -74,6 +101,10 @@ def test_plots(mock_show, func_from_csv): """ # Test plot methods assert func_from_csv.plot() == None + assert func_2d_from_csv.plot() == None + # Test plot methods with limits + assert func_from_csv.plot(-1, 1) == None + assert func_2d_from_csv.plot(-1, 1) == None # Test compare_plots func2 = Function( source="tests/fixtures/airfoils/e473-10e6-degrees.csv", From 8abe6af79803fff9d7f577b9196e8ba556037fd7 Mon Sep 17 00:00:00 2001 From: MateusStano Date: Tue, 21 Nov 2023 23:43:32 +0100 Subject: [PATCH 3/4] BUG: typo in conftest --- tests/conftest.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 05f0992a1..939256be7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1113,11 +1113,8 @@ def func_from_csv(): func = Function( source="tests/fixtures/airfoils/e473-10e6-degrees.csv", ) - outputs=["Scalar"], return func - interpolation="linear", - extrapolation="natural", @pytest.fixture def func_2d_from_csv(): @@ -1132,3 +1129,4 @@ def func_2d_from_csv(): func = Function( source="tests/fixtures/function/2d.csv", ) + return func From 2ee958ae5ded6b24455a855e49058207133af46f Mon Sep 17 00:00:00 2001 From: MateusStano Date: Wed, 22 Nov 2023 14:57:26 +0100 Subject: [PATCH 4/4] MNT: few docs fix and remove extra np.array conversion for csvs --- rocketpy/mathutils/function.py | 10 ++++++---- tests/conftest.py | 3 ++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 344581ce4..e18495a1d 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -2785,7 +2785,8 @@ def _check_user_input( dimensions of inputs and outputs. If the outputs list has more than one element. TypeError - If the source is not a list, np.ndarray, or Function object. + If the source is not a list, np.ndarray, Function object, str or + Path. Warning If inputs or outputs do not match for a Function source, or if defaults are used for inputs, interpolation,and extrapolation for a @@ -2831,9 +2832,10 @@ def _check_user_input( # Convert to numpy array source = np.loadtxt(source, delimiter=",", dtype=float) - # this will also trigger an error if the source is not a list of - # numbers or if the array is not homogeneous - source = np.array(source, dtype=np.float64) + else: + # this will also trigger an error if the source is not a list of + # numbers or if the array is not homogeneous + source = np.array(source, dtype=np.float64) # check dimensions source_dim = source.shape[1] diff --git a/tests/conftest.py b/tests/conftest.py index 939256be7..2e4b040f0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1118,7 +1118,8 @@ def func_from_csv(): @pytest.fixture def func_2d_from_csv(): - """Create a 2d function based on a csv file. The csv file contains the + """Create a 2d function based on a csv file. + Returns ------- rocketpy.Function