diff --git a/mbs_results/validate_imputation.py b/mbs_results/validate_imputation.py new file mode 100644 index 00000000..1de65f25 --- /dev/null +++ b/mbs_results/validate_imputation.py @@ -0,0 +1,25 @@ +import pandas as pd + + +def validate_imputation(df: pd.DataFrame, target: str) -> None: + """ + Validation for the imputation, including: + - no missing values in target column + + Parameters + ---------- + df : pd.DataFrame + data with imputed values + target : str + name of column containing target variable + + Raises + ------ + """ + if df[target].isna().any(): + raise ValueError( + f""" + Target column should have no missing values following imputation: + missing values found in column {target} + """ + ) diff --git a/tests/data/validate_imputation/target_missing_values.csv b/tests/data/validate_imputation/target_missing_values.csv new file mode 100644 index 00000000..34e7f214 --- /dev/null +++ b/tests/data/validate_imputation/target_missing_values.csv @@ -0,0 +1,4 @@ +no_missing,one_missing,all_missing +11,14,, +12,15,, +13,, diff --git a/tests/helper_functions.py b/tests/helper_functions.py index cba874f3..aa7a41f7 100644 --- a/tests/helper_functions.py +++ b/tests/helper_functions.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from pathlib import Path import pandas as pd @@ -22,3 +23,10 @@ def load_filter(filter_path): df["date"] = pd.to_datetime(df["date"], format="%Y%m") return df + + +# when updating to python>=3.7 this can be replaced by importing +# contextlib.nullcontext as does_not_raise +@contextmanager +def does_not_raise(): + yield diff --git a/tests/test_validate_imputation.py b/tests/test_validate_imputation.py new file mode 100644 index 00000000..f8c2ded1 --- /dev/null +++ b/tests/test_validate_imputation.py @@ -0,0 +1,36 @@ +from pathlib import Path + +import pandas as pd +import pytest +from helper_functions import does_not_raise + +from mbs_results.validate_imputation import validate_imputation + + +@pytest.fixture(scope="class") +def filepath(): + return Path("tests/data/validate_imputation") + + +@pytest.fixture(scope="class") +def missing_target_values_data(filepath): + return pd.read_csv(filepath / "target_missing_values.csv", index_col=False) + + +class TestValidateImputation: + @pytest.mark.parametrize( + "target_column_name,expectation", + [ + ("no_missing", does_not_raise()), + ("one_missing", pytest.raises(ValueError)), + ("all_missing", pytest.raises(ValueError)), + ], + ) + def test_target_missing_values_validation( + self, + missing_target_values_data, + target_column_name, + expectation, + ): + with expectation: + validate_imputation(missing_target_values_data, target_column_name)