Skip to content

Commit

Permalink
Add new Pathfinder arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Jan 19, 2024
1 parent 078c43a commit dc1b939
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 8 deletions.
10 changes: 10 additions & 0 deletions cmdstanpy/cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,8 @@ def __init__(
num_draws: Optional[int] = None,
num_elbo_draws: Optional[int] = None,
save_single_paths: bool = False,
psis_resample: bool = True,
calculate_lp: bool = True,
) -> None:
self.init_alpha = init_alpha
self.tol_obj = tol_obj
Expand All @@ -557,6 +559,8 @@ def __init__(
self.num_elbo_draws = num_elbo_draws

self.save_single_paths = save_single_paths
self.psis_resample = psis_resample
self.calculate_lp = calculate_lp

def validate(self, _chains: Optional[int] = None) -> None:
"""
Expand Down Expand Up @@ -609,6 +613,12 @@ def compose(self, _idx: int, cmd: List[str]) -> List[str]:
if self.save_single_paths:
cmd.append('save_single_paths=1')

if not self.psis_resample:
cmd.append('psis_resample=0')

if not self.calculate_lp:
cmd.append('calculate_lp=0')

return cmd


Expand Down
24 changes: 23 additions & 1 deletion cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,6 +1612,8 @@ def pathfinder(
draws: Optional[int] = None,
num_single_draws: Optional[int] = None,
num_elbo_draws: Optional[int] = None,
psis_resample: bool = True,
calculate_lp: bool = True,
# arguments standard to all methods
seed: Optional[int] = None,
inits: Union[Dict[str, float], float, str, os.PathLike, None] = None,
Expand Down Expand Up @@ -1645,6 +1647,14 @@ def pathfinder(
:param num_elbo_draws: Number of Monte Carlo draws to evaluate ELBO.
:param psis_resample: Whether or not to use Pareto Smoothed Importance
Sampling on the result of the individual Pathfinders. If False, the
result contains the draws from each path.
:param calculate_lp: Whether or not to calculate the log probability
for approximate draws. If False, this also implies that
``psis_resample`` will be set to False.
:param seed: The seed for random number generator. Must be an integer
between 0 and 2^32 - 1. If unspecified,
:func:`numpy.random.default_rng` is used to generate a seed.
Expand Down Expand Up @@ -1726,12 +1736,22 @@ def pathfinder(
Research, 23(306), 1–49. Retrieved from
http://jmlr.org/papers/v23/21-0889.html
"""
if cmdstan_version_before(2, 33, self.exe_info()):

exe_info = self.exe_info()
if cmdstan_version_before(2, 33, exe_info):
raise ValueError(
"Method 'pathfinder' not available for CmdStan versions "
"before 2.33"
)

if (not psis_resample or not calculate_lp) and cmdstan_version_before(
2, 34, exe_info
):
raise ValueError(
"Arguments 'psis_resample' and 'calculate_lp' are only "
"available for CmdStan versions 2.34 and later"
)

if num_paths == 1:
if num_single_draws is None:
num_single_draws = draws
Expand All @@ -1754,6 +1774,8 @@ def pathfinder(
max_lbfgs_iters=max_lbfgs_iters,
num_draws=num_single_draws,
num_elbo_draws=num_elbo_draws,
psis_resample=psis_resample,
calculate_lp=calculate_lp,
)

with temp_single_json(data) as _data, temp_inits(inits) as _inits:
Expand Down
12 changes: 12 additions & 0 deletions cmdstanpy/stanfit/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,18 @@ def column_names(self) -> Tuple[str, ...]:
"""
return self._metadata.cmdstan_config['column_names'] # type: ignore

@property
def is_resampled(self) -> bool:
"""
Returns True if the draws were resampled from several Pathfinder
approximations, False otherwise.
"""
return ( # type: ignore
self._metadata.cmdstan_config.get("num_paths", 4) > 1
and self._metadata.cmdstan_config.get('psis_resample', 1) == 1
and self._metadata.cmdstan_config.get('calculate_lp', 1) == 1
)

def save_csvfiles(self, dir: Optional[str] = None) -> None:
"""
Move output CSV files to specified directory. If files were
Expand Down
22 changes: 15 additions & 7 deletions test/test_log_prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,21 @@
BERN_BASENAME = 'bernoulli'


@pytest.mark.parametrize("sig_figs, expected, expected_unadjusted", [
(11, ["-7.0214667713","-1.188472607"], ["-5.5395901199", "-1.4903938392"]),
(3, ["-7.02", "-1.19"], ["-5.54", "-1.49"]),
(None, ["-7.02147", "-1.18847"], ["-5.53959", "-1.49039"])
])
def test_lp_good(sig_figs: Optional[int], expected: List[str],
expected_unadjusted: List[str]) -> None:
@pytest.mark.parametrize(
"sig_figs, expected, expected_unadjusted",
[
(
11,
["-7.0214667713", "-1.188472607"],
["-5.5395901199", "-1.4903938392"],
),
(3, ["-7.02", "-1.19"], ["-5.54", "-1.49"]),
(None, ["-7.02147", "-1.18847"], ["-5.53959", "-1.49039"]),
],
)
def test_lp_good(
sig_figs: Optional[int], expected: List[str], expected_unadjusted: List[str]
) -> None:
model = CmdStanModel(stan_file=BERN_STAN)
params = {"theta": 0.34903938392023830482}
out = model.log_prob(params, data=BERN_DATA, sig_figs=sig_figs)
Expand Down
30 changes: 30 additions & 0 deletions test/test_pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pathlib import Path

import numpy as np
import pytest

import cmdstanpy
Expand Down Expand Up @@ -31,6 +32,8 @@ def test_pathfinder_outputs():
assert theta.shape == (draws,)
assert 0.23 < theta.mean() < 0.27

assert pathfinder.is_resampled

assert pathfinder.draws().shape == (draws, 3)


Expand Down Expand Up @@ -58,6 +61,8 @@ def test_single_pathfinder():
draws=draws,
)

assert not pathfinder.is_resampled

theta = pathfinder.theta
assert theta.shape == (draws,)

Expand Down Expand Up @@ -122,3 +127,28 @@ def test_pathfinder_init_sampling():

assert fit.chains == 4
assert fit.draws().shape == (1000, 4, 9)


def test_pathfinder_no_psis():
stan = DATAFILES_PATH / 'bernoulli.stan'
bern_model = cmdstanpy.CmdStanModel(stan_file=stan)
jdata = str(DATAFILES_PATH / 'bernoulli.data.json')

pathfinder = bern_model.pathfinder(data=jdata, psis_resample=False)

assert not pathfinder.is_resampled
assert pathfinder.draws().shape == (4000, 3)


def test_pathfinder_no_lp_calc():
stan = DATAFILES_PATH / 'bernoulli.stan'
bern_model = cmdstanpy.CmdStanModel(stan_file=stan)
jdata = str(DATAFILES_PATH / 'bernoulli.data.json')

pathfinder = bern_model.pathfinder(data=jdata, calculate_lp=False)

assert not pathfinder.is_resampled
assert pathfinder.draws().shape == (4000, 3)
n_lp_nan = np.sum(np.isnan(pathfinder.method_variables()['lp__']))
assert n_lp_nan < 4000 # some lp still calculated during pathfinder
assert n_lp_nan > 3000 # but most are not

0 comments on commit dc1b939

Please sign in to comment.