Skip to content

Commit

Permalink
Merge pull request #718 from stan-dev/fix/717-draws_pd-diagnostic-fil…
Browse files Browse the repository at this point in the history
…tering

Allow the 'vars' argument to draws_pd to filter new columns
  • Loading branch information
WardBrian authored Jan 22, 2024
2 parents e1bf7d9 + 29d5b85 commit 46f8608
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 26 deletions.
12 changes: 6 additions & 6 deletions cmdstanpy/stanfit/gq.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ def draws_pd(

self._assemble_generated_quantities()

all_columns = ['chain__', 'iter__', 'draw__'] + list(self.column_names)

gq_cols: List[str] = []
mcmc_vars: List[str] = []
if vars is not None:
Expand All @@ -341,10 +343,12 @@ def draws_pd(
info.start_idx : info.end_idx
]
)
elif var in ['chain__', 'iter__', 'draw__']:
gq_cols.append(var)
else:
raise ValueError('Unknown variable: {}'.format(var))
else:
gq_cols = list(self.column_names)
gq_cols = all_columns
vars_list = gq_cols

previous_draws_pd = self._previous_draws_pd(mcmc_vars, inc_warmup)
Expand All @@ -369,13 +373,9 @@ def draws_pd(
)
draws = np.concatenate([chains_col, iter_col, draw_col, draws], axis=2)

vars_list = ['chain__', 'iter__', 'draw__'] + vars_list
if gq_cols:
gq_cols = ['chain__', 'iter__', 'draw__'] + gq_cols

draws_pd = pd.DataFrame(
data=flatten_chains(draws),
columns=['chain__', 'iter__', 'draw__'] + list(self.column_names),
columns=all_columns,
)

if inc_sample and mcmc_vars:
Expand Down
6 changes: 3 additions & 3 deletions cmdstanpy/stanfit/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,10 +615,12 @@ def draws_pd(
cols.extend(
self.column_names[info.start_idx : info.end_idx]
)
elif var in ['chain__', 'iter__', 'draw__']:
cols.append(var)
else:
raise ValueError(f'Unknown variable: {var}')
else:
cols = list(self.column_names)
cols = ['chain__', 'iter__', 'draw__'] + list(self.column_names)

draws = self.draws(inc_warmup=inc_warmup)
# add long-form columns for chain, iteration, draw
Expand All @@ -640,8 +642,6 @@ def draws_pd(
)
draws = np.concatenate([chains_col, iter_col, draw_col, draws], axis=2)

cols = ['chain__', 'iter__', 'draw__'] + cols

return pd.DataFrame(
data=flatten_chains(draws),
columns=['chain__', 'iter__', 'draw__'] + list(self.column_names),
Expand Down
12 changes: 7 additions & 5 deletions test/test_generate_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@ def test_from_csv_files(caplog: pytest.LogCaptureFixture) -> None:
- 3 # chain, iter, draw duplicates
)

assert list(bern_gqs.draws_pd(vars=['y_rep']).columns) == (
["chain__", "iter__", "draw__"] + column_names
)
assert list(bern_gqs.draws_pd(vars=['y_rep']).columns) == (column_names)

assert list(
bern_gqs.draws_pd(vars=["chain__", "iter__", "draw__", 'y_rep']).columns
) == (["chain__", "iter__", "draw__"] + column_names)


def test_pd_xr_agreement():
Expand Down Expand Up @@ -315,9 +317,9 @@ def test_save_warmup(caplog: pytest.LogCaptureFixture) -> None:
assert bern_gqs.draws_pd(inc_warmup=True).shape == (800, 13)
assert bern_gqs.draws_pd(vars=['y_rep'], inc_warmup=False).shape == (
400,
13,
10,
)
assert bern_gqs.draws_pd(vars='y_rep', inc_warmup=False).shape == (400, 13)
assert bern_gqs.draws_pd(vars='y_rep', inc_warmup=False).shape == (400, 10)

theta = bern_gqs.stan_variable(var='theta')
assert theta.shape == (400,)
Expand Down
19 changes: 7 additions & 12 deletions test/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,24 +778,19 @@ def test_validate_good_run() -> None:
fit.runset.chains * fit.num_draws_sampling,
len(fit.column_names) + 3,
)
assert fit.draws_pd(vars=['theta']).shape == (400, 4)
assert fit.draws_pd(vars=['lp__', 'theta']).shape == (400, 5)
assert fit.draws_pd(vars=['theta', 'lp__']).shape == (400, 5)
assert fit.draws_pd(vars='theta').shape == (400, 4)
assert fit.draws_pd(vars=['theta']).shape == (400, 1)
assert fit.draws_pd(vars=['lp__', 'theta']).shape == (400, 2)
assert fit.draws_pd(vars=['theta', 'lp__']).shape == (400, 2)
assert fit.draws_pd(vars='theta').shape == (400, 1)

assert list(fit.draws_pd(vars=['theta', 'lp__']).columns) == [
'chain__',
'iter__',
'draw__',
'theta',
'lp__',
]
assert list(fit.draws_pd(vars=['lp__', 'theta']).columns) == [
'chain__',
'iter__',
'draw__',
assert list(fit.draws_pd(vars=['lp__', 'theta', 'iter__']).columns) == [
'lp__',
'theta',
'iter__',
]

summary = fit.summary()
Expand Down Expand Up @@ -854,7 +849,7 @@ def test_validate_big_run() -> None:
assert fit.step_size.shape == (2,)
assert fit.metric.shape == (2, 2095)
assert fit.draws().shape == (1000, 2, 2102)
assert fit.draws_pd(vars=['phi']).shape == (2000, 2098)
assert fit.draws_pd(vars=['phi']).shape == (2000, 2095)
with raises_nested(ValueError, r'Unknown variable: gamma'):
fit.draws_pd(vars=['gamma'])

Expand Down

0 comments on commit 46f8608

Please sign in to comment.