diff --git a/cmdstanpy/cmdstan_args.py b/cmdstanpy/cmdstan_args.py index 2e5d045e..984d6966 100644 --- a/cmdstanpy/cmdstan_args.py +++ b/cmdstanpy/cmdstan_args.py @@ -541,6 +541,8 @@ def __init__( num_draws: Optional[int] = None, num_elbo_draws: Optional[int] = None, save_single_paths: bool = False, + psis_resample: bool = True, + calculate_lp: bool = True, ) -> None: self.init_alpha = init_alpha self.tol_obj = tol_obj @@ -557,6 +559,8 @@ def __init__( self.num_elbo_draws = num_elbo_draws self.save_single_paths = save_single_paths + self.psis_resample = psis_resample + self.calculate_lp = calculate_lp def validate(self, _chains: Optional[int] = None) -> None: """ @@ -609,6 +613,12 @@ def compose(self, _idx: int, cmd: List[str]) -> List[str]: if self.save_single_paths: cmd.append('save_single_paths=1') + if not self.psis_resample: + cmd.append('psis_resample=0') + + if not self.calculate_lp: + cmd.append('calculate_lp=0') + return cmd diff --git a/cmdstanpy/model.py b/cmdstanpy/model.py index 05a9a4b2..3ec13b90 100644 --- a/cmdstanpy/model.py +++ b/cmdstanpy/model.py @@ -1612,6 +1612,8 @@ def pathfinder( draws: Optional[int] = None, num_single_draws: Optional[int] = None, num_elbo_draws: Optional[int] = None, + psis_resample: bool = True, + calculate_lp: bool = True, # arguments standard to all methods seed: Optional[int] = None, inits: Union[Dict[str, float], float, str, os.PathLike, None] = None, @@ -1645,6 +1647,14 @@ def pathfinder( :param num_elbo_draws: Number of Monte Carlo draws to evaluate ELBO. + :param psis_resample: Whether or not to use Pareto Smoothed Importance + Sampling on the result of the individual Pathfinders. If False, the + result contains the draws from each path. + + :param calculate_lp: Whether or not to calculate the log probability + for approximate draws. If False, this also implies that + ``psis_resample`` will be set to False. + :param seed: The seed for random number generator. Must be an integer between 0 and 2^32 - 1. If unspecified, :func:`numpy.random.default_rng` is used to generate a seed. @@ -1726,12 +1736,22 @@ def pathfinder( Research, 23(306), 1–49. Retrieved from http://jmlr.org/papers/v23/21-0889.html """ - if cmdstan_version_before(2, 33, self.exe_info()): + + exe_info = self.exe_info() + if cmdstan_version_before(2, 33, exe_info): raise ValueError( "Method 'pathfinder' not available for CmdStan versions " "before 2.33" ) + if (not psis_resample or not calculate_lp) and cmdstan_version_before( + 2, 34, exe_info + ): + raise ValueError( + "Arguments 'psis_resample' and 'calculate_lp' are only " + "available for CmdStan versions 2.34 and later" + ) + if num_paths == 1: if num_single_draws is None: num_single_draws = draws @@ -1754,6 +1774,8 @@ def pathfinder( max_lbfgs_iters=max_lbfgs_iters, num_draws=num_single_draws, num_elbo_draws=num_elbo_draws, + psis_resample=psis_resample, + calculate_lp=calculate_lp, ) with temp_single_json(data) as _data, temp_inits(inits) as _inits: diff --git a/cmdstanpy/stanfit/pathfinder.py b/cmdstanpy/stanfit/pathfinder.py index 8c7d867c..7ec0b7a2 100644 --- a/cmdstanpy/stanfit/pathfinder.py +++ b/cmdstanpy/stanfit/pathfinder.py @@ -206,6 +206,18 @@ def column_names(self) -> Tuple[str, ...]: """ return self._metadata.cmdstan_config['column_names'] # type: ignore + @property + def is_resampled(self) -> bool: + """ + Returns True if the draws were resampled from several Pathfinder + approximations, False otherwise. + """ + return ( # type: ignore + self._metadata.cmdstan_config.get("num_paths", 4) > 1 + and self._metadata.cmdstan_config.get('psis_resample', 1) == 1 + and self._metadata.cmdstan_config.get('calculate_lp', 1) == 1 + ) + def save_csvfiles(self, dir: Optional[str] = None) -> None: """ Move output CSV files to specified directory. If files were diff --git a/test/test_log_prob.py b/test/test_log_prob.py index 3e021296..bed2f57b 100644 --- a/test/test_log_prob.py +++ b/test/test_log_prob.py @@ -21,13 +21,21 @@ BERN_BASENAME = 'bernoulli' -@pytest.mark.parametrize("sig_figs, expected, expected_unadjusted", [ - (11, ["-7.0214667713","-1.188472607"], ["-5.5395901199", "-1.4903938392"]), - (3, ["-7.02", "-1.19"], ["-5.54", "-1.49"]), - (None, ["-7.02147", "-1.18847"], ["-5.53959", "-1.49039"]) -]) -def test_lp_good(sig_figs: Optional[int], expected: List[str], - expected_unadjusted: List[str]) -> None: +@pytest.mark.parametrize( + "sig_figs, expected, expected_unadjusted", + [ + ( + 11, + ["-7.0214667713", "-1.188472607"], + ["-5.5395901199", "-1.4903938392"], + ), + (3, ["-7.02", "-1.19"], ["-5.54", "-1.49"]), + (None, ["-7.02147", "-1.18847"], ["-5.53959", "-1.49039"]), + ], +) +def test_lp_good( + sig_figs: Optional[int], expected: List[str], expected_unadjusted: List[str] +) -> None: model = CmdStanModel(stan_file=BERN_STAN) params = {"theta": 0.34903938392023830482} out = model.log_prob(params, data=BERN_DATA, sig_figs=sig_figs) diff --git a/test/test_pathfinder.py b/test/test_pathfinder.py index b359c0c0..b8f7c050 100644 --- a/test/test_pathfinder.py +++ b/test/test_pathfinder.py @@ -4,6 +4,7 @@ from pathlib import Path +import numpy as np import pytest import cmdstanpy @@ -31,6 +32,8 @@ def test_pathfinder_outputs(): assert theta.shape == (draws,) assert 0.23 < theta.mean() < 0.27 + assert pathfinder.is_resampled + assert pathfinder.draws().shape == (draws, 3) @@ -58,6 +61,8 @@ def test_single_pathfinder(): draws=draws, ) + assert not pathfinder.is_resampled + theta = pathfinder.theta assert theta.shape == (draws,) @@ -122,3 +127,28 @@ def test_pathfinder_init_sampling(): assert fit.chains == 4 assert fit.draws().shape == (1000, 4, 9) + + +def test_pathfinder_no_psis(): + stan = DATAFILES_PATH / 'bernoulli.stan' + bern_model = cmdstanpy.CmdStanModel(stan_file=stan) + jdata = str(DATAFILES_PATH / 'bernoulli.data.json') + + pathfinder = bern_model.pathfinder(data=jdata, psis_resample=False) + + assert not pathfinder.is_resampled + assert pathfinder.draws().shape == (4000, 3) + + +def test_pathfinder_no_lp_calc(): + stan = DATAFILES_PATH / 'bernoulli.stan' + bern_model = cmdstanpy.CmdStanModel(stan_file=stan) + jdata = str(DATAFILES_PATH / 'bernoulli.data.json') + + pathfinder = bern_model.pathfinder(data=jdata, calculate_lp=False) + + assert not pathfinder.is_resampled + assert pathfinder.draws().shape == (4000, 3) + n_lp_nan = np.sum(np.isnan(pathfinder.method_variables()['lp__'])) + assert n_lp_nan < 4000 # some lp still calculated during pathfinder + assert n_lp_nan > 3000 # but most are not