Skip to content

Commit

Permalink
Add module for preprocessing input args and data.
Browse files Browse the repository at this point in the history
  • Loading branch information
tompollard committed Jun 7, 2024
1 parent 9d21c8e commit baa2b1f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
14 changes: 14 additions & 0 deletions tableone/preprocessors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@


def ensure_list(arg, arg_name):
"""
Ensure input argument is a list.
"""
if arg is None:
return []
elif isinstance(arg, str):
return [arg]
elif isinstance(arg, list):
return arg
else:
raise TypeError(f"{arg_name} must be a string or a list of strings.")
31 changes: 16 additions & 15 deletions tableone/tableone.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tabulate import tabulate

from tableone.deprecations import deprecated_parameter
from tableone.preprocessors import ensure_list
from tableone.statistics import Statistics
from tableone.validators import DataValidator, InputValidator, InputError

Expand Down Expand Up @@ -68,10 +69,17 @@ class TableOne:
variables.
columns : list, optional
List of columns in the dataset to be included in the final table.
Setting the argument to None will include all columns by default.
categorical : list, optional
List of columns that contain categorical variables.
If the argument is set to None (or omitted), we attempt to detect
categorical variables. Set to an empty list to indicate explicitly
that there are no variables of this type to be included.
continuous : list, optional
List of columns that contain continuous variables.
If the argument is set to None (or omitted), we attempt to detect
continuous variables. Set to an empty list to indicate explicitly
that there are no variables of this type to be included.
groupby : str, optional
Optional column for stratifying the final table (default: None).
nonnormal : list, optional
Expand Down Expand Up @@ -217,26 +225,21 @@ def __init__(self, data: pd.DataFrame,
deprecated_parameter(pval_test_name, "pval_test_name", "Use 'htest_name' instead")
deprecated_parameter(remarks, "remarks", "Use test names instead (e.g. diptest = True)")

if not columns:
columns = data.columns.values # type: ignore
self._columns = columns if columns else data.columns.to_list() # type: ignore
self._nonnormal = ensure_list(nonnormal, arg_name="nonnormal") # type: ignore

self.statistics = Statistics()
self.data_validator = DataValidator()
self.data_validator.validate(data, columns) # type: ignore
self.data_validator.validate(data, self._columns) # type: ignore

self.input_validator = InputValidator()
self.input_validator.validate(groupby, nonnormal, min_max, pval_adjust, order, # type: ignore
pval, columns, categorical, continuous) # type: ignore

# nonnormal should be a list
if not nonnormal:
nonnormal = []
elif nonnormal and type(nonnormal) == str:
nonnormal = [nonnormal]
self.input_validator.validate(groupby, self._nonnormal, min_max, pval_adjust, order, # type: ignore
pval, self._columns, categorical, continuous) # type: ignore

# if categorical not specified, try to identify categorical
# if empty list is provided, assume there are no categorical variables.
if not categorical and type(categorical) != list:
categorical = self._detect_categorical_columns(data[columns])
categorical = self._detect_categorical_columns(data[self._columns])
# omit categorical row if it is specified in groupby
if groupby:
categorical = [x for x in categorical if x != groupby]
Expand All @@ -259,11 +262,10 @@ def __init__(self, data: pd.DataFrame,
order = d_order_cats # type: ignore

self._alt_labels = rename
self._columns = list(columns) # type: ignore
if continuous:
self._continuous = continuous
else:
self._continuous = [c for c in columns # type: ignore
self._continuous = [c for c in self._columns # type: ignore
if c not in categorical + [groupby]]
self._categorical = categorical
self._ddof = ddof
Expand All @@ -275,7 +277,6 @@ def __init__(self, data: pd.DataFrame,
self._label_suffix = label_suffix
self._limit = limit
self._min_max = min_max
self._nonnormal = nonnormal
self._normal_test = normal_test
self._order = order
self._overall = overall
Expand Down

0 comments on commit baa2b1f

Please sign in to comment.