Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a num_threads helper argument to pathfinder() #741

Merged
merged 4 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmdstanpy/cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,7 @@ def validate(self) -> None:
if not (
isinstance(self.method_args, SamplerArgs)
and self.method_args.num_chains > 1
or isinstance(self.method_args, PathfinderArgs)
):
if not os.path.exists(self.inits):
raise ValueError('no such file {}'.format(self.inits))
Expand Down
16 changes: 16 additions & 0 deletions cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1635,6 +1635,7 @@ def pathfinder(
refresh: Optional[int] = None,
time_fmt: str = "%Y%m%d%H%M%S",
timeout: Optional[float] = None,
num_threads: Optional[int] = None,
) -> CmdStanPathfinder:
"""
Run CmdStan's Pathfinder variational inference algorithm.
Expand Down Expand Up @@ -1737,6 +1738,10 @@ def pathfinder(
:param timeout: Duration at which Pathfinder times
out in seconds. Defaults to None.

:param num_threads: Number of threads to request for parallel execution.
A number other than ``1`` requires the model to have been compiled
with STAN_THREADS=True.

:return: A :class:`CmdStanPathfinder` object

References
Expand All @@ -1763,6 +1768,17 @@ def pathfinder(
"available for CmdStan versions 2.34 and later"
)

if num_threads is not None:
if (
num_threads != 1
and exe_info.get('STAN_THREADS', '').lower() != 'true'
):
raise ValueError(
"Model must be compiled with 'STAN_THREADS=true' to use"
" 'num_threads' argument"
)
os.environ['STAN_NUM_THREADS'] = str(num_threads)

if num_paths == 1:
if num_single_draws is None:
num_single_draws = draws
Expand Down
39 changes: 39 additions & 0 deletions test/test_pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Tests for the Pathfinder method.
"""

import contextlib
from io import StringIO
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -129,6 +131,26 @@ def test_pathfinder_init_sampling():
assert fit.draws().shape == (1000, 4, 9)


def test_inits_for_pathfinder():
stan = DATAFILES_PATH / 'bernoulli.stan'
bern_model = cmdstanpy.CmdStanModel(stan_file=stan)
jdata = str(DATAFILES_PATH / 'bernoulli.data.json')
bern_model.pathfinder(
jdata, inits=[{"theta": 0.1}, {"theta": 0.9}], num_paths=2
)

# second path is initialized too large!
with contextlib.redirect_stdout(StringIO()) as captured:
bern_model.pathfinder(
jdata,
inits=[{"theta": 0.1}, {"theta": 1.1}],
num_paths=2,
show_console=True,
)

assert "Bounded variable is 1.1" in captured.getvalue()


def test_pathfinder_no_psis():
stan = DATAFILES_PATH / 'bernoulli.stan'
bern_model = cmdstanpy.CmdStanModel(stan_file=stan)
Expand All @@ -152,3 +174,20 @@ def test_pathfinder_no_lp_calc():
n_lp_nan = np.sum(np.isnan(pathfinder.method_variables()['lp__']))
assert n_lp_nan < 4000 # some lp still calculated during pathfinder
assert n_lp_nan > 3000 # but most are not


def test_pathfinder_threads():
stan = DATAFILES_PATH / 'bernoulli.stan'
bern_model = cmdstanpy.CmdStanModel(stan_file=stan)
jdata = str(DATAFILES_PATH / 'bernoulli.data.json')

bern_model.pathfinder(data=jdata, num_threads=1)

with pytest.raises(ValueError, match="STAN_THREADS"):
bern_model.pathfinder(data=jdata, num_threads=4)

bern_model = cmdstanpy.CmdStanModel(
stan_file=stan, cpp_options={'STAN_THREADS': True}, force_compile=True
)
pathfinder = bern_model.pathfinder(data=jdata, num_threads=4)
assert pathfinder.draws().shape == (1000, 3)
4 changes: 3 additions & 1 deletion test/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
)
def test_bernoulli_good(stanfile: str):
stan = os.path.join(DATAFILES_PATH, stanfile)
bern_model = CmdStanModel(stan_file=stan)
bern_model = CmdStanModel(stan_file=stan, force_compile=True)

jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
bern_fit = bern_model.sample(
Expand All @@ -74,6 +74,8 @@ def test_bernoulli_good(stanfile: str):

for i in range(bern_fit.runset.chains):
csv_file = bern_fit.runset.csv_files[i]
# NB: This will fail if STAN_THREADS is enabled
# due to sampling only producing 1 stdout file in that case
stdout_file = bern_fit.runset.stdout_files[i]
assert os.path.exists(csv_file)
assert os.path.exists(stdout_file)
Expand Down