From 9208822cfc5ba3c5b6de23de7cce07b56176710c Mon Sep 17 00:00:00 2001 From: Tom Pollard Date: Sat, 15 Jun 2024 23:38:35 -0400 Subject: [PATCH] Refactor for readability. --- tableone/tableone.py | 165 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 131 insertions(+), 34 deletions(-) diff --git a/tableone/tableone.py b/tableone/tableone.py index 883e9a2..c196bab 100644 --- a/tableone/tableone.py +++ b/tableone/tableone.py @@ -582,20 +582,15 @@ def _t1_summary(self, x: pd.Series) -> str: f = '{{:.{}f}} ({{:.{}f}})'.format(n, n) return f.format(np.nanmean(x.values), self.statistics._std(x, self._ddof)) # type: ignore - def _create_tableone(self, data): + def _combine_tables(self): """ - Create table 1 by combining the continuous and categorical tables. - - Returns - ---------- - table : pandas DataFrame - The complete table one. + Combine continuous and categorical tables. """ + table = pd.DataFrame() if self._continuous and self._categorical: # support pandas<=0.22 try: - table = pd.concat([self.cont_table, self.cat_table], - sort=False) + table = pd.concat([self.cont_table, self.cat_table], sort=False) except TypeError: table = pd.concat([self.cont_table, self.cat_table]) elif self._continuous: @@ -603,12 +598,15 @@ def _create_tableone(self, data): elif self._categorical: table = self.cat_table - # ensure column headers are strings before reindexing - table = table.reset_index().set_index(['variable', 'value']) # type: ignore - table.columns = table.columns.values.astype(str) + return table + def _sort_and_reindex(self, table): + """ + Sorts and reindexes the table to meet requirements. + """ # sort the table rows sort_columns = ['Missing', 'P-Value', 'P-Value (adjusted)', 'Test'] + if self._smd and self.smd_table is not None: sort_columns = sort_columns + list(self.smd_table.columns) @@ -634,6 +632,23 @@ def _create_tableone(self, data): key=lambda x: self._columns.index(x[0])) table = table.reindex(new_index) + return table + + def _format_values(self, table): + """ + Formats the numerical values in the table, specifically focusing on the p value + and SMD (Standardized Mean Differences) columns. It applies rounding and + converts numbers to strings for better presentation. + """ + table = self._format_pvalues(table) + table = self._format_smd_columns(table) + return table + + def _format_pvalues(self, table): + """ + Formats the p value columns, applying rounding rules and adding + significance markers based on defined thresholds. + """ # round pval column and convert to string if self._pval and self._pval_adjust: if self._pval_threshold: @@ -658,12 +673,26 @@ def _create_tableone(self, data): if self._pval_threshold: table.loc[asterisk_mask, 'P-Value'] = table['P-Value'][asterisk_mask].astype(str)+"*" # type: ignore + return table + + def _format_smd_columns(self, table): + """ + Formats the SMD (Standardized Mean Differences) columns. Rounds the SMD values + and ensures they are presented as strings. + """ # round smd columns and convert to string if self._smd and self.smd_table is not None: for c in list(self.smd_table.columns): table[c] = table[c].apply('{:.3f}'.format).astype(str) table.loc[table[c] == '0.000', c] = '<0.001' + return table + + def _apply_order(self, table): + """ + Applies a predefined order to rows based on specified requirements. + May include reordering based on categorical group levels or other criteria. + """ # if an order is specified, apply it if self._order: for k in self._order: @@ -695,6 +724,13 @@ def _create_tableone(self, data): orig_idx[table.index.get_loc(k)] = new_idx_array table = table.reindex(orig_idx) + return table + + def _apply_limits(self, table, data): + """ + Applies limits to the number of categories shown for each categorical variable + in the DataFrame, based on specified requirements. + """ # set the limit on the number of categorical variables if self._limit: levelcounts = data[self._categorical].nunique() @@ -730,6 +766,13 @@ def _create_tableone(self, data): # drop the rows > the limit table = table.drop(new_idx_array[limit:]) # type: ignore + return table + + def _insert_n_row(self, table, data): + """ + Inserts a row that shows 'n', the total number or count of items + within each group or overall. + """ # insert n row n_row = pd.DataFrame(columns=['variable', 'value', 'Missing']) n_row = n_row.set_index(['variable', 'value']) @@ -750,10 +793,17 @@ def _create_tableone(self, data): ct = data[self._groupby][data[self._groupby] == g].count() table.loc['n', '{}'.format(g)] = ct + return table + + def _mask_duplicate_values(self, table, optional_columns): + """ + Masks duplicate values, ensuring that repeated values (e.g. counts of + missing values) are only displayed once. + """ # only display data in first level row dupe_mask = table.groupby(level=[0]).cumcount().ne(0) # type: ignore dupe_columns = ['Missing'] - optional_columns = ['P-Value', 'P-Value (adjusted)', 'Test'] + if self._smd and self.smd_table is not None: optional_columns = optional_columns + list(self.smd_table.columns) for col in optional_columns: @@ -762,31 +812,21 @@ def _create_tableone(self, data): table[dupe_columns] = table[dupe_columns].mask(dupe_mask).fillna('') - # remove Missing column if not needed - if not self._isnull: - table = table.drop('Missing', axis=1) - - if self._pval and not self._pval_test_name: - table = table.drop('Test', axis=1) - - # replace nans with empty strings - table = table.fillna('') - - # add column index - if not self._groupbylvls == ['Overall']: - # rename groupby variable if requested - c = self._groupby - if self._alt_labels: - if self._groupby in self._alt_labels: - c = self._alt_labels[self._groupby] - - c = 'Grouped by {}'.format(c) - table.columns = pd.MultiIndex.from_product([[c], table.columns]) + return table + def _apply_alt_labels(self, table): + """ + Applies alternative labels to the variables if required. + """ # display alternative labels if assigned table = table.rename(index=self._create_row_labels(), level=0) - # ensure the order of columns is consistent + return table + + def _reorder_columns(self, table, optional_columns): + """ + Reorder columns for consistent, predictable formatting. + """ if self._groupby and self._order and (self._groupby in self._order): header = ['{}'.format(v) for v in table.columns.levels[1].values] # type: ignore cols = self._order[self._groupby] + ['{}'.format(v) @@ -814,6 +854,63 @@ def _create_tableone(self, data): else: table = table.reindex(cols, axis=1) + return table + + def _add_groupby_columns(self, table): + """ + Adds multi-level column headers to denote grouping by a specific + variable, for clarity. + """ + # add column index + if not self._groupbylvls == ['Overall']: + # rename groupby variable if requested + c = self._groupby + if self._alt_labels: + if self._groupby in self._alt_labels: + c = self._alt_labels[self._groupby] + + c = 'Grouped by {}'.format(c) + table.columns = pd.MultiIndex.from_product([[c], table.columns]) + + return table + + def _create_tableone(self, data): + """ + Create table 1 by combining the continuous and categorical tables. + + Returns + ---------- + table : pandas DataFrame + The complete table one. + """ + table = self._combine_tables() + optional_columns = ['P-Value', 'P-Value (adjusted)', 'Test'] + + # ensure column headers are strings before reindexing + table = table.reset_index().set_index(['variable', 'value']) # type: ignore + table.columns = table.columns.values.astype(str) + + table = self._sort_and_reindex(table) + table = self._format_values(table) + table = self._apply_order(table) + table = self._apply_limits(table, data) + table = self._insert_n_row(table, data) + table = self._mask_duplicate_values(table, optional_columns) + + # remove unwanted columns + if not self._isnull: + table = table.drop('Missing', axis=1) + + if self._pval and not self._pval_test_name: + table = table.drop('Test', axis=1) + + # replace nans with empty strings + table = table.fillna('') + + table = self._add_groupby_columns(table) + table = self._apply_alt_labels(table) + table = self._reorder_columns(table, optional_columns) + try: if 'Missing' in self._alt_labels or 'Overall' in self._alt_labels: # type: ignore table = table.rename(columns=self._alt_labels)