diff --git a/eis_toolkit/training_data_tools/class_balancing.py b/eis_toolkit/training_data_tools/class_balancing.py index 3120bf38..f8b80c10 100644 --- a/eis_toolkit/training_data_tools/class_balancing.py +++ b/eis_toolkit/training_data_tools/class_balancing.py @@ -1,7 +1,7 @@ import numpy as np import pandas as pd from beartype import beartype -from beartype.typing import Optional, Union +from beartype.typing import Literal, Optional, Union from imblearn.combine import SMOTETomek from eis_toolkit.exceptions import NonMatchingParameterLengthsException @@ -11,24 +11,27 @@ def balance_SMOTETomek( X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray], - sampling_strategy: Union[float, str, dict] = "auto", + sampling_strategy: Union[float, Literal["minority", "not minority", "not majority", "all", "auto"], dict] = "auto", random_state: Optional[int] = None, ) -> tuple[Union[pd.DataFrame, np.ndarray], Union[pd.Series, np.ndarray]]: - """Balances the classes of input dataset using SMOTETomek resampling method. + """ + Balances the classes of input dataset using SMOTETomek resampling method. + + For more information about Imblearn SMOTETomek read the documentation here: + https://imbalanced-learn.org/stable/references/generated/imblearn.combine.SMOTETomek.html. Args: - X: The feature matrix (input data as a DataFrame). - y: The target labels corresponding to the feature matrix. + X: Input feature data to be sampled. + y: Target labels corresponding to the input features. sampling_strategy: Parameter controlling how to perform the resampling. If float, specifies the ratio of samples in minority class to samples of majority class, if str, specifies classes to be resampled ("minority", "not minority", "not majority", "all", "auto"), if dict, the keys should be targeted classes and values the desired number of samples for the class. Defaults to "auto", which will resample all classes except the majority class. - random_state: Parameter controlling randomization of the algorithm. Can be given a seed (number). - Defaults to None, which randomizes the seed. + random_state: Seed for random number generation. Defaults to None. Returns: - Resampled feature matrix and target labels. + Resampled feature data and target labels. Raises: NonMatchingParameterLengthsException: If X and y have different length. diff --git a/tests/training_data_tools/class_balancing_test.py b/tests/training_data_tools/class_balancing_test.py index 6703d433..b2ce3d49 100644 --- a/tests/training_data_tools/class_balancing_test.py +++ b/tests/training_data_tools/class_balancing_test.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from beartype.roar import BeartypeCallHintParamViolation from sklearn.datasets import make_classification from eis_toolkit.exceptions import NonMatchingParameterLengthsException @@ -37,6 +38,6 @@ def test_invalid_label_length(): def test_invalid_sampling_strategy(): - """Test that invalid value for sampling strategy raises the correct exception (generated by imblearn).""" - with pytest.raises(ValueError): + """Test that invalid value for sampling strategy raises the correct exception.""" + with pytest.raises(BeartypeCallHintParamViolation): balance_SMOTETomek(X, y, sampling_strategy="invalid_strategy")