Skip to content

Commit

Permalink
Refactor for readability.
Browse files Browse the repository at this point in the history
  • Loading branch information
tompollard committed Jun 16, 2024
1 parent dd7c93d commit 6d0bac9
Showing 1 changed file with 130 additions and 34 deletions.
164 changes: 130 additions & 34 deletions tableone/tableone.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,33 +582,30 @@ def _t1_summary(self, x: pd.Series) -> str:
f = '{{:.{}f}} ({{:.{}f}})'.format(n, n)
return f.format(np.nanmean(x.values), self.statistics._std(x, self._ddof)) # type: ignore

def _create_tableone(self, data):
def _combine_tables(self):
"""
Create table 1 by combining the continuous and categorical tables.
Returns
----------
table : pandas DataFrame
The complete table one.
Combine continuous and categorical tables.
"""
if self._continuous and self._categorical:
# support pandas<=0.22
try:
table = pd.concat([self.cont_table, self.cat_table],
sort=False)
table = pd.concat([self.cont_table, self.cat_table], sort=False)
except TypeError:
table = pd.concat([self.cont_table, self.cat_table])
elif self._continuous:
table = self.cont_table
elif self._categorical:
table = self.cat_table

# ensure column headers are strings before reindexing
table = table.reset_index().set_index(['variable', 'value']) # type: ignore
table.columns = table.columns.values.astype(str)
return table

def _sort_and_reindex(self, table):
"""
Sorts and reindexes the table to meet requirements.
"""
# sort the table rows
sort_columns = ['Missing', 'P-Value', 'P-Value (adjusted)', 'Test']

if self._smd and self.smd_table is not None:
sort_columns = sort_columns + list(self.smd_table.columns)

Expand All @@ -634,6 +631,23 @@ def _create_tableone(self, data):
key=lambda x: self._columns.index(x[0]))
table = table.reindex(new_index)

return table

def _format_values(self, table):
"""
Formats the numerical values in the table, specifically focusing on the p value
and SMD (Standardized Mean Differences) columns. It applies rounding and
converts numbers to strings for better presentation.
"""
table = self._format_pvalues(table)
table = self._format_smd_columns(table)
return table

def _format_pvalues(self, table):
"""
Formats the p value columns, applying rounding rules and adding
significance markers based on defined thresholds.
"""
# round pval column and convert to string
if self._pval and self._pval_adjust:
if self._pval_threshold:
Expand All @@ -658,12 +672,26 @@ def _create_tableone(self, data):
if self._pval_threshold:
table.loc[asterisk_mask, 'P-Value'] = table['P-Value'][asterisk_mask].astype(str)+"*" # type: ignore

return table

def _format_smd_columns(self, table):
"""
Formats the SMD (Standardized Mean Differences) columns. Rounds the SMD values
and ensures they are presented as strings.
"""
# round smd columns and convert to string
if self._smd and self.smd_table is not None:
for c in list(self.smd_table.columns):
table[c] = table[c].apply('{:.3f}'.format).astype(str)
table.loc[table[c] == '0.000', c] = '<0.001'

return table

def _apply_order(self, table):
"""
Applies a predefined order to rows based on specified requirements.
May include reordering based on categorical group levels or other criteria.
"""
# if an order is specified, apply it
if self._order:
for k in self._order:
Expand Down Expand Up @@ -695,6 +723,13 @@ def _create_tableone(self, data):
orig_idx[table.index.get_loc(k)] = new_idx_array
table = table.reindex(orig_idx)

return table

def _apply_limits(self, table, data):
"""
Applies limits to the number of categories shown for each categorical variable
in the DataFrame, based on specified requirements.
"""
# set the limit on the number of categorical variables
if self._limit:
levelcounts = data[self._categorical].nunique()
Expand Down Expand Up @@ -730,6 +765,13 @@ def _create_tableone(self, data):
# drop the rows > the limit
table = table.drop(new_idx_array[limit:]) # type: ignore

return table

