Skip to content

Commit

Permalink
Merge pull request #751 from stan-dev/fix/multichain-profile-file
Browse files Browse the repository at this point in the history
Fix profile file output when running multiple chains in one process
  • Loading branch information
WardBrian authored May 21, 2024
2 parents fad7a69 + bf39084 commit 4f39687
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 18 deletions.
2 changes: 2 additions & 0 deletions cmdstanpy/install_cmdstan.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,8 @@ def retrieve_version(version: str, progress: bool = True) -> None:
first = tar.next()
if first is not None:
top_dir = first.name
else:
top_dir = ''
cmdstan_dir = f'cmdstan-{version}'
if top_dir != cmdstan_dir:
raise CmdStanInstallError(
Expand Down
4 changes: 4 additions & 0 deletions cmdstanpy/install_cxx_toolchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,10 @@ def get_toolchain_name() -> str:
return ''


# TODO(2.0): drop 3.5 support
def get_url(version: str) -> str:
"""Return URL for toolchain."""
url = ''
if platform.system() == 'Windows':
if version == '4.0':
# pylint: disable=line-too-long
Expand Down Expand Up @@ -277,6 +279,8 @@ def run_rtools_install(args: Dict[str, Any]) -> None:

if 'verbose' in args:
verbose = args['verbose']
else:
verbose = False

install_dir = args['dir']
if install_dir is None:
Expand Down
20 changes: 10 additions & 10 deletions cmdstanpy/stanfit/runset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,40 +61,40 @@ def __init__(
self._base_outfile = (
f'{args.model_name}-{datetime.now().strftime(time_fmt)}'
)
# per-process console messages
# per-process outputs
self._stdout_files = [''] * self._num_procs
self._profile_files = [''] * self._num_procs # optional
if one_process_per_chain:
for i in range(chains):
self._stdout_files[i] = self.file_path("-stdout.txt", id=i)
if args.save_profile:
self._profile_files[i] = self.file_path(
".csv", extra="-profile", id=chain_ids[i]
)
else:
self._stdout_files[0] = self.file_path("-stdout.txt")
if args.save_profile:
self._profile_files[0] = self.file_path(
".csv", extra="-profile"
)

# per-chain output files
self._csv_files: List[str] = [''] * chains
self._diagnostic_files = [''] * chains # optional
self._profile_files = [''] * chains # optional

if chains == 1:
self._csv_files[0] = self.file_path(".csv")
if args.save_latent_dynamics:
self._diagnostic_files[0] = self.file_path(
".csv", extra="-diagnostic"
)
if args.save_profile:
self._profile_files[0] = self.file_path(
".csv", extra="-profile"
)
else:
for i in range(chains):
self._csv_files[i] = self.file_path(".csv", id=chain_ids[i])
if args.save_latent_dynamics:
self._diagnostic_files[i] = self.file_path(
".csv", extra="-diagnostic", id=chain_ids[i]
)
if args.save_profile:
self._profile_files[i] = self.file_path(
".csv", extra="-profile", id=chain_ids[i]
)

def __repr__(self) -> str:
repr = 'RunSet: chains={}, chain_ids={}, num_processes={}'.format(
Expand Down
1 change: 1 addition & 0 deletions cmdstanpy/utils/stancsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]:
all_iters[i, :] = [float(x) for x in xs]
if i == iters - 1:
mle: np.ndarray = np.array(xs, dtype=float)
# pylint: disable=possibly-used-before-assignment
dict['mle'] = mle
if save_iters:
dict['all_iters'] = all_iters
Expand Down
19 changes: 11 additions & 8 deletions test/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,33 +1748,36 @@ def test_save_latent_dynamics() -> None:

def test_save_profile() -> None:
stan = os.path.join(DATAFILES_PATH, 'profile_likelihood.stan')
profile_model = CmdStanModel(stan_file=stan)
profile_model = CmdStanModel(
stan_file=stan, cpp_options={"STAN_THREADS": '1'}, force_compile=True
)

profile_fit = profile_model.sample(
chains=2,
parallel_chains=2,
force_one_process_per_chain=True,
seed=12345,
iter_warmup=100,
iter_sampling=200,
save_profile=True,
)
for i in range(profile_fit.runset.chains):
profile_file = profile_fit.runset.profile_files[i]
assert len(profile_fit.runset.profile_files) == 2
for profile_file in profile_fit.runset.profile_files:
assert os.path.exists(profile_file)

profile_fit = profile_model.sample(
chains=2,
parallel_chains=2,
force_one_process_per_chain=False,
seed=12345,
iter_warmup=100,
iter_sampling=200,
save_latent_dynamics=True,
save_profile=True,
)

for i in range(profile_fit.runset.chains):
profile_file = profile_fit.runset.profile_files[i]
assert len(profile_fit.runset.profile_files) == 1
for profile_file in profile_fit.runset.profile_files:
assert os.path.exists(profile_file)
diagnostics_file = profile_fit.runset.diagnostic_files[i]
assert os.path.exists(diagnostics_file)


def test_xarray_draws() -> None:
Expand Down

0 comments on commit 4f39687

Please sign in to comment.