Skip to content

Commit

Permalink
Merge pull request #731 from stan-dev/fix/730-warmup-no-adapt
Browse files Browse the repository at this point in the history
Allow warmup iterations when not adapting
  • Loading branch information
WardBrian authored Jan 25, 2024
2 parents 46f8608 + af8a747 commit e7e6121
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
5 changes: 2 additions & 3 deletions cmdstanpy/cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,9 @@ def validate(self, chains: Optional[int]) -> None:
'Value for iter_warmup must be a non-negative integer,'
' found {}.'.format(self.iter_warmup)
)
if self.iter_warmup > 0 and not self.adapt_engaged:
if self.iter_warmup == 0 and self.adapt_engaged:
raise ValueError(
'Argument "adapt_engaged" is False, '
'cannot specify warmup iterations.'
'Must specify iter_warmup > 0 when adapt_engaged=True.'
)
if self.iter_sampling is not None:
if self.iter_sampling < 0 or not isinstance(
Expand Down
2 changes: 1 addition & 1 deletion test/test_cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_bad() -> None:
with pytest.raises(ValueError):
args.validate(chains=2)

args = SamplerArgs(iter_warmup=10, adapt_engaged=False)
args = SamplerArgs(iter_warmup=0, adapt_engaged=True)
with pytest.raises(ValueError):
args.validate(chains=2)

Expand Down
20 changes: 20 additions & 0 deletions test/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,26 @@ def test_dont_save_warmup(caplog: pytest.LogCaptureFixture) -> None:
)


def test_warmup_no_adapt() -> None:
# we may want to have a "burn-in" period, even without adaptation
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')

bern_model = CmdStanModel(stan_file=stan)
bern_fit = bern_model.sample(
data=jdata,
chains=2,
seed=12345,
iter_warmup=200,
iter_sampling=100,
adapt_engaged=False,
)

assert bern_fit.column_names == tuple(BERNOULLI_COLS)
assert bern_fit.num_draws_sampling == 100
assert bern_fit.draws().shape == (100, 2, len(BERNOULLI_COLS))


def test_sampler_diags() -> None:
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
Expand Down

0 comments on commit e7e6121

Please sign in to comment.