Skip to content

Commit

Permalink
TST: complementing tests for sensitivity analysis and removing duplic…
Browse files Browse the repository at this point in the history
…ate piece of code.
  • Loading branch information
Lucas-Prates authored and Gui-FernandesBR committed Dec 16, 2024
1 parent c07b50b commit 6863d6e
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 17 deletions.
8 changes: 2 additions & 6 deletions rocketpy/sensitivity/sensitivity_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,6 @@ def set_target_variables_nominal(self, target_variables_nominal_value):
self.target_variables_info[target_variable]["nominal_value"] = (
target_variables_nominal_value[i]
)
for i, target_variable in enumerate(self.target_variables_names):
self.target_variables_info[target_variable]["nominal_value"] = (
target_variables_nominal_value[i]
)

self._nominal_target_passed = True

Expand Down Expand Up @@ -356,12 +352,12 @@ def __check_requirements(self):
version = ">=0" if not version else version
try:
check_requirement_version(module_name, version)
except (ValueError, ImportError) as e:
except (ValueError, ImportError) as e: # pragma: no cover
has_error = True
print(
f"The following error occurred while importing {module_name}: {e}"
)
if has_error:
if has_error: # pragma: no cover
print(
"Given the above errors, some methods may not work. Please run "
+ "'pip install rocketpy[sensitivity]' to install extra requirements."
Expand Down
81 changes: 70 additions & 11 deletions tests/unit/test_sensitivity.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from unittest.mock import patch

import numpy as np
import pytest

from rocketpy.sensitivity import SensitivityModel

# TODO: for some weird reason, these tests are not passing in the CI, but
# passing locally. Need to investigate why.


@pytest.mark.skip(reason="legacy test")
def test_initialization():
parameters_names = ["param1", "param2"]
target_variables_names = ["target1", "target2"]
Expand All @@ -21,7 +19,6 @@ def test_initialization():
assert not model._fitted


@pytest.mark.skip(reason="legacy test")
def test_set_parameters_nominal():
parameters_names = ["param1", "param2"]
target_variables_names = ["target1", "target2"]
Expand All @@ -35,8 +32,16 @@ def test_set_parameters_nominal():
assert model.parameters_info["param1"]["nominal_mean"] == 1.0
assert model.parameters_info["param2"]["nominal_sd"] == 0.2

# check dimensions mismatch error raise
incorrect_nominal_mean = np.array([1.0])
with pytest.raises(ValueError):
model.set_parameters_nominal(incorrect_nominal_mean, parameters_nominal_sd)

incorrect_nominal_sd = np.array([0.1])
with pytest.raises(ValueError):
model.set_parameters_nominal(parameters_nominal_mean, incorrect_nominal_sd)


@pytest.mark.skip(reason="legacy test")
def test_set_target_variables_nominal():
parameters_names = ["param1", "param2"]
target_variables_names = ["target1", "target2"]
Expand All @@ -49,9 +54,13 @@ def test_set_target_variables_nominal():
assert model.target_variables_info["target1"]["nominal_value"] == 10.0
assert model.target_variables_info["target2"]["nominal_value"] == 20.0

# check dimensions mismatch error raise
incorrect_nominal_value = np.array([10.0])
with pytest.raises(ValueError):
model.set_target_variables_nominal(incorrect_nominal_value)


@pytest.mark.skip(reason="legacy test")
def test_fit_method():
def test_fit_method_one_target():
parameters_names = ["param1", "param2"]
target_variables_names = ["target1"]
model = SensitivityModel(parameters_names, target_variables_names)
Expand All @@ -65,7 +74,20 @@ def test_fit_method():
assert model.number_of_samples == 3


@pytest.mark.skip(reason="legacy test")
def test_fit_method_multiple_target():
parameters_names = ["param1", "param2"]
target_variables_names = ["target1", "target2"]
model = SensitivityModel(parameters_names, target_variables_names)

parameters_matrix = np.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]])
target_data = np.array([[10.0, 12.0, 14.0], [11.0, 13.0, 17.0]]).T

model.fit(parameters_matrix, target_data)

assert model._fitted
assert model.number_of_samples == 3


def test_fit_raises_error_on_mismatched_dimensions():
parameters_names = ["param1", "param2"]
target_variables_names = ["target1"]
Expand All @@ -78,7 +100,6 @@ def test_fit_raises_error_on_mismatched_dimensions():
model.fit(parameters_matrix, target_data)


@pytest.mark.skip(reason="legacy test")
def test_check_conformity():
parameters_names = ["param1", "param2"]
target_variables_names = ["target1", "target2"]
Expand All @@ -90,7 +111,6 @@ def test_check_conformity():
model._SensitivityModel__check_conformity(parameters_matrix, target_data)


@pytest.mark.skip(reason="legacy test")
def test_check_conformity_raises_error():
parameters_names = ["param1", "param2"]
target_variables_names = ["target1", "target2"]
Expand All @@ -101,3 +121,42 @@ def test_check_conformity_raises_error():

with pytest.raises(ValueError):
model._SensitivityModel__check_conformity(parameters_matrix, target_data)

parameters_matrix2 = np.array([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]])

with pytest.raises(ValueError):
model._SensitivityModel__check_conformity(parameters_matrix2, target_data)

target_data2 = np.array([10.0, 12.0])

with pytest.raises(ValueError):
model._SensitivityModel__check_conformity(parameters_matrix, target_data2)

target_variables_names = ["target1"]
model = SensitivityModel(parameters_names, target_variables_names)

target_data = np.array([[10.0, 20.0], [12.0, 22.0], [14.0, 24.0]])

with pytest.raises(ValueError):
model._SensitivityModel__check_conformity(parameters_matrix, target_data)


@patch("matplotlib.pyplot.show")
def test_prints_and_plots(mock_show): # pylint: disable=unused-argument
parameters_names = ["param1", "param2"]
target_variables_names = ["target1"]
model = SensitivityModel(parameters_names, target_variables_names)

parameters_matrix = np.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]])
target_data = np.array([10.0, 12.0, 14.0])

# tests if an error is raised if summary is called before print
with pytest.raises(ValueError):
model.info()

model.fit(parameters_matrix, target_data)
assert model.all_info() is None

nominal_target = np.array([12.0])
model.set_target_variables_nominal(nominal_target)
assert model.all_info() is None

0 comments on commit 6863d6e

Please sign in to comment.