From 44758a87af1efbfc13c765840cfa61e6a2022f81 Mon Sep 17 00:00:00 2001 From: Gui-FernandesBR <63590233+Gui-FernandesBR@users.noreply.github.com> Date: Mon, 26 Feb 2024 13:52:26 -0500 Subject: [PATCH] MNT: refactor the remove outliers method --- rocketpy/mathutils/function.py | 41 +++++----------------------------- tests/unit/test_function.py | 4 ++-- 2 files changed, 8 insertions(+), 37 deletions(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 3a297a6f8..eff7b894a 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -1144,41 +1144,7 @@ def low_pass_filter(self, alpha, file_path=None): title=self.title, ) - def remove_outliers(self, method="iqr", **kwargs): - """Remove outliers from the Function source using the specified method. - - Parameters - ---------- - method : string, optional - Method to be used to remove outliers. Options are 'iqr'. - Default is 'iqr'. - **kwargs : optional - Keyword arguments to be passed to specific methods according to the - selected method. - If the selected method is the 'iqr', then the following kwargs are - available: - - threshold : float - Threshold for the interquartile range method. Default is 1.5. - - Returns - ------- - Function - The new Function object without outliers. - """ - if callable(self.source): - print("Cannot remove outliers if the source is a callable object.") - return self - - if method.lower() == "iqr": - return self.__remove_outliers_iqr(**kwargs) - else: - print( - f"Method '{method}' not recognized. No outliers removed." - + "Please use one of the following supported methods: 'iqr'." - ) - return self - - def __remove_outliers_iqr(self, threshold=1.5): + def remove_outliers_iqr(self, threshold=1.5): """Remove outliers from the Function source using the interquartile range method. @@ -1196,6 +1162,11 @@ def __remove_outliers_iqr(self, threshold=1.5): ---------- [1] https://en.wikipedia.org/wiki/Outlier#Tukey's_fences """ + + if callable(self.source): + print("Cannot remove outliers if the source is a callable object.") + return self + x = self.x_array y = self.y_array y_q1 = np.percentile(y, 25) diff --git a/tests/unit/test_function.py b/tests/unit/test_function.py index 4439c9162..8bcefb818 100644 --- a/tests/unit/test_function.py +++ b/tests/unit/test_function.py @@ -290,11 +290,11 @@ def test_set_discrete_based_on_model_non_mutator(linear_func): ], ) def test_remove_outliers_iqr(x, y, expected_x, expected_y): - """Test the function __remove_outliers_iqr which is expected to remove + """Test the function remove_outliers_iqr which is expected to remove outliers from the data based on the Interquartile Range (IQR) method. """ func = Function(source=np.column_stack((x, y))) - filtered_func = func.remove_outliers(method="iqr", threshold=1.5) + filtered_func = func.remove_outliers_iqr(threshold=1.5) # Check if the outliers are removed assert np.array_equal(filtered_func.x_array, expected_x)