Skip to content

Commit

Permalink
CmdStan 2.36 no longer requires fixed_param hacks
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Nov 25, 2024
1 parent c5bcfb3 commit 7667d52
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 43 deletions.
6 changes: 3 additions & 3 deletions cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ def __init__(
self._compiler_options.add_include_path(path)

# try to detect models w/out parameters, needed for sampler
if not cmdstan_version_before(
2, 27
): # unknown end of version range
if not cmdstan_version_before(2, 27) and cmdstan_version_before(
2, 36
):
try:
model_info = self.src_info()
if 'parameters' in model_info:
Expand Down
23 changes: 13 additions & 10 deletions cmdstanpy/stanfit/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def _assemble_draws(self) -> None:
self._step_size[chain] = float(step_size.strip())
if self._metadata.cmdstan_config['metric'] != 'unit_e':
line = fd.readline().strip() # metric type
line = fd.readline().lstrip(' #\t')
line = fd.readline().lstrip(' #\t').rstrip()
num_unconstrained_params = len(line.split(','))
if chain == 0: # can't allocate w/o num params
if self.metric_type == 'diag_e':
Expand All @@ -429,18 +429,21 @@ def _assemble_draws(self) -> None:
),
dtype=float,
)
if self.metric_type == 'diag_e':
xs = line.split(',')
self._metric[chain, :] = [float(x) for x in xs]
else:
xs = line.split(',')
self._metric[chain, 0, :] = [float(x) for x in xs]
for i in range(1, num_unconstrained_params):
line = fd.readline().lstrip(' #\t').strip()
if line:
if self.metric_type == 'diag_e':
xs = line.split(',')
self._metric[chain, i, :] = [
self._metric[chain, :] = [float(x) for x in xs]
else:
xs = line.strip().split(',')
self._metric[chain, 0, :] = [
float(x) for x in xs
]
for i in range(1, num_unconstrained_params):
line = fd.readline().lstrip(' #\t').rstrip()
xs = line.split(',')
self._metric[chain, i, :] = [
float(x) for x in xs
]
else: # unit_e changed in 2.34 to have an extra line
pos = fd.tell()
line = fd.readline().strip()
Expand Down
45 changes: 15 additions & 30 deletions test/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,23 +642,23 @@ def test_fixed_param_good() -> None:
assert datagen_fit.step_size is None


def test_fixed_param_unspecified() -> None:
def test_sample_no_params() -> None:
stan = os.path.join(DATAFILES_PATH, 'datagen_poisson_glm.stan')
datagen_model = CmdStanModel(stan_file=stan)
datagen_fit = datagen_model.sample(iter_sampling=100, show_progress=False)
assert datagen_fit.step_size is None
assert np.isnan(datagen_fit.step_size).all()
summary = datagen_fit.summary()
assert 'lp__' not in list(summary.index)
assert 'lp__' in list(summary.index)

exe_only = os.path.join(DATAFILES_PATH, 'exe_only')
shutil.copyfile(datagen_model.exe_file, exe_only)
os.chmod(exe_only, 0o755)
datagen2_model = CmdStanModel(exe_file=exe_only)
datagen2_fit = datagen2_model.sample(iter_sampling=200, show_console=True)
assert datagen2_fit.chains == 4
assert datagen2_fit.step_size is None
assert np.isnan(datagen2_fit.step_size).all()
summary = datagen2_fit.summary()
assert 'lp__' not in list(summary.index)
assert 'lp__' in list(summary.index)


def test_index_bounds_error() -> None:
Expand Down Expand Up @@ -823,7 +823,7 @@ def test_validate_good_run() -> None:
assert 'Treedepth satisfactory for all transitions.' in diagnostics
assert 'No divergent transitions found.' in diagnostics
assert 'E-BFMI satisfactory' in diagnostics
assert 'Effective sample size satisfactory.' in diagnostics
assert 'effective sample size satisfactory.' in diagnostics.lower()


def test_validate_big_run() -> None:
Expand Down Expand Up @@ -1621,33 +1621,18 @@ def test_validate_sample_sig_figs(stanfile='bernoulli.stan'):


def test_validate_summary_sig_figs() -> None:
# construct CmdStanMCMC from logistic model output, config
exe = os.path.join(DATAFILES_PATH, 'logistic' + EXTENSION)
rdata = os.path.join(DATAFILES_PATH, 'logistic.data.R')
sampler_args = SamplerArgs(iter_sampling=100)
cmdstan_args = CmdStanArgs(
model_name='logistic',
model_exe=exe,
chain_ids=[1, 2, 3, 4],
seed=12345,
data=rdata,
output_dir=DATAFILES_PATH,
sig_figs=17,
method_args=sampler_args,
# construct CmdStanMCMC from logistic model output
fit = from_csv(
[
os.path.join(DATAFILES_PATH, 'logistic_output_1.csv'),
os.path.join(DATAFILES_PATH, 'logistic_output_2.csv'),
os.path.join(DATAFILES_PATH, 'logistic_output_3.csv'),
os.path.join(DATAFILES_PATH, 'logistic_output_4.csv'),
]
)
runset = RunSet(args=cmdstan_args, chains=4)
runset._csv_files = [
os.path.join(DATAFILES_PATH, 'logistic_output_1.csv'),
os.path.join(DATAFILES_PATH, 'logistic_output_2.csv'),
os.path.join(DATAFILES_PATH, 'logistic_output_3.csv'),
os.path.join(DATAFILES_PATH, 'logistic_output_4.csv'),
]
retcodes = runset._retcodes
for i in range(len(retcodes)):
runset._set_retcode(i, 0)
fit = CmdStanMCMC(runset)

sum_default = fit.summary()

beta1_default = format(sum_default.iloc[1, 0], '.18g')
assert beta1_default.startswith('1.3')

Expand Down

0 comments on commit 7667d52

Please sign in to comment.