diff --git a/tableone/tableone.py b/tableone/tableone.py index 7472137..2bf43fc 100644 --- a/tableone/tableone.py +++ b/tableone/tableone.py @@ -329,10 +329,10 @@ def setup_validators(self): self.input_validator = InputValidator() def validate_data(self, data): - self.data_validator.validate(data, self._columns) # type: ignore self.input_validator.validate(self._groupby, self._nonnormal, self._min_max, # type: ignore self._pval_adjust, self._order, self._pval, # type: ignore self._columns, self._categorical, self._continuous) # type: ignore + self.data_validator.validate(data, self._columns, self._categorical) # type: ignore def create_intermediate_tables(self, data): """ diff --git a/tableone/validators.py b/tableone/validators.py index a70dc82..bc69fa8 100644 --- a/tableone/validators.py +++ b/tableone/validators.py @@ -10,7 +10,8 @@ def __init__(self): """Initialize the DataValidator class.""" pass - def validate(self, data: pd.DataFrame, columns: list) -> None: + def validate(self, data: pd.DataFrame, columns: list, + categorical: Optional[List[str]] = None) -> None: """ Check the input dataset for obvious issues. @@ -22,6 +23,22 @@ def validate(self, data: pd.DataFrame, columns: list) -> None: self.check_unique_index(data) self.check_columns_exist(data, columns) self.check_duplicate_columns(data, columns) + if categorical: + self.check_categorical_none(data, categorical) + + def check_categorical_none(self, data: pd.DataFrame, categorical: List[str]): + """ + Ensure that categorical columns do not contain None values. + + Parameters: + data (pd.DataFrame): The DataFrame to check. + categorical (List[str]): The list of categorical columns to validate. + """ + none_containing_cols = [col for col in categorical if data[col].isnull().any()] + if none_containing_cols: + raise InputError(f"The following categorical columns contains one or more 'None' values. These values " + f"must be converted to a string before processing: {none_containing_cols}. e.g. use " + f"data[categorical_columns] = data[categorical_columns].fillna('None')") def validate_input(self, data: pd.DataFrame): if not isinstance(data, pd.DataFrame):