Skip to content

Commit

Permalink
Merge pull request #175 from tompollard/tp/auto_fill_nulls
Browse files Browse the repository at this point in the history
Add include_null argument to handle nulls for categorical values. Ref #114.
  • Loading branch information
tompollard authored Jun 14, 2024
2 parents bd27742 + 02a7326 commit 6fc6a30
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 23 deletions.
16 changes: 16 additions & 0 deletions tableone/preprocessors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pandas as pd

from tableone.exceptions import InputError

Expand Down Expand Up @@ -99,3 +100,18 @@ def get_groups(data, groupby, order, reserved_columns):
groupbylvls = ['Overall']

return groupbylvls


def handle_categorical_nulls(df: pd.DataFrame, null_value: str = 'None') -> pd.DataFrame:
"""
Convert None/Null values in specified categorical columns to a given string,
so they are treated as an additional category.
Parameters:
- data (pd.DataFrame): The DataFrame containing the categorical data.
- null_value (str): The string to replace null values with. Default is 'None'.
Returns:
- pd.DataFrame: The modified DataFrame if not inplace, otherwise None.
"""
return df.fillna(null_value)
36 changes: 27 additions & 9 deletions tableone/tableone.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from tabulate import tabulate

from tableone.deprecations import handle_deprecated_parameters
from tableone.preprocessors import ensure_list, detect_categorical, order_categorical, get_groups
from tableone.preprocessors import (ensure_list, detect_categorical, order_categorical,
get_groups, handle_categorical_nulls)
from tableone.statistics import Statistics
from tableone.tables import Tables
from tableone.validators import DataValidator, InputValidator
Expand Down Expand Up @@ -168,6 +169,10 @@ class TableOne:
Run Tukey's test for far outliers. If variables are found to
have far outliers, a remark will be added below the Table 1.
(default: False)
include_null : bool, optional
Include None/Null values for categorical variables by treating them as a
category level. (default: True)
Attributes
----------
Expand Down Expand Up @@ -219,7 +224,8 @@ def __init__(self, data: pd.DataFrame,
row_percent: bool = False, display_all: bool = False,
dip_test: bool = False, normal_test: bool = False,
tukey_test: bool = False,
pval_threshold: Optional[float] = None) -> None:
pval_threshold: Optional[float] = None,
include_null: Optional[bool] = True) -> None:

# Warn about deprecated parameters
handle_deprecated_parameters(labels, isnull, pval_test_name, remarks)
Expand All @@ -229,11 +235,12 @@ def __init__(self, data: pd.DataFrame,
self.tables = Tables()

# Initialize attributes
self.initialize_core_attributes(data, columns, categorical, continuous, groupby,
nonnormal, min_max, pval, pval_adjust, htest_name,
htest, missing, ddof, rename, sort, limit, order,
label_suffix, decimals, smd, overall, row_percent,
dip_test, normal_test, tukey_test, pval_threshold)
data = self.initialize_core_attributes(data, columns, categorical, continuous, groupby,
nonnormal, min_max, pval, pval_adjust, htest_name,
htest, missing, ddof, rename, sort, limit, order,
label_suffix, decimals, smd, overall, row_percent,
dip_test, normal_test, tukey_test, pval_threshold,
include_null)

# Initialize intermediate tables
self.initialize_intermediate_tables()
Expand Down Expand Up @@ -274,11 +281,13 @@ def initialize_core_attributes(self, data, columns, categorical, continuous, gro
nonnormal, min_max, pval, pval_adjust, htest_name,
htest, missing, ddof, rename, sort, limit, order,
label_suffix, decimals, smd, overall, row_percent,
dip_test, normal_test, tukey_test, pval_threshold):
dip_test, normal_test, tukey_test, pval_threshold,
include_null):
"""
Initialize attributes.
"""
self._alt_labels = rename
self._include_null = include_null
self._columns = columns if columns else data.columns.to_list() # type: ignore
self._categorical = detect_categorical(data[self._columns], groupby) if categorical is None else categorical
if continuous:
Expand Down Expand Up @@ -308,8 +317,14 @@ def initialize_core_attributes(self, data, columns, categorical, continuous, gro
self._sort = sort
self._tukey_test = tukey_test
self._warnings = {}

if self._categorical and self._include_null:
data[self._categorical] = handle_categorical_nulls(data[self._categorical])

self._groupbylvls = get_groups(data, self._groupby, self._order, self._reserved_columns)

return data

def initialize_intermediate_tables(self):
"""
Initialize the intermediate tables.
Expand All @@ -329,10 +344,10 @@ def setup_validators(self):
self.input_validator = InputValidator()

def validate_data(self, data):
self.data_validator.validate(data, self._columns) # type: ignore
self.input_validator.validate(self._groupby, self._nonnormal, self._min_max, # type: ignore
self._pval_adjust, self._order, self._pval, # type: ignore
self._columns, self._categorical, self._continuous) # type: ignore
self.data_validator.validate(data, self._columns, self._categorical, self._include_null) # type: ignore

def create_intermediate_tables(self, data):
"""
Expand All @@ -351,6 +366,7 @@ def create_intermediate_tables(self, data):
self._categorical,
self._decimals,
self._row_percent,
self._include_null,
groupby=None,
groupbylvls=['Overall'])

Expand All @@ -370,6 +386,7 @@ def create_intermediate_tables(self, data):
self._categorical,
self._decimals,
self._row_percent,
self._include_null,
groupby=self._groupby,
groupbylvls=self._groupbylvls)

