Skip to content

Commit

Permalink
Merge pull request #771 from stan-dev/fixes/2.36
Browse files Browse the repository at this point in the history
CmdStan 2.36 no longer requires fixed_param hacks
  • Loading branch information
WardBrian authored Dec 3, 2024
2 parents c5bcfb3 + f53ab34 commit 0ed8811
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 45 deletions.
6 changes: 3 additions & 3 deletions cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ def __init__(
self._compiler_options.add_include_path(path)

# try to detect models w/out parameters, needed for sampler
if not cmdstan_version_before(
2, 27
): # unknown end of version range
if (not cmdstan_version_before(2, 27)) and cmdstan_version_before(
2, 36
):
try:
model_info = self.src_info()
if 'parameters' in model_info:
Expand Down
23 changes: 13 additions & 10 deletions cmdstanpy/stanfit/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def _assemble_draws(self) -> None:
self._step_size[chain] = float(step_size.strip())
if self._metadata.cmdstan_config['metric'] != 'unit_e':
line = fd.readline().strip() # metric type
line = fd.readline().lstrip(' #\t')
line = fd.readline().lstrip(' #\t').rstrip()
num_unconstrained_params = len(line.split(','))
if chain == 0: # can't allocate w/o num params
if self.metric_type == 'diag_e':
Expand All @@ -429,18 +429,21 @@ def _assemble_draws(self) -> None:
),
dtype=float,
)
if self.metric_type == 'diag_e':
xs = line.split(',')
self._metric[chain, :] = [float(x) for x in xs]
else:
xs = line.split(',')
self._metric[chain, 0, :] = [float(x) for x in xs]
for i in range(1, num_unconstrained_params):
line = fd.readline().lstrip(' #\t').strip()
if line:
if self.metric_type == 'diag_e':
xs = line.split(',')
self._metric[chain, i, :] = [
self._metric[chain, :] = [float(x) for x in xs]
else:
xs = line.strip().split(',')
self._metric[chain, 0, :] = [
float(x) for x in xs
]
for i in range(1, num_unconstrained_params):
line = fd.readline().lstrip(' #\t').rstrip()
xs = line.split(',')
self._metric[chain, i, :] = [
float(x) for x in xs
]
else: # unit_e changed in 2.34 to have an extra line
pos = fd.tell()
line = fd.readline().strip()
Expand Down
Loading

0 comments on commit 0ed8811

Please sign in to comment.