From 44ef61fa6321b5680c9d33a2a269748c66eebe17 Mon Sep 17 00:00:00 2001 From: Tom Pollard Date: Thu, 13 Jun 2024 02:41:19 -0400 Subject: [PATCH] Add auto_fill_nulls argument. Ref #114 Missing values are now treated as a category for categorical values. --- tableone/preprocessors.py | 16 ++++++++++++++++ tableone/tableone.py | 33 ++++++++++++++++++++++++--------- tableone/validators.py | 14 ++++++++------ 3 files changed, 48 insertions(+), 15 deletions(-) diff --git a/tableone/preprocessors.py b/tableone/preprocessors.py index caf7646..ed8b18f 100644 --- a/tableone/preprocessors.py +++ b/tableone/preprocessors.py @@ -1,4 +1,5 @@ import numpy as np +import pandas as pd from tableone.exceptions import InputError @@ -99,3 +100,18 @@ def get_groups(data, groupby, order, reserved_columns): groupbylvls = ['Overall'] return groupbylvls + + +def handle_categorical_nulls(df: pd.DataFrame, null_value: str = 'None') -> pd.DataFrame: + """ + Convert None/Null values in specified categorical columns to a given string, + so they are treated as an additional category. + + Parameters: + - data (pd.DataFrame): The DataFrame containing the categorical data. + - null_value (str): The string to replace null values with. Default is 'None'. + + Returns: + - pd.DataFrame: The modified DataFrame if not inplace, otherwise None. + """ + return df.fillna(null_value) diff --git a/tableone/tableone.py b/tableone/tableone.py index 2bf43fc..948d49a 100644 --- a/tableone/tableone.py +++ b/tableone/tableone.py @@ -10,7 +10,8 @@ from tabulate import tabulate from tableone.deprecations import handle_deprecated_parameters -from tableone.preprocessors import ensure_list, detect_categorical, order_categorical, get_groups +from tableone.preprocessors import (ensure_list, detect_categorical, order_categorical, + get_groups, handle_categorical_nulls) from tableone.statistics import Statistics from tableone.tables import Tables from tableone.validators import DataValidator, InputValidator @@ -168,6 +169,10 @@ class TableOne: Run Tukey's test for far outliers. If variables are found to have far outliers, a remark will be added below the Table 1. (default: False) + auto_fill_nulls : bool, optional + Attempt to automatically handle None/Null values in categorical columns + by treating them as a category named 'None'. (default: True) + Attributes ---------- @@ -219,7 +224,8 @@ def __init__(self, data: pd.DataFrame, row_percent: bool = False, display_all: bool = False, dip_test: bool = False, normal_test: bool = False, tukey_test: bool = False, - pval_threshold: Optional[float] = None) -> None: + pval_threshold: Optional[float] = None, + auto_fill_nulls: Optional[bool] = True) -> None: # Warn about deprecated parameters handle_deprecated_parameters(labels, isnull, pval_test_name, remarks) @@ -229,11 +235,12 @@ def __init__(self, data: pd.DataFrame, self.tables = Tables() # Initialize attributes - self.initialize_core_attributes(data, columns, categorical, continuous, groupby, - nonnormal, min_max, pval, pval_adjust, htest_name, - htest, missing, ddof, rename, sort, limit, order, - label_suffix, decimals, smd, overall, row_percent, - dip_test, normal_test, tukey_test, pval_threshold) + data = self.initialize_core_attributes(data, columns, categorical, continuous, groupby, + nonnormal, min_max, pval, pval_adjust, htest_name, + htest, missing, ddof, rename, sort, limit, order, + label_suffix, decimals, smd, overall, row_percent, + dip_test, normal_test, tukey_test, pval_threshold, + auto_fill_nulls) # Initialize intermediate tables self.initialize_intermediate_tables() @@ -274,11 +281,13 @@ def initialize_core_attributes(self, data, columns, categorical, continuous, gro nonnormal, min_max, pval, pval_adjust, htest_name, htest, missing, ddof, rename, sort, limit, order, label_suffix, decimals, smd, overall, row_percent, - dip_test, normal_test, tukey_test, pval_threshold): + dip_test, normal_test, tukey_test, pval_threshold, + auto_fill_nulls): """ Initialize attributes. """ self._alt_labels = rename + self._auto_fill_nulls = auto_fill_nulls self._columns = columns if columns else data.columns.to_list() # type: ignore self._categorical = detect_categorical(data[self._columns], groupby) if categorical is None else categorical if continuous: @@ -308,8 +317,14 @@ def initialize_core_attributes(self, data, columns, categorical, continuous, gro self._sort = sort self._tukey_test = tukey_test self._warnings = {} + + if self._categorical and self._auto_fill_nulls: + data[self._categorical] = handle_categorical_nulls(data[self._categorical]) + self._groupbylvls = get_groups(data, self._groupby, self._order, self._reserved_columns) + return data + def initialize_intermediate_tables(self): """ Initialize the intermediate tables. @@ -332,7 +347,7 @@ def validate_data(self, data): self.input_validator.validate(self._groupby, self._nonnormal, self._min_max, # type: ignore self._pval_adjust, self._order, self._pval, # type: ignore self._columns, self._categorical, self._continuous) # type: ignore - self.data_validator.validate(data, self._columns, self._categorical) # type: ignore + self.data_validator.validate(data, self._columns, self._categorical, self._auto_fill_nulls) # type: ignore def create_intermediate_tables(self, data): """ diff --git a/tableone/validators.py b/tableone/validators.py index bc69fa8..d738293 100644 --- a/tableone/validators.py +++ b/tableone/validators.py @@ -11,7 +11,8 @@ def __init__(self): pass def validate(self, data: pd.DataFrame, columns: list, - categorical: Optional[List[str]] = None) -> None: + categorical: list, + auto_fill_nulls: bool) -> None: """ Check the input dataset for obvious issues. @@ -23,7 +24,7 @@ def validate(self, data: pd.DataFrame, columns: list, self.check_unique_index(data) self.check_columns_exist(data, columns) self.check_duplicate_columns(data, columns) - if categorical: + if categorical and not auto_fill_nulls: self.check_categorical_none(data, categorical) def check_categorical_none(self, data: pd.DataFrame, categorical: List[str]): @@ -34,10 +35,11 @@ def check_categorical_none(self, data: pd.DataFrame, categorical: List[str]): data (pd.DataFrame): The DataFrame to check. categorical (List[str]): The list of categorical columns to validate. """ - none_containing_cols = [col for col in categorical if data[col].isnull().any()] - if none_containing_cols: - raise InputError(f"The following categorical columns contains one or more 'None' values. These values " - f"must be converted to a string before processing: {none_containing_cols}. e.g. use " + contains_none = [col for col in categorical if data[col].isnull().any()] + if contains_none: + raise InputError(f"The following categorical columns contains one or more null values: {contains_none}. " + f"These must be converted to strings before processing. Either set " + f"`auto_fill_nulls = True` or manually convert nulls to strings with: " f"data[categorical_columns] = data[categorical_columns].fillna('None')") def validate_input(self, data: pd.DataFrame):