Expand Down Expand Up @@ -398,6 +415,7 @@ def create_intermediate_tables(self, data):
self._overall,
self.cat_describe,
self._categorical,
self._include_null,
self._pval,
self._pval_adjust,
self.htest_table,
Expand Down
26 changes: 20 additions & 6 deletions tableone/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def create_cat_describe(self,
categorical,
decimals,
row_percent,
include_null,
groupby: Optional[str] = None,
groupbylvls: Optional[list] = None
) -> pd.DataFrame:
Expand Down Expand Up @@ -223,12 +224,19 @@ def create_cat_describe(self,
else:
df = cat_slice.copy()

# create n column and null count column
# create n column
# must be done before converting values to strings
ct = df.count().to_frame(name='n')
ct.index.name = 'variable'
nulls = df.isnull().sum().to_frame(name='Missing')
nulls.index.name = 'variable'

if include_null:
# create an empty Missing column for display purposes
nulls = pd.DataFrame('', index=df.columns, columns=['Missing'])
nulls.index.name = 'variable'
else:
# Count and display null count
nulls = df.isnull().sum().to_frame(name='Missing')
nulls.index.name = 'variable'

# Convert to str to handle int converted to boolean in the index.
# Also avoid nans.
Expand Down Expand Up @@ -445,6 +453,7 @@ def create_cat_table(self,
overall,
cat_describe,
categorical,
include_null,
pval,
pval_adjust,
htest_table,
Expand All @@ -462,9 +471,14 @@ def create_cat_table(self,
"""
table = cat_describe['t1_summary'].copy()

# add the total count of null values across all levels
isnull = data[categorical].isnull().sum().to_frame(name='Missing')
isnull.index = isnull.index.rename('variable')
if include_null:
isnull = pd.DataFrame(index=categorical, columns=['Missing'])
isnull['Missing'] = ''
isnull.index.rename('variable', inplace=True)
else:
# add the total count of null values across all levels
isnull = data[categorical].isnull().sum().to_frame(name='Missing')
isnull.index = isnull.index.rename('variable')

try:
table = table.join(isnull)
Expand Down
4 changes: 3 additions & 1 deletion tableone/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ def __init__(self):
"""Initialize the DataValidator class."""
pass

def validate(self, data: pd.DataFrame, columns: list) -> None:
def validate(self, data: pd.DataFrame, columns: list,
categorical: list,
include_null: bool) -> None:
"""
Check the input dataset for obvious issues.
Expand Down
15 changes: 8 additions & 7 deletions tests/unit/test_tableone.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def test_overall_n_and_percent_for_binary_cat_var_with_nan(
"""
categorical = ['likeshoney']
table = TableOne(data_sample, columns=categorical,
categorical=categorical)
categorical=categorical, include_null=False)

lh = table.cat_describe.loc['likeshoney']

Expand Down Expand Up @@ -796,7 +796,8 @@ def test_nan_rows_not_deleted_in_categorical_columns(self):

# create tableone
t1 = TableOne(df, label_suffix=False,
categorical=['basket1', 'basket2', 'basket3', 'basket4'])
categorical=['basket1', 'basket2', 'basket3', 'basket4'],
include_null=False)

assert all(t1.tableone.loc['basket1'].index == ['apple', 'banana',
'durian', 'lemon',
Expand Down Expand Up @@ -1028,7 +1029,7 @@ def test_order_of_order_categorical_columns(self):

# if a custom order is not specified, the categorical order
# specified above should apply
t1 = TableOne(data, label_suffix=False)
t1 = TableOne(data, label_suffix=False, include_null=False)

t1_expected_order = {'month': ["feb", "jan", "mar", "apr"],
'day': ["wed", "thu", "mon", "tue"]}
Expand All @@ -1039,7 +1040,7 @@ def test_order_of_order_categorical_columns(self):
t1_expected_order[k])

# if a desired order is set, it should override the order
t2 = TableOne(data, order=order, label_suffix=False)
t2 = TableOne(data, order=order, label_suffix=False, include_null=False)

t2_expected_order = {'month': ["jan", "feb", "mar", "apr"],
'day': ["mon", "tue", "wed", "thu"]}
Expand Down Expand Up @@ -1104,7 +1105,7 @@ def test_row_percent_false(self, data_pn):
t1 = TableOne(data_pn, columns=columns,
categorical=categorical, groupby=groupby,
nonnormal=nonnormal, decimals=decimals,
row_percent=False)
row_percent=False, include_null=False)

row1 = list(t1.tableone.loc["MechVent, n (%)"][group].values[0])
row1_expect = [0, '540 (54.0)', '468 (54.2)', '72 (52.9)']
Expand Down Expand Up @@ -1154,7 +1155,7 @@ def test_row_percent_true(self, data_pn):
t2 = TableOne(data_pn, columns=columns,
categorical=categorical, groupby=groupby,
nonnormal=nonnormal, decimals=decimals,
row_percent=True)
row_percent=True, include_null=False)

row1 = list(t2.tableone.loc["MechVent, n (%)"][group].values[0])
row1_expect = [0, '540 (100.0)', '468 (86.7)', '72 (13.3)']
Expand Down Expand Up @@ -1204,7 +1205,7 @@ def test_row_percent_true_and_overall_false(self, data_pn):
t1 = TableOne(data_pn, columns=columns, overall=False,
categorical=categorical, groupby=groupby,
nonnormal=nonnormal, decimals=decimals,
row_percent=True)
row_percent=True, include_null=False)

row1 = list(t1.tableone.loc["MechVent, n (%)"][group].values[0])
row1_expect = [0, '468 (86.7)', '72 (13.3)']
Expand Down

0 comments on commit 6fc6a30

Please sign in to comment.