Skip to content

Commit

Permalink
Merge pull request #554 from RocketPy-Team/enh/function-remove-outliers
Browse files Browse the repository at this point in the history
ENH: adds `Function.remove_outliers` method
  • Loading branch information
Gui-FernandesBR authored Feb 27, 2024
2 parents 35a9439 + fda6b33 commit 25374fa
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

### Added

- ENH: adds `Function.remove_outliers` method [#554](https://github.com/RocketPy-Team/RocketPy/pull/554)

### Changed
- ENH: Optional argument to show the plot in Function.compare_plots [#563](https://github.com/RocketPy-Team/RocketPy/pull/563)
Expand Down
45 changes: 45 additions & 0 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,51 @@ def low_pass_filter(self, alpha, file_path=None):
title=self.title,
)

def remove_outliers_iqr(self, threshold=1.5):
"""Remove outliers from the Function source using the interquartile
range method. The Function should have an array-like source.
Parameters
----------
threshold : float, optional
Threshold for the interquartile range method. Default is 1.5.
Returns
-------
Function
The Function with the outliers removed.
References
----------
[1] https://en.wikipedia.org/wiki/Outlier#Tukey's_fences
"""

if callable(self.source):
raise TypeError(
"Cannot remove outliers if the source is a callable object."
+ " The Function.source should be array-like."
)

x = self.x_array
y = self.y_array
y_q1 = np.percentile(y, 25)
y_q3 = np.percentile(y, 75)
y_iqr = y_q3 - y_q1
y_lower = y_q1 - threshold * y_iqr
y_upper = y_q3 + threshold * y_iqr

y_filtered = y[(y >= y_lower) & (y <= y_upper)]
x_filtered = x[(y >= y_lower) & (y <= y_upper)]

return Function(
source=np.column_stack((x_filtered, y_filtered)),
inputs=self.__inputs__,
outputs=self.__outputs__,
interpolation=self.__interpolation__,
extrapolation=self.__extrapolation__,
title=self.title,
)

# Define all presentation methods
def __call__(self, *args):
"""Plot the Function if no argument is given. If an
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,33 @@ def test_set_discrete_based_on_model_non_mutator(linear_func):
assert isinstance(func, Function)
assert discretized_func.source.shape == (4, 2)
assert callable(func.source)


@pytest.mark.parametrize(
"x, y, expected_x, expected_y",
[
(
np.array([1, 2, 3, 4, 5, 6]),
np.array([10, 20, 30, 40, 50000, 60]),
np.array([1, 2, 3, 4, 6]),
np.array([10, 20, 30, 40, 60]),
),
],
)
def test_remove_outliers_iqr(x, y, expected_x, expected_y):
"""Test the function remove_outliers_iqr which is expected to remove
outliers from the data based on the Interquartile Range (IQR) method.
"""
func = Function(source=np.column_stack((x, y)))
filtered_func = func.remove_outliers_iqr(threshold=1.5)

# Check if the outliers are removed
assert np.array_equal(filtered_func.x_array, expected_x)
assert np.array_equal(filtered_func.y_array, expected_y)

# Check if the other attributes are preserved
assert filtered_func.__inputs__ == func.__inputs__
assert filtered_func.__outputs__ == func.__outputs__
assert filtered_func.__interpolation__ == func.__interpolation__
assert filtered_func.__extrapolation__ == func.__extrapolation__
assert filtered_func.title == func.title

0 comments on commit 25374fa

Please sign in to comment.