diff --git a/src/stan/mcmc/hmc/hamiltonians/dense_e_metric.hpp b/src/stan/mcmc/hmc/hamiltonians/dense_e_metric.hpp index 3f177b5a13e..7d99ae65e22 100644 --- a/src/stan/mcmc/hmc/hamiltonians/dense_e_metric.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/dense_e_metric.hpp @@ -20,7 +20,7 @@ class dense_e_metric : public base_hamiltonian { : base_hamiltonian(model) {} double T(dense_e_point& z) { - return 0.5 * z.p.transpose() * z.inv_e_metric_ * z.p; + return 0.5 * z.p.transpose() * z.get_inv_metric() * z.p; } double tau(dense_e_point& z) { return T(z); } @@ -35,7 +35,7 @@ class dense_e_metric : public base_hamiltonian { return Eigen::VectorXd::Zero(this->model_.num_params_r()); } - Eigen::VectorXd dtau_dp(dense_e_point& z) { return z.inv_e_metric_ * z.p; } + Eigen::VectorXd dtau_dp(dense_e_point& z) { return z.get_inv_metric() * z.p; } Eigen::VectorXd dphi_dq(dense_e_point& z, callbacks::logger& logger) { return z.g; @@ -51,7 +51,7 @@ class dense_e_metric : public base_hamiltonian { for (idx_t i = 0; i < u.size(); ++i) u(i) = rand_dense_gaus(); - z.p = z.inv_e_metric_.llt().matrixU().solve(u); + z.p = z.get_transpose_llt_inv_metric().triangularView().solve(u); } }; diff --git a/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp b/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp index 3b811e280e3..ca96eced6eb 100644 --- a/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp @@ -11,11 +11,13 @@ namespace mcmc { * Euclidean manifold with dense metric */ class dense_e_point : public ps_point { - public: +private: /** * Inverse mass matrix. */ Eigen::MatrixXd inv_e_metric_; + Eigen::MatrixXd inv_e_metric_llt_matrixU_; +public: /** * Construct a dense point in n-dimensional phase space @@ -23,17 +25,40 @@ class dense_e_point : public ps_point { * * @param n number of dimensions */ - explicit dense_e_point(int n) : ps_point(n), inv_e_metric_(n, n) { + explicit dense_e_point(int n) : ps_point(n), + inv_e_metric_(n, n), + inv_e_metric_llt_matrixU_(n, n) { inv_e_metric_.setIdentity(); + inv_e_metric_llt_matrixU_.setIdentity(); } /** - * Set elements of mass matrix + * Set inverse metric * * @param inv_e_metric initial mass matrix */ - void set_metric(const Eigen::MatrixXd& inv_e_metric) { + void set_inv_metric(const Eigen::MatrixXd& inv_e_metric) { inv_e_metric_ = inv_e_metric; + inv_e_metric_llt_matrixU_ = inv_e_metric_.llt().matrixU(); + } + + /** + * Get inverse metric + * + * @return reference to the inverse metric + */ + const Eigen::MatrixXd& get_inv_metric() { + return inv_e_metric_; + } + + /** + * Get the transpose of the lower Cholesky factor + * of the inverse metric + * + * @return reference to transpose of Cholesky factor + */ + const Eigen::MatrixXd& get_transpose_llt_inv_metric() { + return inv_e_metric_llt_matrixU_; } /** diff --git a/src/stan/mcmc/hmc/hamiltonians/diag_e_metric.hpp b/src/stan/mcmc/hmc/hamiltonians/diag_e_metric.hpp index 98bfee84294..2762899f765 100644 --- a/src/stan/mcmc/hmc/hamiltonians/diag_e_metric.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/diag_e_metric.hpp @@ -18,7 +18,7 @@ class diag_e_metric : public base_hamiltonian { : base_hamiltonian(model) {} double T(diag_e_point& z) { - return 0.5 * z.p.dot(z.inv_e_metric_.cwiseProduct(z.p)); + return 0.5 * z.p.dot(z.get_inv_metric().cwiseProduct(z.p)); } double tau(diag_e_point& z) { return T(z); } @@ -34,7 +34,7 @@ class diag_e_metric : public base_hamiltonian { } Eigen::VectorXd dtau_dp(diag_e_point& z) { - return z.inv_e_metric_.cwiseProduct(z.p); + return z.get_inv_metric().cwiseProduct(z.p); } Eigen::VectorXd dphi_dq(diag_e_point& z, callbacks::logger& logger) { @@ -46,7 +46,7 @@ class diag_e_metric : public base_hamiltonian { rand_diag_gaus(rng, boost::normal_distribution<>()); for (int i = 0; i < z.p.size(); ++i) - z.p(i) = rand_diag_gaus() / sqrt(z.inv_e_metric_(i)); + z.p(i) = rand_diag_gaus() / sqrt(z.get_inv_metric()(i)); } }; diff --git a/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp b/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp index eb2c2d9f5d5..af136812afd 100644 --- a/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp @@ -11,12 +11,13 @@ namespace mcmc { * Euclidean manifold with diagonal metric */ class diag_e_point : public ps_point { - public: +private: /** * Vector of diagonal elements of inverse mass matrix. */ Eigen::VectorXd inv_e_metric_; +public: /** * Construct a diag point in n-dimensional phase space * with vector of ones for diagonal elements of inverse mass matrix. @@ -32,10 +33,19 @@ class diag_e_point : public ps_point { * * @param inv_e_metric initial mass matrix */ - void set_metric(const Eigen::VectorXd& inv_e_metric) { + void set_inv_metric(const Eigen::VectorXd& inv_e_metric) { inv_e_metric_ = inv_e_metric; } + /** + * Get inverse metric + * + * @return reference to the inverse metric + */ + const Eigen::VectorXd& get_inv_metric() { + return inv_e_metric_; + } + /** * Write elements of mass matrix to string and handoff to writer. * diff --git a/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp index 05b6c80523f..e78f9440cd6 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp @@ -29,12 +29,15 @@ class adapt_dense_e_nuts : public dense_e_nuts, this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update = this->covar_adaptation_.learn_covariance( - this->z_.inv_e_metric_, this->z_.q); + Eigen::MatrixXd inv_metric; + + bool update = this->covar_adaptation_.learn_covariance(inv_metric, this->z_.q); if (update) { this->init_stepsize(logger); + this->z_.set_inv_metric(inv_metric); + this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); this->stepsize_adaptation_.restart(); } diff --git a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp index 45e92380f57..94a05abc70a 100644 --- a/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp @@ -29,12 +29,15 @@ class adapt_diag_e_nuts : public diag_e_nuts, this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_, s.accept_stat()); - bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, - this->z_.q); - + Eigen::VectorXd inv_metric; + + bool update = this->var_adaptation_.learn_variance(inv_metric, this->z_.q); + if (update) { this->init_stepsize(logger); + this->z_.set_inv_metric(inv_metric); + this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); this->stepsize_adaptation_.restart(); } diff --git a/src/stan/mcmc/hmc/nuts/base_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_nuts.hpp index 7b08ff8ce5d..a495a06754a 100644 --- a/src/stan/mcmc/hmc/nuts/base_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/base_nuts.hpp @@ -58,12 +58,12 @@ class base_nuts : public base_hmc { ~base_nuts() {} - void set_metric(const Eigen::MatrixXd& inv_e_metric) { - this->z_.set_metric(inv_e_metric); + void set_inv_metric(const Eigen::MatrixXd& inv_e_metric) { + this->z_.set_inv_metric(inv_e_metric); } - void set_metric(const Eigen::VectorXd& inv_e_metric) { - this->z_.set_metric(inv_e_metric); + void set_inv_metric(const Eigen::VectorXd& inv_e_metric) { + this->z_.set_inv_metric(inv_e_metric); } void set_max_depth(int d) { diff --git a/src/stan/mcmc/hmc/static/adapt_dense_e_static_hmc.hpp b/src/stan/mcmc/hmc/static/adapt_dense_e_static_hmc.hpp index 7691fc6f6e2..957c96ee390 100644 --- a/src/stan/mcmc/hmc/static/adapt_dense_e_static_hmc.hpp +++ b/src/stan/mcmc/hmc/static/adapt_dense_e_static_hmc.hpp @@ -32,12 +32,14 @@ class adapt_dense_e_static_hmc : public dense_e_static_hmc, s.accept_stat()); this->update_L_(); - bool update = this->covar_adaptation_.learn_covariance( - this->z_.inv_e_metric_, this->z_.q); + Eigen::MatrixXd inv_metric; + + bool update = this->covar_adaptation_.learn_covariance(inv_metric, this->z_.q); if (update) { this->init_stepsize(logger); this->update_L_(); + this->z_.set_inv_metric(inv_metric); this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); this->stepsize_adaptation_.restart(); diff --git a/src/stan/mcmc/hmc/static/adapt_diag_e_static_hmc.hpp b/src/stan/mcmc/hmc/static/adapt_diag_e_static_hmc.hpp index a2098ef95f0..8f443b01cc3 100644 --- a/src/stan/mcmc/hmc/static/adapt_diag_e_static_hmc.hpp +++ b/src/stan/mcmc/hmc/static/adapt_diag_e_static_hmc.hpp @@ -32,12 +32,15 @@ class adapt_diag_e_static_hmc : public diag_e_static_hmc, s.accept_stat()); this->update_L_(); - bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_, + Eigen::VectorXd inv_metric; + + bool update = this->var_adaptation_.learn_variance(inv_metric, this->z_.q); if (update) { this->init_stepsize(logger); this->update_L_(); + this->z_.set_inv_metric(inv_metric); this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_)); this->stepsize_adaptation_.restart(); diff --git a/src/stan/mcmc/hmc/static/base_static_hmc.hpp b/src/stan/mcmc/hmc/static/base_static_hmc.hpp index 4dab52ad9b6..7d30447138f 100644 --- a/src/stan/mcmc/hmc/static/base_static_hmc.hpp +++ b/src/stan/mcmc/hmc/static/base_static_hmc.hpp @@ -30,12 +30,12 @@ class base_static_hmc ~base_static_hmc() {} - void set_metric(const Eigen::MatrixXd& inv_e_metric) { - this->z_.set_metric(inv_e_metric); + void set_inv_metric(const Eigen::MatrixXd& inv_e_metric) { + this->z_.set_inv_metric(inv_e_metric); } - void set_metric(const Eigen::VectorXd& inv_e_metric) { - this->z_.set_metric(inv_e_metric); + void set_inv_metric(const Eigen::VectorXd& inv_e_metric) { + this->z_.set_inv_metric(inv_e_metric); } sample transition(sample& init_sample, callbacks::logger& logger) { diff --git a/src/stan/services/sample/hmc_nuts_dense_e.hpp b/src/stan/services/sample/hmc_nuts_dense_e.hpp index f57466503db..671763279b7 100644 --- a/src/stan/services/sample/hmc_nuts_dense_e.hpp +++ b/src/stan/services/sample/hmc_nuts_dense_e.hpp @@ -73,7 +73,7 @@ int hmc_nuts_dense_e(Model& model, const stan::io::var_context& init, stan::mcmc::dense_e_nuts sampler(model, rng); - sampler.set_metric(inv_metric); + sampler.set_inv_metric(inv_metric); sampler.set_nominal_stepsize(stepsize); sampler.set_stepsize_jitter(stepsize_jitter); diff --git a/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp index cb644a346c6..2c071d9199f 100644 --- a/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp @@ -81,7 +81,7 @@ int hmc_nuts_dense_e_adapt( stan::mcmc::adapt_dense_e_nuts sampler(model, rng); - sampler.set_metric(inv_metric); + sampler.set_inv_metric(inv_metric); sampler.set_nominal_stepsize(stepsize); sampler.set_stepsize_jitter(stepsize_jitter); diff --git a/src/stan/services/sample/hmc_nuts_diag_e.hpp b/src/stan/services/sample/hmc_nuts_diag_e.hpp index 3932a870dba..dc1f19724c0 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e.hpp @@ -73,7 +73,7 @@ int hmc_nuts_diag_e(Model& model, const stan::io::var_context& init, stan::mcmc::diag_e_nuts sampler(model, rng); - sampler.set_metric(inv_metric); + sampler.set_inv_metric(inv_metric); sampler.set_nominal_stepsize(stepsize); sampler.set_stepsize_jitter(stepsize_jitter); sampler.set_max_depth(max_depth); diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index 8d03373e1a5..87211e6f2b2 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -81,7 +81,7 @@ int hmc_nuts_diag_e_adapt( stan::mcmc::adapt_diag_e_nuts sampler(model, rng); - sampler.set_metric(inv_metric); + sampler.set_inv_metric(inv_metric); sampler.set_nominal_stepsize(stepsize); sampler.set_stepsize_jitter(stepsize_jitter); sampler.set_max_depth(max_depth); diff --git a/src/stan/services/sample/hmc_static_dense_e.hpp b/src/stan/services/sample/hmc_static_dense_e.hpp index 5ddb3d224eb..eead72689d9 100644 --- a/src/stan/services/sample/hmc_static_dense_e.hpp +++ b/src/stan/services/sample/hmc_static_dense_e.hpp @@ -70,7 +70,7 @@ int hmc_static_dense_e( stan::mcmc::dense_e_static_hmc sampler(model, rng); - sampler.set_metric(inv_metric); + sampler.set_inv_metric(inv_metric); sampler.set_nominal_stepsize_and_T(stepsize, int_time); sampler.set_stepsize_jitter(stepsize_jitter); diff --git a/src/stan/services/sample/hmc_static_dense_e_adapt.hpp b/src/stan/services/sample/hmc_static_dense_e_adapt.hpp index 979e5ded1b8..6eec480bc6d 100644 --- a/src/stan/services/sample/hmc_static_dense_e_adapt.hpp +++ b/src/stan/services/sample/hmc_static_dense_e_adapt.hpp @@ -82,7 +82,7 @@ int hmc_static_dense_e_adapt( stan::mcmc::adapt_dense_e_static_hmc sampler(model, rng); - sampler.set_metric(inv_metric); + sampler.set_inv_metric(inv_metric); sampler.set_nominal_stepsize_and_T(stepsize, int_time); sampler.set_stepsize_jitter(stepsize_jitter); diff --git a/src/stan/services/sample/hmc_static_diag_e.hpp b/src/stan/services/sample/hmc_static_diag_e.hpp index 57109395f5d..b1aa63465e7 100644 --- a/src/stan/services/sample/hmc_static_diag_e.hpp +++ b/src/stan/services/sample/hmc_static_diag_e.hpp @@ -75,7 +75,7 @@ int hmc_static_diag_e(Model& model, const stan::io::var_context& init, stan::mcmc::diag_e_static_hmc sampler(model, rng); - sampler.set_metric(inv_metric); + sampler.set_inv_metric(inv_metric); sampler.set_nominal_stepsize_and_T(stepsize, int_time); sampler.set_stepsize_jitter(stepsize_jitter); diff --git a/src/stan/services/sample/hmc_static_diag_e_adapt.hpp b/src/stan/services/sample/hmc_static_diag_e_adapt.hpp index f374f205c62..86a4896988c 100644 --- a/src/stan/services/sample/hmc_static_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_static_diag_e_adapt.hpp @@ -82,7 +82,7 @@ int hmc_static_diag_e_adapt( stan::mcmc::adapt_diag_e_static_hmc sampler(model, rng); - sampler.set_metric(inv_metric); + sampler.set_inv_metric(inv_metric); sampler.set_nominal_stepsize_and_T(stepsize, int_time); sampler.set_stepsize_jitter(stepsize_jitter); diff --git a/src/test/unit/mcmc/hmc/hamiltonians/dense_e_metric_test.cpp b/src/test/unit/mcmc/hmc/hamiltonians/dense_e_metric_test.cpp index 2340609aebe..50dd3227f5a 100644 --- a/src/test/unit/mcmc/hmc/hamiltonians/dense_e_metric_test.cpp +++ b/src/test/unit/mcmc/hmc/hamiltonians/dense_e_metric_test.cpp @@ -25,7 +25,7 @@ TEST(McmcDenseEMetric, sample_p) { stan::mcmc::dense_e_metric metric(model); stan::mcmc::dense_e_point z(2); - z.set_metric(m_inv); + z.set_inv_metric(m_inv); int n_samples = 1000;