diff --git a/CHANGELOG.md b/CHANGELOG.md index 863077840..7ff32bd8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Added +- ENH: adds `Function.remove_outliers` method [#554](https://github.com/RocketPy-Team/RocketPy/pull/554) ### Changed - ENH: Optional argument to show the plot in Function.compare_plots [#563](https://github.com/RocketPy-Team/RocketPy/pull/563) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index cd87d7598..cefed044d 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -1144,6 +1144,51 @@ def low_pass_filter(self, alpha, file_path=None): title=self.title, ) + def remove_outliers_iqr(self, threshold=1.5): + """Remove outliers from the Function source using the interquartile + range method. The Function should have an array-like source. + + Parameters + ---------- + threshold : float, optional + Threshold for the interquartile range method. Default is 1.5. + + Returns + ------- + Function + The Function with the outliers removed. + + References + ---------- + [1] https://en.wikipedia.org/wiki/Outlier#Tukey's_fences + """ + + if callable(self.source): + raise TypeError( + "Cannot remove outliers if the source is a callable object." + + " The Function.source should be array-like." + ) + + x = self.x_array + y = self.y_array + y_q1 = np.percentile(y, 25) + y_q3 = np.percentile(y, 75) + y_iqr = y_q3 - y_q1 + y_lower = y_q1 - threshold * y_iqr + y_upper = y_q3 + threshold * y_iqr + + y_filtered = y[(y >= y_lower) & (y <= y_upper)] + x_filtered = x[(y >= y_lower) & (y <= y_upper)] + + return Function( + source=np.column_stack((x_filtered, y_filtered)), + inputs=self.__inputs__, + outputs=self.__outputs__, + interpolation=self.__interpolation__, + extrapolation=self.__extrapolation__, + title=self.title, + ) + # Define all presentation methods def __call__(self, *args): """Plot the Function if no argument is given. If an diff --git a/tests/unit/test_function.py b/tests/unit/test_function.py index 1455212d9..8bcefb818 100644 --- a/tests/unit/test_function.py +++ b/tests/unit/test_function.py @@ -276,3 +276,33 @@ def test_set_discrete_based_on_model_non_mutator(linear_func): assert isinstance(func, Function) assert discretized_func.source.shape == (4, 2) assert callable(func.source) + + +@pytest.mark.parametrize( + "x, y, expected_x, expected_y", + [ + ( + np.array([1, 2, 3, 4, 5, 6]), + np.array([10, 20, 30, 40, 50000, 60]), + np.array([1, 2, 3, 4, 6]), + np.array([10, 20, 30, 40, 60]), + ), + ], +) +def test_remove_outliers_iqr(x, y, expected_x, expected_y): + """Test the function remove_outliers_iqr which is expected to remove + outliers from the data based on the Interquartile Range (IQR) method. + """ + func = Function(source=np.column_stack((x, y))) + filtered_func = func.remove_outliers_iqr(threshold=1.5) + + # Check if the outliers are removed + assert np.array_equal(filtered_func.x_array, expected_x) + assert np.array_equal(filtered_func.y_array, expected_y) + + # Check if the other attributes are preserved + assert filtered_func.__inputs__ == func.__inputs__ + assert filtered_func.__outputs__ == func.__outputs__ + assert filtered_func.__interpolation__ == func.__interpolation__ + assert filtered_func.__extrapolation__ == func.__extrapolation__ + assert filtered_func.title == func.title