Skip to content

Commit

Permalink
Fix profile file output when running multiple chains in one process
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed May 17, 2024
1 parent fad7a69 commit cd6736f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 18 deletions.
20 changes: 10 additions & 10 deletions cmdstanpy/stanfit/runset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,40 +61,40 @@ def __init__(
self._base_outfile = (
f'{args.model_name}-{datetime.now().strftime(time_fmt)}'
)
# per-process console messages
# per-process outputs
self._stdout_files = [''] * self._num_procs
self._profile_files = [''] * self._num_procs # optional
if one_process_per_chain:
for i in range(chains):
self._stdout_files[i] = self.file_path("-stdout.txt", id=i)
if args.save_profile:
self._profile_files[i] = self.file_path(
".csv", extra="-profile", id=chain_ids[i]
)
else:
self._stdout_files[0] = self.file_path("-stdout.txt")
if args.save_profile:
self._profile_files[0] = self.file_path(
".csv", extra="-profile"
)

# per-chain output files
self._csv_files: List[str] = [''] * chains
self._diagnostic_files = [''] * chains # optional
self._profile_files = [''] * chains # optional

if chains == 1:
self._csv_files[0] = self.file_path(".csv")
if args.save_latent_dynamics:
self._diagnostic_files[0] = self.file_path(
".csv", extra="-diagnostic"
)
if args.save_profile:
self._profile_files[0] = self.file_path(
".csv", extra="-profile"
)
else:
for i in range(chains):
self._csv_files[i] = self.file_path(".csv", id=chain_ids[i])
if args.save_latent_dynamics:
self._diagnostic_files[i] = self.file_path(
".csv", extra="-diagnostic", id=chain_ids[i]
)
if args.save_profile:
self._profile_files[i] = self.file_path(
".csv", extra="-profile", id=chain_ids[i]
)

def __repr__(self) -> str:
repr = 'RunSet: chains={}, chain_ids={}, num_processes={}'.format(
Expand Down
19 changes: 11 additions & 8 deletions test/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,33 +1748,36 @@ def test_save_latent_dynamics() -> None:

def test_save_profile() -> None:
stan = os.path.join(DATAFILES_PATH, 'profile_likelihood.stan')
profile_model = CmdStanModel(stan_file=stan)
profile_model = CmdStanModel(
stan_file=stan, cpp_options={"STAN_THREADS": '1'}, force_compile=True
)

profile_fit = profile_model.sample(
chains=2,
parallel_chains=2,
force_one_process_per_chain=True,
seed=12345,
iter_warmup=100,
iter_sampling=200,
save_profile=True,
)
for i in range(profile_fit.runset.chains):
profile_file = profile_fit.runset.profile_files[i]
assert len(profile_fit.runset.profile_files) == 2
for profile_file in profile_fit.runset.profile_files:
assert os.path.exists(profile_file)

profile_fit = profile_model.sample(
chains=2,
parallel_chains=2,
force_one_process_per_chain=False,
seed=12345,
iter_warmup=100,
iter_sampling=200,
save_latent_dynamics=True,
save_profile=True,
)

for i in range(profile_fit.runset.chains):
profile_file = profile_fit.runset.profile_files[i]
assert len(profile_fit.runset.profile_files) == 1
for profile_file in profile_fit.runset.profile_files:
assert os.path.exists(profile_file)
diagnostics_file = profile_fit.runset.diagnostic_files[i]
assert os.path.exists(diagnostics_file)


def test_xarray_draws() -> None:
Expand Down

0 comments on commit cd6736f

Please sign in to comment.