Skip to content

Commit

Permalink
Move detect_categorical function to preprocessors module.
Browse files Browse the repository at this point in the history
  • Loading branch information
tompollard committed Jun 7, 2024
1 parent baa2b1f commit 0835de6
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 36 deletions.
36 changes: 35 additions & 1 deletion tableone/preprocessors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

import numpy as np

def ensure_list(arg, arg_name):
"""
Expand All @@ -12,3 +12,37 @@ def ensure_list(arg, arg_name):
return arg
else:
raise TypeError(f"{arg_name} must be a string or a list of strings.")


def detect_categorical(data, groupby) -> list:
"""
Detect categorical columns if they are not specified.
Parameters
----------
data : pandas DataFrame
The input dataset.
groupby : str (optional)
The groupby variable.
Returns
----------
likely_cat : list
List of variables that appear to be categorical.
"""
# assume all non-numerical and date columns are categorical
numeric_cols = set(data._get_numeric_data().columns.values)
date_cols = set(data.select_dtypes(include=[np.datetime64]).columns)
likely_cat = set(data.columns) - numeric_cols
likely_cat = list(likely_cat - date_cols)

# check proportion of unique values if numerical
for var in data._get_numeric_data().columns:
likely_flag = 1.0 * data[var].nunique()/data[var].count() < 0.005
if likely_flag:
likely_cat.append(var)

if groupby:
likely_cat = [x for x in likely_cat if x != groupby]

return likely_cat
42 changes: 7 additions & 35 deletions tableone/tableone.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tabulate import tabulate

from tableone.deprecations import deprecated_parameter
from tableone.preprocessors import ensure_list
from tableone.preprocessors import ensure_list, detect_categorical
from tableone.statistics import Statistics
from tableone.validators import DataValidator, InputValidator, InputError

Expand Down Expand Up @@ -236,13 +236,12 @@ def __init__(self, data: pd.DataFrame,
self.input_validator.validate(groupby, self._nonnormal, min_max, pval_adjust, order, # type: ignore
pval, self._columns, categorical, continuous) # type: ignore

# if categorical not specified, try to identify categorical
# if categorical is set to None, try to automatically detect
# if empty list is provided, assume there are no categorical variables.
if not categorical and type(categorical) != list:
categorical = self._detect_categorical_columns(data[self._columns])
# omit categorical row if it is specified in groupby
if groupby:
categorical = [x for x in categorical if x != groupby]
if categorical is None:
self._categorical = detect_categorical(data[self._columns], groupby)
else:
self._categorical = categorical

# if input df has ordered categorical variables, get the order.
order_cats = [x for x in data.select_dtypes("category")
Expand All @@ -266,8 +265,7 @@ def __init__(self, data: pd.DataFrame,
self._continuous = continuous
else:
self._continuous = [c for c in self._columns # type: ignore
if c not in categorical + [groupby]]
self._categorical = categorical
if c not in self._categorical + [groupby]]
self._ddof = ddof
self._decimals = decimals
self._dip_test = dip_test
Expand Down Expand Up @@ -481,32 +479,6 @@ def _generate_remarks(self, newline='\n') -> str:

return msg

def _detect_categorical_columns(self, data) -> list:
"""
Detect categorical columns if they are not specified.
Parameters
----------
data : pandas DataFrame
The input dataset.
Returns
----------
likely_cat : list
List of variables that appear to be categorical.
"""
# assume all non-numerical and date columns are categorical
numeric_cols = set(data._get_numeric_data().columns.values)
date_cols = set(data.select_dtypes(include=[np.datetime64]).columns)
likely_cat = set(data.columns) - numeric_cols
likely_cat = list(likely_cat - date_cols)
# check proportion of unique values if numerical
for var in data._get_numeric_data().columns:
likely_flag = 1.0 * data[var].nunique()/data[var].count() < 0.005
if likely_flag:
likely_cat.append(var)
return likely_cat

def _t1_summary(self, x: pd.Series) -> str:
"""
Compute median [IQR] or mean (Std) for the input series.
Expand Down

0 comments on commit 0835de6

Please sign in to comment.