Skip to content

Commit

Permalink
Define order for categorical variables in the preprocessor module.
Browse files Browse the repository at this point in the history
  • Loading branch information
tompollard committed Jun 7, 2024
1 parent 0835de6 commit 50f3c78
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 18 deletions.
25 changes: 25 additions & 0 deletions tableone/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,28 @@ def detect_categorical(data, groupby) -> list:
likely_cat = [x for x in likely_cat if x != groupby]

return likely_cat


def order_categorical(data, order):
"""
Define an order for categorical variables.
"""
# if input df has ordered categorical variables, get the order.
order_cats = [x for x in data.select_dtypes("category")
if data[x].dtype.ordered] # type: ignore

if any(order_cats):
d_order_cats = {v: data[v].cat.categories for v in order_cats}
d_order_cats = {k: ["{}".format(v) for v in d_order_cats[k]]
for k in d_order_cats}

# combine the orders. custom order takes precedence.
if order_cats and order:
new = {**order, **d_order_cats} # type: ignore
for k in order:
new[k] = order[k] + [x for x in new[k] if x not in order[k]]
order = new
elif order_cats:
order = d_order_cats # type: ignore

return order
20 changes: 2 additions & 18 deletions tableone/tableone.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tabulate import tabulate

from tableone.deprecations import deprecated_parameter
from tableone.preprocessors import ensure_list, detect_categorical
from tableone.preprocessors import ensure_list, detect_categorical, order_categorical
from tableone.statistics import Statistics
from tableone.validators import DataValidator, InputValidator, InputError

Expand Down Expand Up @@ -243,22 +243,7 @@ def __init__(self, data: pd.DataFrame,
else:
self._categorical = categorical

# if input df has ordered categorical variables, get the order.
order_cats = [x for x in data.select_dtypes("category")
if data[x].dtype.ordered] # type: ignore
if any(order_cats):
d_order_cats = {v: data[v].cat.categories for v in order_cats}
d_order_cats = {k: ["{}".format(v) for v in d_order_cats[k]]
for k in d_order_cats}

# combine the orders. custom order takes precedence.
if order_cats and order:
new = {**order, **d_order_cats} # type: ignore
for k in order:
new[k] = order[k] + [x for x in new[k] if x not in order[k]]
order = new
elif order_cats:
order = d_order_cats # type: ignore
self._order = order_categorical(data, order)

self._alt_labels = rename
if continuous:
Expand All @@ -276,7 +261,6 @@ def __init__(self, data: pd.DataFrame,
self._limit = limit
self._min_max = min_max
self._normal_test = normal_test
self._order = order
self._overall = overall
self._pval = pval
self._pval_adjust = pval_adjust
Expand Down

0 comments on commit 50f3c78

Please sign in to comment.