def _insert_n_row(self, table, data):
"""
Inserts a row that shows 'n', the total number or count of items
within each group or overall.
"""
# insert n row
n_row = pd.DataFrame(columns=['variable', 'value', 'Missing'])
n_row = n_row.set_index(['variable', 'value'])
Expand All @@ -750,10 +792,17 @@ def _create_tableone(self, data):
ct = data[self._groupby][data[self._groupby] == g].count()
table.loc['n', '{}'.format(g)] = ct

return table

def _mask_duplicate_values(self, table, optional_columns):
"""
Masks duplicate values, ensuring that repeated values (e.g. counts of
missing values) are only displayed once.
"""
# only display data in first level row
dupe_mask = table.groupby(level=[0]).cumcount().ne(0) # type: ignore
dupe_columns = ['Missing']
optional_columns = ['P-Value', 'P-Value (adjusted)', 'Test']

if self._smd and self.smd_table is not None:
optional_columns = optional_columns + list(self.smd_table.columns)
for col in optional_columns:
Expand All @@ -762,31 +811,21 @@ def _create_tableone(self, data):

table[dupe_columns] = table[dupe_columns].mask(dupe_mask).fillna('')

# remove Missing column if not needed
if not self._isnull:
table = table.drop('Missing', axis=1)

if self._pval and not self._pval_test_name:
table = table.drop('Test', axis=1)

# replace nans with empty strings
table = table.fillna('')

# add column index
if not self._groupbylvls == ['Overall']:
# rename groupby variable if requested
c = self._groupby
if self._alt_labels:
if self._groupby in self._alt_labels:
c = self._alt_labels[self._groupby]

c = 'Grouped by {}'.format(c)
table.columns = pd.MultiIndex.from_product([[c], table.columns])
return table

def _apply_alt_labels(self, table):
"""
Applies alternative labels to the variables if required.
"""
# display alternative labels if assigned
table = table.rename(index=self._create_row_labels(), level=0)

# ensure the order of columns is consistent
return table

def _reorder_columns(self, table, optional_columns):
"""
Reorder columns for consistent, predictable formatting.
"""
if self._groupby and self._order and (self._groupby in self._order):
header = ['{}'.format(v) for v in table.columns.levels[1].values] # type: ignore
cols = self._order[self._groupby] + ['{}'.format(v)
Expand Down Expand Up @@ -814,6 +853,63 @@ def _create_tableone(self, data):
else:
table = table.reindex(cols, axis=1)

return table

def _add_groupby_columns(self, table):
"""
Adds multi-level column headers to denote grouping by a specific
variable, for clarity.
"""
# add column index
if not self._groupbylvls == ['Overall']:
# rename groupby variable if requested
c = self._groupby
if self._alt_labels:
if self._groupby in self._alt_labels:
c = self._alt_labels[self._groupby]

c = 'Grouped by {}'.format(c)
table.columns = pd.MultiIndex.from_product([[c], table.columns])

return table

def _create_tableone(self, data):
"""
Create table 1 by combining the continuous and categorical tables.
Returns
----------
table : pandas DataFrame
The complete table one.
"""
table = self._combine_tables()
optional_columns = ['P-Value', 'P-Value (adjusted)', 'Test']

# ensure column headers are strings before reindexing
table = table.reset_index().set_index(['variable', 'value']) # type: ignore
table.columns = table.columns.values.astype(str)

table = self._sort_and_reindex(table)
table = self._format_values(table)
table = self._apply_order(table)
table = self._apply_limits(table, data)
table = self._insert_n_row(table, data)
table = self._mask_duplicate_values(table, optional_columns)

# remove unwanted columns
if not self._isnull:
table = table.drop('Missing', axis=1)

if self._pval and not self._pval_test_name:
table = table.drop('Test', axis=1)

# replace nans with empty strings
table = table.fillna('')

table = self._add_groupby_columns(table)
table = self._apply_alt_labels(table)
table = self._reorder_columns(table, optional_columns)

try:
if 'Missing' in self._alt_labels or 'Overall' in self._alt_labels: # type: ignore
table = table.rename(columns=self._alt_labels)
Expand Down

0 comments on commit 6d0bac9

Please sign in to comment.