Skip to content

Commit

Permalink
fix(SMOTETomek): improve documentation and parameter typing
Browse files Browse the repository at this point in the history
  • Loading branch information
msorvoja committed Nov 25, 2024
1 parent 59bbda1 commit bb2ff19
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
19 changes: 11 additions & 8 deletions eis_toolkit/training_data_tools/class_balancing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pandas as pd
from beartype import beartype
from beartype.typing import Optional, Union
from beartype.typing import Literal, Optional, Union
from imblearn.combine import SMOTETomek

from eis_toolkit.exceptions import NonMatchingParameterLengthsException
Expand All @@ -11,24 +11,27 @@
def balance_SMOTETomek(
X: Union[pd.DataFrame, np.ndarray],
y: Union[pd.Series, np.ndarray],
sampling_strategy: Union[float, str, dict] = "auto",
sampling_strategy: Union[float, Literal["minority", "not minority", "not majority", "all", "auto"], dict] = "auto",
random_state: Optional[int] = None,
) -> tuple[Union[pd.DataFrame, np.ndarray], Union[pd.Series, np.ndarray]]:
"""Balances the classes of input dataset using SMOTETomek resampling method.
"""
Balances the classes of input dataset using SMOTETomek resampling method.
For more information about Imblearn SMOTETomek read the documentation here:
https://imbalanced-learn.org/stable/references/generated/imblearn.combine.SMOTETomek.html.
Args:
X: The feature matrix (input data as a DataFrame).
y: The target labels corresponding to the feature matrix.
X: Input feature data to be sampled.
y: Target labels corresponding to the input features.
sampling_strategy: Parameter controlling how to perform the resampling.
If float, specifies the ratio of samples in minority class to samples of majority class,
if str, specifies classes to be resampled ("minority", "not minority", "not majority", "all", "auto"),
if dict, the keys should be targeted classes and values the desired number of samples for the class.
Defaults to "auto", which will resample all classes except the majority class.
random_state: Parameter controlling randomization of the algorithm. Can be given a seed (number).
Defaults to None, which randomizes the seed.
random_state: Seed for random number generation. Defaults to None.
Returns:
Resampled feature matrix and target labels.
Resampled feature data and target labels.
Raises:
NonMatchingParameterLengthsException: If X and y have different length.
Expand Down
5 changes: 3 additions & 2 deletions tests/training_data_tools/class_balancing_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
from beartype.roar import BeartypeCallHintParamViolation
from sklearn.datasets import make_classification

from eis_toolkit.exceptions import NonMatchingParameterLengthsException
Expand Down Expand Up @@ -37,6 +38,6 @@ def test_invalid_label_length():


def test_invalid_sampling_strategy():
"""Test that invalid value for sampling strategy raises the correct exception (generated by imblearn)."""
with pytest.raises(ValueError):
"""Test that invalid value for sampling strategy raises the correct exception."""
with pytest.raises(BeartypeCallHintParamViolation):
balance_SMOTETomek(X, y, sampling_strategy="invalid_strategy")

0 comments on commit bb2ff19

Please sign in to comment.