Skip to content

Commit

Permalink
Add auto_fill_nulls argument. Ref #114
Browse files Browse the repository at this point in the history
Missing values are now treated as a category for categorical values.
  • Loading branch information
tompollard committed Jun 13, 2024
1 parent e7b7683 commit 44ef61f
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 15 deletions.
16 changes: 16 additions & 0 deletions tableone/preprocessors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pandas as pd

from tableone.exceptions import InputError

Expand Down Expand Up @@ -99,3 +100,18 @@ def get_groups(data, groupby, order, reserved_columns):
groupbylvls = ['Overall']

return groupbylvls


def handle_categorical_nulls(df: pd.DataFrame, null_value: str = 'None') -> pd.DataFrame:
"""
Convert None/Null values in specified categorical columns to a given string,
so they are treated as an additional category.
Parameters:
- data (pd.DataFrame): The DataFrame containing the categorical data.
- null_value (str): The string to replace null values with. Default is 'None'.
Returns:
- pd.DataFrame: The modified DataFrame if not inplace, otherwise None.
"""
return df.fillna(null_value)
33 changes: 24 additions & 9 deletions tableone/tableone.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from tabulate import tabulate

from tableone.deprecations import handle_deprecated_parameters
from tableone.preprocessors import ensure_list, detect_categorical, order_categorical, get_groups
from tableone.preprocessors import (ensure_list, detect_categorical, order_categorical,
get_groups, handle_categorical_nulls)
from tableone.statistics import Statistics
from tableone.tables import Tables
from tableone.validators import DataValidator, InputValidator
Expand Down Expand Up @@ -168,6 +169,10 @@ class TableOne:
Run Tukey's test for far outliers. If variables are found to
have far outliers, a remark will be added below the Table 1.
(default: False)
auto_fill_nulls : bool, optional
Attempt to automatically handle None/Null values in categorical columns
by treating them as a category named 'None'. (default: True)
Attributes
----------
Expand Down Expand Up @@ -219,7 +224,8 @@ def __init__(self, data: pd.DataFrame,
row_percent: bool = False, display_all: bool = False,
dip_test: bool = False, normal_test: bool = False,
tukey_test: bool = False,
pval_threshold: Optional[float] = None) -> None:
pval_threshold: Optional[float] = None,
auto_fill_nulls: Optional[bool] = True) -> None:

# Warn about deprecated parameters
handle_deprecated_parameters(labels, isnull, pval_test_name, remarks)
Expand All @@ -229,11 +235,12 @@ def __init__(self, data: pd.DataFrame,
self.tables = Tables()

# Initialize attributes
self.initialize_core_attributes(data, columns, categorical, continuous, groupby,
nonnormal, min_max, pval, pval_adjust, htest_name,
htest, missing, ddof, rename, sort, limit, order,
label_suffix, decimals, smd, overall, row_percent,
dip_test, normal_test, tukey_test, pval_threshold)
data = self.initialize_core_attributes(data, columns, categorical, continuous, groupby,
nonnormal, min_max, pval, pval_adjust, htest_name,
htest, missing, ddof, rename, sort, limit, order,
label_suffix, decimals, smd, overall, row_percent,
dip_test, normal_test, tukey_test, pval_threshold,
auto_fill_nulls)

# Initialize intermediate tables
self.initialize_intermediate_tables()
Expand Down Expand Up @@ -274,11 +281,13 @@ def initialize_core_attributes(self, data, columns, categorical, continuous, gro
nonnormal, min_max, pval, pval_adjust, htest_name,
htest, missing, ddof, rename, sort, limit, order,
label_suffix, decimals, smd, overall, row_percent,
dip_test, normal_test, tukey_test, pval_threshold):
dip_test, normal_test, tukey_test, pval_threshold,
auto_fill_nulls):
"""
Initialize attributes.
"""
self._alt_labels = rename
self._auto_fill_nulls = auto_fill_nulls
self._columns = columns if columns else data.columns.to_list() # type: ignore
self._categorical = detect_categorical(data[self._columns], groupby) if categorical is None else categorical
if continuous:
Expand Down Expand Up @@ -308,8 +317,14 @@ def initialize_core_attributes(self, data, columns, categorical, continuous, gro
self._sort = sort
self._tukey_test = tukey_test
self._warnings = {}

if self._categorical and self._auto_fill_nulls:
data[self._categorical] = handle_categorical_nulls(data[self._categorical])

self._groupbylvls = get_groups(data, self._groupby, self._order, self._reserved_columns)

return data

def initialize_intermediate_tables(self):
"""
Initialize the intermediate tables.
Expand All @@ -332,7 +347,7 @@ def validate_data(self, data):
self.input_validator.validate(self._groupby, self._nonnormal, self._min_max, # type: ignore
self._pval_adjust, self._order, self._pval, # type: ignore
self._columns, self._categorical, self._continuous) # type: ignore
self.data_validator.validate(data, self._columns, self._categorical) # type: ignore
self.data_validator.validate(data, self._columns, self._categorical, self._auto_fill_nulls) # type: ignore

def create_intermediate_tables(self, data):
"""
Expand Down
14 changes: 8 additions & 6 deletions tableone/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def __init__(self):
pass

def validate(self, data: pd.DataFrame, columns: list,
categorical: Optional[List[str]] = None) -> None:
categorical: list,
auto_fill_nulls: bool) -> None:
"""
Check the input dataset for obvious issues.
Expand All @@ -23,7 +24,7 @@ def validate(self, data: pd.DataFrame, columns: list,
self.check_unique_index(data)
self.check_columns_exist(data, columns)
self.check_duplicate_columns(data, columns)
if categorical:
if categorical and not auto_fill_nulls:
self.check_categorical_none(data, categorical)

def check_categorical_none(self, data: pd.DataFrame, categorical: List[str]):
Expand All @@ -34,10 +35,11 @@ def check_categorical_none(self, data: pd.DataFrame, categorical: List[str]):
data (pd.DataFrame): The DataFrame to check.
categorical (List[str]): The list of categorical columns to validate.
"""
none_containing_cols = [col for col in categorical if data[col].isnull().any()]
if none_containing_cols:
raise InputError(f"The following categorical columns contains one or more 'None' values. These values "
f"must be converted to a string before processing: {none_containing_cols}. e.g. use "
contains_none = [col for col in categorical if data[col].isnull().any()]
if contains_none:
raise InputError(f"The following categorical columns contains one or more null values: {contains_none}. "
f"These must be converted to strings before processing. Either set "
f"`auto_fill_nulls = True` or manually convert nulls to strings with: "
f"data[categorical_columns] = data[categorical_columns].fillna('None')")

def validate_input(self, data: pd.DataFrame):
Expand Down

0 comments on commit 44ef61f

Please sign in to comment.