From 0835de6359b28737220653708e92a61854334d17 Mon Sep 17 00:00:00 2001 From: Tom Pollard Date: Fri, 7 Jun 2024 00:08:40 -0400 Subject: [PATCH] Move detect_categorical function to preprocessors module. --- tableone/preprocessors.py | 36 ++++++++++++++++++++++++++++++++- tableone/tableone.py | 42 +++++++-------------------------------- 2 files changed, 42 insertions(+), 36 deletions(-) diff --git a/tableone/preprocessors.py b/tableone/preprocessors.py index 6c0283f..ce96335 100644 --- a/tableone/preprocessors.py +++ b/tableone/preprocessors.py @@ -1,4 +1,4 @@ - +import numpy as np def ensure_list(arg, arg_name): """ @@ -12,3 +12,37 @@ def ensure_list(arg, arg_name): return arg else: raise TypeError(f"{arg_name} must be a string or a list of strings.") + + +def detect_categorical(data, groupby) -> list: + """ + Detect categorical columns if they are not specified. + + Parameters + ---------- + data : pandas DataFrame + The input dataset. + groupby : str (optional) + The groupby variable. + + Returns + ---------- + likely_cat : list + List of variables that appear to be categorical. + """ + # assume all non-numerical and date columns are categorical + numeric_cols = set(data._get_numeric_data().columns.values) + date_cols = set(data.select_dtypes(include=[np.datetime64]).columns) + likely_cat = set(data.columns) - numeric_cols + likely_cat = list(likely_cat - date_cols) + + # check proportion of unique values if numerical + for var in data._get_numeric_data().columns: + likely_flag = 1.0 * data[var].nunique()/data[var].count() < 0.005 + if likely_flag: + likely_cat.append(var) + + if groupby: + likely_cat = [x for x in likely_cat if x != groupby] + + return likely_cat diff --git a/tableone/tableone.py b/tableone/tableone.py index 2b29bad..ac9c8f5 100644 --- a/tableone/tableone.py +++ b/tableone/tableone.py @@ -10,7 +10,7 @@ from tabulate import tabulate from tableone.deprecations import deprecated_parameter -from tableone.preprocessors import ensure_list +from tableone.preprocessors import ensure_list, detect_categorical from tableone.statistics import Statistics from tableone.validators import DataValidator, InputValidator, InputError @@ -236,13 +236,12 @@ def __init__(self, data: pd.DataFrame, self.input_validator.validate(groupby, self._nonnormal, min_max, pval_adjust, order, # type: ignore pval, self._columns, categorical, continuous) # type: ignore - # if categorical not specified, try to identify categorical + # if categorical is set to None, try to automatically detect # if empty list is provided, assume there are no categorical variables. - if not categorical and type(categorical) != list: - categorical = self._detect_categorical_columns(data[self._columns]) - # omit categorical row if it is specified in groupby - if groupby: - categorical = [x for x in categorical if x != groupby] + if categorical is None: + self._categorical = detect_categorical(data[self._columns], groupby) + else: + self._categorical = categorical # if input df has ordered categorical variables, get the order. order_cats = [x for x in data.select_dtypes("category") @@ -266,8 +265,7 @@ def __init__(self, data: pd.DataFrame, self._continuous = continuous else: self._continuous = [c for c in self._columns # type: ignore - if c not in categorical + [groupby]] - self._categorical = categorical + if c not in self._categorical + [groupby]] self._ddof = ddof self._decimals = decimals self._dip_test = dip_test @@ -481,32 +479,6 @@ def _generate_remarks(self, newline='\n') -> str: return msg - def _detect_categorical_columns(self, data) -> list: - """ - Detect categorical columns if they are not specified. - - Parameters - ---------- - data : pandas DataFrame - The input dataset. - - Returns - ---------- - likely_cat : list - List of variables that appear to be categorical. - """ - # assume all non-numerical and date columns are categorical - numeric_cols = set(data._get_numeric_data().columns.values) - date_cols = set(data.select_dtypes(include=[np.datetime64]).columns) - likely_cat = set(data.columns) - numeric_cols - likely_cat = list(likely_cat - date_cols) - # check proportion of unique values if numerical - for var in data._get_numeric_data().columns: - likely_flag = 1.0 * data[var].nunique()/data[var].count() < 0.005 - if likely_flag: - likely_cat.append(var) - return likely_cat - def _t1_summary(self, x: pd.Series) -> str: """ Compute median [IQR] or mean (Std) for the input series.