diff --git a/cmdstanpy/stanfit/gq.py b/cmdstanpy/stanfit/gq.py index f383eaec..6c77ec95 100644 --- a/cmdstanpy/stanfit/gq.py +++ b/cmdstanpy/stanfit/gq.py @@ -323,6 +323,8 @@ def draws_pd( self._assemble_generated_quantities() + all_columns = ['chain__', 'iter__', 'draw__'] + list(self.column_names) + gq_cols: List[str] = [] mcmc_vars: List[str] = [] if vars is not None: @@ -341,10 +343,12 @@ def draws_pd( info.start_idx : info.end_idx ] ) + elif var in ['chain__', 'iter__', 'draw__']: + gq_cols.append(var) else: raise ValueError('Unknown variable: {}'.format(var)) else: - gq_cols = list(self.column_names) + gq_cols = all_columns vars_list = gq_cols previous_draws_pd = self._previous_draws_pd(mcmc_vars, inc_warmup) @@ -369,13 +373,9 @@ def draws_pd( ) draws = np.concatenate([chains_col, iter_col, draw_col, draws], axis=2) - vars_list = ['chain__', 'iter__', 'draw__'] + vars_list - if gq_cols: - gq_cols = ['chain__', 'iter__', 'draw__'] + gq_cols - draws_pd = pd.DataFrame( data=flatten_chains(draws), - columns=['chain__', 'iter__', 'draw__'] + list(self.column_names), + columns=all_columns, ) if inc_sample and mcmc_vars: diff --git a/cmdstanpy/stanfit/mcmc.py b/cmdstanpy/stanfit/mcmc.py index 4d908033..0bc1e599 100644 --- a/cmdstanpy/stanfit/mcmc.py +++ b/cmdstanpy/stanfit/mcmc.py @@ -615,10 +615,12 @@ def draws_pd( cols.extend( self.column_names[info.start_idx : info.end_idx] ) + elif var in ['chain__', 'iter__', 'draw__']: + cols.append(var) else: raise ValueError(f'Unknown variable: {var}') else: - cols = list(self.column_names) + cols = ['chain__', 'iter__', 'draw__'] + list(self.column_names) draws = self.draws(inc_warmup=inc_warmup) # add long-form columns for chain, iteration, draw @@ -640,8 +642,6 @@ def draws_pd( ) draws = np.concatenate([chains_col, iter_col, draw_col, draws], axis=2) - cols = ['chain__', 'iter__', 'draw__'] + cols - return pd.DataFrame( data=flatten_chains(draws), columns=['chain__', 'iter__', 'draw__'] + list(self.column_names), diff --git a/test/test_generate_quantities.py b/test/test_generate_quantities.py index b88ba327..0fc6ded3 100644 --- a/test/test_generate_quantities.py +++ b/test/test_generate_quantities.py @@ -85,9 +85,11 @@ def test_from_csv_files(caplog: pytest.LogCaptureFixture) -> None: - 3 # chain, iter, draw duplicates ) - assert list(bern_gqs.draws_pd(vars=['y_rep']).columns) == ( - ["chain__", "iter__", "draw__"] + column_names - ) + assert list(bern_gqs.draws_pd(vars=['y_rep']).columns) == (column_names) + + assert list( + bern_gqs.draws_pd(vars=["chain__", "iter__", "draw__", 'y_rep']).columns + ) == (["chain__", "iter__", "draw__"] + column_names) def test_pd_xr_agreement(): @@ -315,9 +317,9 @@ def test_save_warmup(caplog: pytest.LogCaptureFixture) -> None: assert bern_gqs.draws_pd(inc_warmup=True).shape == (800, 13) assert bern_gqs.draws_pd(vars=['y_rep'], inc_warmup=False).shape == ( 400, - 13, + 10, ) - assert bern_gqs.draws_pd(vars='y_rep', inc_warmup=False).shape == (400, 13) + assert bern_gqs.draws_pd(vars='y_rep', inc_warmup=False).shape == (400, 10) theta = bern_gqs.stan_variable(var='theta') assert theta.shape == (400,) diff --git a/test/test_sample.py b/test/test_sample.py index c783f92c..02404d58 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -778,24 +778,19 @@ def test_validate_good_run() -> None: fit.runset.chains * fit.num_draws_sampling, len(fit.column_names) + 3, ) - assert fit.draws_pd(vars=['theta']).shape == (400, 4) - assert fit.draws_pd(vars=['lp__', 'theta']).shape == (400, 5) - assert fit.draws_pd(vars=['theta', 'lp__']).shape == (400, 5) - assert fit.draws_pd(vars='theta').shape == (400, 4) + assert fit.draws_pd(vars=['theta']).shape == (400, 1) + assert fit.draws_pd(vars=['lp__', 'theta']).shape == (400, 2) + assert fit.draws_pd(vars=['theta', 'lp__']).shape == (400, 2) + assert fit.draws_pd(vars='theta').shape == (400, 1) assert list(fit.draws_pd(vars=['theta', 'lp__']).columns) == [ - 'chain__', - 'iter__', - 'draw__', 'theta', 'lp__', ] - assert list(fit.draws_pd(vars=['lp__', 'theta']).columns) == [ - 'chain__', - 'iter__', - 'draw__', + assert list(fit.draws_pd(vars=['lp__', 'theta', 'iter__']).columns) == [ 'lp__', 'theta', + 'iter__', ] summary = fit.summary() @@ -854,7 +849,7 @@ def test_validate_big_run() -> None: assert fit.step_size.shape == (2,) assert fit.metric.shape == (2, 2095) assert fit.draws().shape == (1000, 2, 2102) - assert fit.draws_pd(vars=['phi']).shape == (2000, 2098) + assert fit.draws_pd(vars=['phi']).shape == (2000, 2095) with raises_nested(ValueError, r'Unknown variable: gamma'): fit.draws_pd(vars=['gamma'])