Skip to content

Commit

Permalink
ENH: adds Function.remove_outliers method
Browse files Browse the repository at this point in the history
  • Loading branch information
Gui-FernandesBR committed Feb 12, 2024
1 parent e4e67f4 commit c9ad317
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

### Added

- ENH: adds `Function.remove_outliers` method [#554](https://github.com/RocketPy-Team/RocketPy/pull/554)

### Changed

Expand Down
72 changes: 72 additions & 0 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,78 @@ def low_pass_filter(self, alpha, file_path=None):
title=self.title,
)

def remove_outliers(self, method="iqr", **kwargs):
"""Remove outliers from the Function source using the specified method.
Parameters
----------
method : string, optional
Method to be used to remove outliers. Options are 'iqr'.
Default is 'iqr'.
**kwargs : optional
Keyword arguments to be passed to specific methods according to the
selected method.
If the selected method is the 'iqr', then the following kwargs are
available:
- threshold : float
Threshold for the interquartile range method. Default is 1.5.
Returns
-------
Function
The new Function object without outliers.
"""
if callable(self.source):
print("Cannot remove outliers if the source is a callable object.")
return self

if method.lower() == "iqr":

Check warning on line 1170 in rocketpy/mathutils/function.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/mathutils/function.py#L1169-L1170

Added lines #L1169 - L1170 were not covered by tests
return self.__remove_outliers_iqr(**kwargs)
else:
print(
f"Method '{method}' not recognized. No outliers removed."
+ "Please use one of the following supported methods: 'iqr'."

Check warning on line 1175 in rocketpy/mathutils/function.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/mathutils/function.py#L1175

Added line #L1175 was not covered by tests
)
return self

def __remove_outliers_iqr(self, threshold=1.5):

Check warning on line 1179 in rocketpy/mathutils/function.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/mathutils/function.py#L1179

Added line #L1179 was not covered by tests
"""Remove outliers from the Function source using the interquartile
range method.
Parameters
----------
threshold : float, optional
Threshold for the interquartile range method. Default is 1.5.
Returns
-------
Function
The Function with the outliers removed.
References
----------
[1] https://en.wikipedia.org/wiki/Outlier#Tukey's_fences
"""
x = self.x_array
y = self.y_array
y_q1 = np.percentile(y, 25)
y_q3 = np.percentile(y, 75)
y_iqr = y_q3 - y_q1
y_lower = y_q1 - threshold * y_iqr
y_upper = y_q3 + threshold * y_iqr

y_filtered = y[(y >= y_lower) & (y <= y_upper)]
x_filtered = x[(y >= y_lower) & (y <= y_upper)]

return Function(
source=np.column_stack((x_filtered, y_filtered)),
inputs=self.__inputs__,
outputs=self.__outputs__,
interpolation=self.__interpolation__,
extrapolation=self.__extrapolation__,
title=self.title,
)

# Define all presentation methods
def __call__(self, *args):
"""Plot the Function if no argument is given. If an
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,33 @@ def test_set_discrete_based_on_model_non_mutator(linear_func):
assert isinstance(func, Function)
assert discretized_func.source.shape == (4, 2)
assert callable(func.source)


@pytest.mark.parametrize(
"x, y, expected_x, expected_y",
[
(
np.array([1, 2, 3, 4, 5, 6]),
np.array([10, 20, 30, 40, 50000, 60]),
np.array([1, 2, 3, 4, 6]),
np.array([10, 20, 30, 40, 60]),
),
],
)
def test_remove_outliers_iqr(x, y, expected_x, expected_y):
"""Test the function __remove_outliers_iqr which is expected to remove
outliers from the data based on the Interquartile Range (IQR) method.
"""
func = Function(source=np.column_stack((x, y)))
filtered_func = func.remove_outliers(method="iqr", threshold=1.5)

# Check if the outliers are removed
assert np.array_equal(filtered_func.x_array, expected_x)
assert np.array_equal(filtered_func.y_array, expected_y)

# Check if the other attributes are preserved
assert filtered_func.__inputs__ == func.__inputs__
assert filtered_func.__outputs__ == func.__outputs__
assert filtered_func.__interpolation__ == func.__interpolation__
assert filtered_func.__extrapolation__ == func.__extrapolation__
assert filtered_func.title == func.title

0 comments on commit c9ad317

Please sign in to comment.