Skip to content

Commit

Permalink
Move method for getting group levels to preprocessor module for reada…
Browse files Browse the repository at this point in the history
…bility.
  • Loading branch information
tompollard committed Jun 7, 2024
1 parent 51affb0 commit 2391bb3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 17 deletions.
28 changes: 28 additions & 0 deletions tableone/preprocessors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np

from tableone.validators import InputError


def ensure_list(arg, arg_name):
"""
Ensure input argument is a list.
Expand Down Expand Up @@ -71,3 +74,28 @@ def order_categorical(data, order):
order = d_order_cats # type: ignore

return order


def get_groups(data, groupby, order, reserved_columns):
"""
Get groups for table.
If groupby is not specified, there will be a single "overall" group.
"""
if groupby:
groupbylvls = sorted(data.groupby(groupby).groups.keys()) # type: ignore

# reorder the groupby levels if order is provided
if order and groupby in order:
unordered = [x for x in groupbylvls if x not in order[groupby]]
groupbylvls = order[groupby] + unordered

# check that the group levels do not include reserved words
for level in groupbylvls:
if level in reserved_columns:
raise InputError("""Group level contains '{}', a reserved
keyword.""".format(level))
else:
groupbylvls = ['Overall']

return groupbylvls
19 changes: 2 additions & 17 deletions tableone/tableone.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tabulate import tabulate

from tableone.deprecations import deprecated_parameter
from tableone.preprocessors import ensure_list, detect_categorical, order_categorical
from tableone.preprocessors import ensure_list, detect_categorical, order_categorical, get_groups
from tableone.statistics import Statistics
from tableone.validators import DataValidator, InputValidator, InputError

Expand Down Expand Up @@ -267,22 +267,7 @@ def __init__(self, data: pd.DataFrame,
self._tukey_test = tukey_test
self._warnings = {} # display notes and warnings below the table

if self._groupby:
self._groupbylvls = sorted(data.groupby(groupby).groups.keys()) # type: ignore

# reorder the groupby levels if order is provided
if self._order and self._groupby in self._order:
unordered = [x for x in self._groupbylvls
if x not in self._order[self._groupby]]
self._groupbylvls = self._order[self._groupby] + unordered

# check that the group levels do not include reserved words
for level in self._groupbylvls:
if level in self._reserved_columns:
raise InputError("""Group level contains '{}', a reserved
keyword.""".format(level))
else:
self._groupbylvls = ['Overall']
self._groupbylvls = get_groups(data, self._groupby, self._order, self._reserved_columns)

# forgive me jraffa
if self._pval:
Expand Down

0 comments on commit 2391bb3

Please sign in to comment.