diff --git a/tableone/tableone.py b/tableone/tableone.py index 8008f52..4874644 100644 --- a/tableone/tableone.py +++ b/tableone/tableone.py @@ -270,7 +270,16 @@ def __init__(self, data: pd.DataFrame, self._groupbylvls = get_groups(data, self._groupby, self._order, self._reserved_columns) + # Intermediate tables self.tables = Tables() + self._htest_table = None + self.cat_describe_all = None + self.cont_describe_all = None + self.cat_describe = None + self.cont_describe = None + self.smd_table = None + self.cat_table = None + self.cont_table = None # forgive me jraffa if self._pval: @@ -331,7 +340,17 @@ def __init__(self, data: pd.DataFrame, self.cat_table = self._create_cat_table(data, self._overall) if self._continuous: - self.cont_table = self._create_cont_table(data, self._overall) + self.cont_table = self.tables.create_cont_table(data, + self._overall, + self.cont_describe, + self.cont_describe_all, + self._continuous, + self._pval, + self._pval_adjust, + self._htest_table, + self._smd, + self.smd_table, + self._groupby) # combine continuous variables and categorical variables into table 1 self.tableone = self._create_tableone(data) @@ -503,51 +522,6 @@ def _t1_summary(self, x: pd.Series) -> str: f = '{{:.{}f}} ({{:.{}f}})'.format(n, n) return f.format(np.nanmean(x.values), self.statistics._std(x, self._ddof)) # type: ignore - def _create_cont_table(self, data, overall) -> pd.DataFrame: - """ - Create tableone for continuous data. - - Returns - ---------- - table : pandas DataFrame - A table summarising the continuous variables. - """ - # remove the t1_summary level - table = self.cont_describe[['t1_summary']].copy() - table.columns = table.columns.droplevel(level=0) - - # add a column of null counts as 1-count() from previous function - nulltable = data[self._continuous].isnull().sum().to_frame( - name='Missing') - try: - table = table.join(nulltable) - # if columns form a CategoricalIndex, need to convert to string first - except TypeError: - table.columns = table.columns.astype(str) - table = table.join(nulltable) - - # add an empty value column, for joining with cat table - table['value'] = '' - table = table.set_index([table.index, 'value']) # type: ignore - - # add pval column - if self._pval and self._pval_adjust: - table = table.join(self._htest_table[['P-Value (adjusted)', - 'Test']]) - elif self._pval: - table = table.join(self._htest_table[['P-Value', 'Test']]) - - # add standardized mean difference (SMD) column/s - if self._smd: - table = table.join(self.smd_table) - - # join the overall column if needed - if self._groupby and overall: - table = table.join(pd.concat([self.cont_describe_all['t1_summary']. - Overall], axis=1, keys=["Overall"])) - - return table - def _create_cat_table(self, data, overall): """ Create table one for categorical data. diff --git a/tableone/tables.py b/tableone/tables.py index 4c084d5..608cc4d 100644 --- a/tableone/tables.py +++ b/tableone/tables.py @@ -384,3 +384,58 @@ def _non_continuous_warning(self, c): msg = ("'{}' has all non-numeric values. Consider including " "it in the list of categorical variables.").format(c) warnings.warn(msg, RuntimeWarning, stacklevel=2) + + def create_cont_table(self, + data, + overall, + cont_describe, + cont_describe_all, + continuous, + pval, + pval_adjust, + htest_table, + smd, + smd_table, + groupby + ) -> pd.DataFrame: + """ + Create tableone for continuous data. + + Returns + ---------- + table : pandas DataFrame + A table summarising the continuous variables. + """ + # remove the t1_summary level + table = cont_describe[['t1_summary']].copy() + table.columns = table.columns.droplevel(level=0) + + # add a column of null counts as 1-count() from previous function + nulltable = data[continuous].isnull().sum().to_frame(name='Missing') + try: + table = table.join(nulltable) + # if columns form a CategoricalIndex, need to convert to string first + except TypeError: + table.columns = table.columns.astype(str) + table = table.join(nulltable) + + # add an empty value column, for joining with cat table + table['value'] = '' + table = table.set_index([table.index, 'value']) # type: ignore + + # add pval column + if pval and pval_adjust: + table = table.join(htest_table[['P-Value (adjusted)', 'Test']]) + elif pval: + table = table.join(htest_table[['P-Value', 'Test']]) + + # add standardized mean difference (SMD) column/s + if smd: + table = table.join(smd_table) + + # join the overall column if needed + if groupby and overall: + table = table.join(pd.concat([cont_describe_all['t1_summary']. + Overall], axis=1, keys=["Overall"])) + + return table