From 4278de353f1bd659a51fde0d5a538bf711780a5a Mon Sep 17 00:00:00 2001 From: Gui-FernandesBR Date: Mon, 12 Feb 2024 12:48:22 -0500 Subject: [PATCH 1/3] ENH: adds `Function.remove_outliers` method --- CHANGELOG.md | 1 + rocketpy/mathutils/function.py | 72 ++++++++++++++++++++++++++++++++++ tests/unit/test_function.py | 30 ++++++++++++++ 3 files changed, 103 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 863077840..7ff32bd8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Added +- ENH: adds `Function.remove_outliers` method [#554](https://github.com/RocketPy-Team/RocketPy/pull/554) ### Changed - ENH: Optional argument to show the plot in Function.compare_plots [#563](https://github.com/RocketPy-Team/RocketPy/pull/563) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index cd87d7598..696ccbcb0 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -1144,6 +1144,78 @@ def low_pass_filter(self, alpha, file_path=None): title=self.title, ) + def remove_outliers(self, method="iqr", **kwargs): + """Remove outliers from the Function source using the specified method. + + Parameters + ---------- + method : string, optional + Method to be used to remove outliers. Options are 'iqr'. + Default is 'iqr'. + **kwargs : optional + Keyword arguments to be passed to specific methods according to the + selected method. + If the selected method is the 'iqr', then the following kwargs are + available: + - threshold : float + Threshold for the interquartile range method. Default is 1.5. + + Returns + ------- + Function + The new Function object without outliers. + """ + if callable(self.source): + print("Cannot remove outliers if the source is a callable object.") + return self + + if method.lower() == "iqr": + return self.__remove_outliers_iqr(**kwargs) + else: + print( + f"Method '{method}' not recognized. No outliers removed." + + "Please use one of the following supported methods: 'iqr'." + ) + return self + + def __remove_outliers_iqr(self, threshold=1.5): + """Remove outliers from the Function source using the interquartile + range method. + + Parameters + ---------- + threshold : float, optional + Threshold for the interquartile range method. Default is 1.5. + + Returns + ------- + Function + The Function with the outliers removed. + + References + ---------- + [1] https://en.wikipedia.org/wiki/Outlier#Tukey's_fences + """ + x = self.x_array + y = self.y_array + y_q1 = np.percentile(y, 25) + y_q3 = np.percentile(y, 75) + y_iqr = y_q3 - y_q1 + y_lower = y_q1 - threshold * y_iqr + y_upper = y_q3 + threshold * y_iqr + + y_filtered = y[(y >= y_lower) & (y <= y_upper)] + x_filtered = x[(y >= y_lower) & (y <= y_upper)] + + return Function( + source=np.column_stack((x_filtered, y_filtered)), + inputs=self.__inputs__, + outputs=self.__outputs__, + interpolation=self.__interpolation__, + extrapolation=self.__extrapolation__, + title=self.title, + ) + # Define all presentation methods def __call__(self, *args): """Plot the Function if no argument is given. If an diff --git a/tests/unit/test_function.py b/tests/unit/test_function.py index 1455212d9..4439c9162 100644 --- a/tests/unit/test_function.py +++ b/tests/unit/test_function.py @@ -276,3 +276,33 @@ def test_set_discrete_based_on_model_non_mutator(linear_func): assert isinstance(func, Function) assert discretized_func.source.shape == (4, 2) assert callable(func.source) + + +@pytest.mark.parametrize( + "x, y, expected_x, expected_y", + [ + ( + np.array([1, 2, 3, 4, 5, 6]), + np.array([10, 20, 30, 40, 50000, 60]), + np.array([1, 2, 3, 4, 6]), + np.array([10, 20, 30, 40, 60]), + ), + ], +) +def test_remove_outliers_iqr(x, y, expected_x, expected_y): + """Test the function __remove_outliers_iqr which is expected to remove + outliers from the data based on the Interquartile Range (IQR) method. + """ + func = Function(source=np.column_stack((x, y))) + filtered_func = func.remove_outliers(method="iqr", threshold=1.5) + + # Check if the outliers are removed + assert np.array_equal(filtered_func.x_array, expected_x) + assert np.array_equal(filtered_func.y_array, expected_y) + + # Check if the other attributes are preserved + assert filtered_func.__inputs__ == func.__inputs__ + assert filtered_func.__outputs__ == func.__outputs__ + assert filtered_func.__interpolation__ == func.__interpolation__ + assert filtered_func.__extrapolation__ == func.__extrapolation__ + assert filtered_func.title == func.title From f128c5ee380a7df3a35d8509f657f2e4632c64ce Mon Sep 17 00:00:00 2001 From: Gui-FernandesBR <63590233+Gui-FernandesBR@users.noreply.github.com> Date: Mon, 26 Feb 2024 13:52:26 -0500 Subject: [PATCH 2/3] MNT: refactor the remove outliers method --- rocketpy/mathutils/function.py | 41 +++++----------------------------- tests/unit/test_function.py | 4 ++-- 2 files changed, 8 insertions(+), 37 deletions(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 696ccbcb0..1f6d4d3f6 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -1144,41 +1144,7 @@ def low_pass_filter(self, alpha, file_path=None): title=self.title, ) - def remove_outliers(self, method="iqr", **kwargs): - """Remove outliers from the Function source using the specified method. - - Parameters - ---------- - method : string, optional - Method to be used to remove outliers. Options are 'iqr'. - Default is 'iqr'. - **kwargs : optional - Keyword arguments to be passed to specific methods according to the - selected method. - If the selected method is the 'iqr', then the following kwargs are - available: - - threshold : float - Threshold for the interquartile range method. Default is 1.5. - - Returns - ------- - Function - The new Function object without outliers. - """ - if callable(self.source): - print("Cannot remove outliers if the source is a callable object.") - return self - - if method.lower() == "iqr": - return self.__remove_outliers_iqr(**kwargs) - else: - print( - f"Method '{method}' not recognized. No outliers removed." - + "Please use one of the following supported methods: 'iqr'." - ) - return self - - def __remove_outliers_iqr(self, threshold=1.5): + def remove_outliers_iqr(self, threshold=1.5): """Remove outliers from the Function source using the interquartile range method. @@ -1196,6 +1162,11 @@ def __remove_outliers_iqr(self, threshold=1.5): ---------- [1] https://en.wikipedia.org/wiki/Outlier#Tukey's_fences """ + + if callable(self.source): + print("Cannot remove outliers if the source is a callable object.") + return self + x = self.x_array y = self.y_array y_q1 = np.percentile(y, 25) diff --git a/tests/unit/test_function.py b/tests/unit/test_function.py index 4439c9162..8bcefb818 100644 --- a/tests/unit/test_function.py +++ b/tests/unit/test_function.py @@ -290,11 +290,11 @@ def test_set_discrete_based_on_model_non_mutator(linear_func): ], ) def test_remove_outliers_iqr(x, y, expected_x, expected_y): - """Test the function __remove_outliers_iqr which is expected to remove + """Test the function remove_outliers_iqr which is expected to remove outliers from the data based on the Interquartile Range (IQR) method. """ func = Function(source=np.column_stack((x, y))) - filtered_func = func.remove_outliers(method="iqr", threshold=1.5) + filtered_func = func.remove_outliers_iqr(threshold=1.5) # Check if the outliers are removed assert np.array_equal(filtered_func.x_array, expected_x) From fda6b339611aff741eee79b36174fdff98197088 Mon Sep 17 00:00:00 2001 From: Gui-FernandesBR <63590233+Gui-FernandesBR@users.noreply.github.com> Date: Tue, 27 Feb 2024 18:11:09 +0000 Subject: [PATCH 3/3] MNT: update the error handling in the remove_outliers_iqr --- rocketpy/mathutils/function.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 1f6d4d3f6..cefed044d 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -1146,7 +1146,7 @@ def low_pass_filter(self, alpha, file_path=None): def remove_outliers_iqr(self, threshold=1.5): """Remove outliers from the Function source using the interquartile - range method. + range method. The Function should have an array-like source. Parameters ---------- @@ -1164,8 +1164,10 @@ def remove_outliers_iqr(self, threshold=1.5): """ if callable(self.source): - print("Cannot remove outliers if the source is a callable object.") - return self + raise TypeError( + "Cannot remove outliers if the source is a callable object." + + " The Function.source should be array-like." + ) x = self.x_array y = self.y_array