diff --git a/CHANGELOG.md b/CHANGELOG.md index a3503a9e9..dafd84e02 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Added +- ENH: adds `Function.remove_outliers` method [#554](https://github.com/RocketPy-Team/RocketPy/pull/554) ### Changed diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index e8b6a9318..3a297a6f8 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -1144,6 +1144,78 @@ def low_pass_filter(self, alpha, file_path=None): title=self.title, ) + def remove_outliers(self, method="iqr", **kwargs): + """Remove outliers from the Function source using the specified method. + + Parameters + ---------- + method : string, optional + Method to be used to remove outliers. Options are 'iqr'. + Default is 'iqr'. + **kwargs : optional + Keyword arguments to be passed to specific methods according to the + selected method. + If the selected method is the 'iqr', then the following kwargs are + available: + - threshold : float + Threshold for the interquartile range method. Default is 1.5. + + Returns + ------- + Function + The new Function object without outliers. + """ + if callable(self.source): + print("Cannot remove outliers if the source is a callable object.") + return self + + if method.lower() == "iqr": + return self.__remove_outliers_iqr(**kwargs) + else: + print( + f"Method '{method}' not recognized. No outliers removed." + + "Please use one of the following supported methods: 'iqr'." + ) + return self + + def __remove_outliers_iqr(self, threshold=1.5): + """Remove outliers from the Function source using the interquartile + range method. + + Parameters + ---------- + threshold : float, optional + Threshold for the interquartile range method. Default is 1.5. + + Returns + ------- + Function + The Function with the outliers removed. + + References + ---------- + [1] https://en.wikipedia.org/wiki/Outlier#Tukey's_fences + """ + x = self.x_array + y = self.y_array + y_q1 = np.percentile(y, 25) + y_q3 = np.percentile(y, 75) + y_iqr = y_q3 - y_q1 + y_lower = y_q1 - threshold * y_iqr + y_upper = y_q3 + threshold * y_iqr + + y_filtered = y[(y >= y_lower) & (y <= y_upper)] + x_filtered = x[(y >= y_lower) & (y <= y_upper)] + + return Function( + source=np.column_stack((x_filtered, y_filtered)), + inputs=self.__inputs__, + outputs=self.__outputs__, + interpolation=self.__interpolation__, + extrapolation=self.__extrapolation__, + title=self.title, + ) + # Define all presentation methods def __call__(self, *args): """Plot the Function if no argument is given. If an diff --git a/tests/unit/test_function.py b/tests/unit/test_function.py index 1455212d9..4439c9162 100644 --- a/tests/unit/test_function.py +++ b/tests/unit/test_function.py @@ -276,3 +276,33 @@ def test_set_discrete_based_on_model_non_mutator(linear_func): assert isinstance(func, Function) assert discretized_func.source.shape == (4, 2) assert callable(func.source) + + +@pytest.mark.parametrize( + "x, y, expected_x, expected_y", + [ + ( + np.array([1, 2, 3, 4, 5, 6]), + np.array([10, 20, 30, 40, 50000, 60]), + np.array([1, 2, 3, 4, 6]), + np.array([10, 20, 30, 40, 60]), + ), + ], +) +def test_remove_outliers_iqr(x, y, expected_x, expected_y): + """Test the function __remove_outliers_iqr which is expected to remove + outliers from the data based on the Interquartile Range (IQR) method. + """ + func = Function(source=np.column_stack((x, y))) + filtered_func = func.remove_outliers(method="iqr", threshold=1.5) + + # Check if the outliers are removed + assert np.array_equal(filtered_func.x_array, expected_x) + assert np.array_equal(filtered_func.y_array, expected_y) + + # Check if the other attributes are preserved + assert filtered_func.__inputs__ == func.__inputs__ + assert filtered_func.__outputs__ == func.__outputs__ + assert filtered_func.__interpolation__ == func.__interpolation__ + assert filtered_func.__extrapolation__ == func.__extrapolation__ + assert filtered_func.title == func.title