Skip to content

Commit

Permalink
Made it so that the decomposition of the inverse dense metric is done…
Browse files Browse the repository at this point in the history
… once each time the inverse metric is set (instead of every sample, Issue #2881).

This involved switching to setter/getters for interfacing with dense_e_point so I made the change for diag_e_point as well.

Also changed set_metric verbage to set_inv_metric.
  • Loading branch information
bbbales2 committed Mar 6, 2020
1 parent aaedb1f commit 45fff72
Show file tree
Hide file tree
Showing 19 changed files with 83 additions and 37 deletions.
6 changes: 3 additions & 3 deletions src/stan/mcmc/hmc/hamiltonians/dense_e_metric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class dense_e_metric : public base_hamiltonian<Model, dense_e_point, BaseRNG> {
: base_hamiltonian<Model, dense_e_point, BaseRNG>(model) {}

double T(dense_e_point& z) {
return 0.5 * z.p.transpose() * z.inv_e_metric_ * z.p;
return 0.5 * z.p.transpose() * z.get_inv_metric() * z.p;
}

double tau(dense_e_point& z) { return T(z); }
Expand All @@ -35,7 +35,7 @@ class dense_e_metric : public base_hamiltonian<Model, dense_e_point, BaseRNG> {
return Eigen::VectorXd::Zero(this->model_.num_params_r());
}

Eigen::VectorXd dtau_dp(dense_e_point& z) { return z.inv_e_metric_ * z.p; }
Eigen::VectorXd dtau_dp(dense_e_point& z) { return z.get_inv_metric() * z.p; }

Eigen::VectorXd dphi_dq(dense_e_point& z, callbacks::logger& logger) {
return z.g;
Expand All @@ -51,7 +51,7 @@ class dense_e_metric : public base_hamiltonian<Model, dense_e_point, BaseRNG> {
for (idx_t i = 0; i < u.size(); ++i)
u(i) = rand_dense_gaus();

z.p = z.inv_e_metric_.llt().matrixU().solve(u);
z.p = z.get_transpose_llt_inv_metric().triangularView<Eigen::Upper>().solve(u);
}
};

Expand Down
33 changes: 29 additions & 4 deletions src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,54 @@ namespace mcmc {
* Euclidean manifold with dense metric
*/
class dense_e_point : public ps_point {
public:
private:
/**
* Inverse mass matrix.
*/
Eigen::MatrixXd inv_e_metric_;
Eigen::MatrixXd inv_e_metric_llt_matrixU_;
public:

/**
* Construct a dense point in n-dimensional phase space
* with identity matrix as inverse mass matrix.
*
* @param n number of dimensions
*/
explicit dense_e_point(int n) : ps_point(n), inv_e_metric_(n, n) {
explicit dense_e_point(int n) : ps_point(n),
inv_e_metric_(n, n),
inv_e_metric_llt_matrixU_(n, n) {
inv_e_metric_.setIdentity();
inv_e_metric_llt_matrixU_.setIdentity();
}

/**
* Set elements of mass matrix
* Set inverse metric
*
* @param inv_e_metric initial mass matrix
*/
void set_metric(const Eigen::MatrixXd& inv_e_metric) {
void set_inv_metric(const Eigen::MatrixXd& inv_e_metric) {
inv_e_metric_ = inv_e_metric;
inv_e_metric_llt_matrixU_ = inv_e_metric_.llt().matrixU();
}

/**
* Get inverse metric
*
* @return reference to the inverse metric
*/
const Eigen::MatrixXd& get_inv_metric() {
return inv_e_metric_;
}

/**
* Get the transpose of the lower Cholesky factor
* of the inverse metric
*
* @return reference to transpose of Cholesky factor
*/
const Eigen::MatrixXd& get_transpose_llt_inv_metric() {
return inv_e_metric_llt_matrixU_;
}

/**
Expand Down
6 changes: 3 additions & 3 deletions src/stan/mcmc/hmc/hamiltonians/diag_e_metric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class diag_e_metric : public base_hamiltonian<Model, diag_e_point, BaseRNG> {
: base_hamiltonian<Model, diag_e_point, BaseRNG>(model) {}

double T(diag_e_point& z) {
return 0.5 * z.p.dot(z.inv_e_metric_.cwiseProduct(z.p));
return 0.5 * z.p.dot(z.get_inv_metric().cwiseProduct(z.p));
}

double tau(diag_e_point& z) { return T(z); }
Expand All @@ -34,7 +34,7 @@ class diag_e_metric : public base_hamiltonian<Model, diag_e_point, BaseRNG> {
}

Eigen::VectorXd dtau_dp(diag_e_point& z) {
return z.inv_e_metric_.cwiseProduct(z.p);
return z.get_inv_metric().cwiseProduct(z.p);
}

Eigen::VectorXd dphi_dq(diag_e_point& z, callbacks::logger& logger) {
Expand All @@ -46,7 +46,7 @@ class diag_e_metric : public base_hamiltonian<Model, diag_e_point, BaseRNG> {
rand_diag_gaus(rng, boost::normal_distribution<>());

for (int i = 0; i < z.p.size(); ++i)
z.p(i) = rand_diag_gaus() / sqrt(z.inv_e_metric_(i));
z.p(i) = rand_diag_gaus() / sqrt(z.get_inv_metric()(i));
}
};

Expand Down
14 changes: 12 additions & 2 deletions src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ namespace mcmc {
* Euclidean manifold with diagonal metric
*/
class diag_e_point : public ps_point {
public:
private:
/**
* Vector of diagonal elements of inverse mass matrix.
*/
Eigen::VectorXd inv_e_metric_;

public:
/**
* Construct a diag point in n-dimensional phase space
* with vector of ones for diagonal elements of inverse mass matrix.
Expand All @@ -32,10 +33,19 @@ class diag_e_point : public ps_point {
*
* @param inv_e_metric initial mass matrix
*/
void set_metric(const Eigen::VectorXd& inv_e_metric) {
void set_inv_metric(const Eigen::VectorXd& inv_e_metric) {
inv_e_metric_ = inv_e_metric;
}

/**
* Get inverse metric
*
* @return reference to the inverse metric
*/
const Eigen::VectorXd& get_inv_metric() {
return inv_e_metric_;
}

/**
* Write elements of mass matrix to string and handoff to writer.
*
Expand Down
7 changes: 5 additions & 2 deletions src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ class adapt_dense_e_nuts : public dense_e_nuts<Model, BaseRNG>,
this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_,
s.accept_stat());

bool update = this->covar_adaptation_.learn_covariance(
this->z_.inv_e_metric_, this->z_.q);
Eigen::MatrixXd inv_metric;

bool update = this->covar_adaptation_.learn_covariance(inv_metric, this->z_.q);

if (update) {
this->init_stepsize(logger);

this->z_.set_inv_metric(inv_metric);

this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();
}
Expand Down
9 changes: 6 additions & 3 deletions src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ class adapt_diag_e_nuts : public diag_e_nuts<Model, BaseRNG>,
this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_,
s.accept_stat());

bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_,
this->z_.q);

Eigen::VectorXd inv_metric;

bool update = this->var_adaptation_.learn_variance(inv_metric, this->z_.q);

if (update) {
this->init_stepsize(logger);

this->z_.set_inv_metric(inv_metric);

this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();
}
Expand Down
8 changes: 4 additions & 4 deletions src/stan/mcmc/hmc/nuts/base_nuts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ class base_nuts : public base_hmc<Model, Hamiltonian, Integrator, BaseRNG> {

~base_nuts() {}

void set_metric(const Eigen::MatrixXd& inv_e_metric) {
this->z_.set_metric(inv_e_metric);
void set_inv_metric(const Eigen::MatrixXd& inv_e_metric) {
this->z_.set_inv_metric(inv_e_metric);
}

void set_metric(const Eigen::VectorXd& inv_e_metric) {
this->z_.set_metric(inv_e_metric);
void set_inv_metric(const Eigen::VectorXd& inv_e_metric) {
this->z_.set_inv_metric(inv_e_metric);
}

void set_max_depth(int d) {
Expand Down
6 changes: 4 additions & 2 deletions src/stan/mcmc/hmc/static/adapt_dense_e_static_hmc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ class adapt_dense_e_static_hmc : public dense_e_static_hmc<Model, BaseRNG>,
s.accept_stat());
this->update_L_();

bool update = this->covar_adaptation_.learn_covariance(
this->z_.inv_e_metric_, this->z_.q);
Eigen::MatrixXd inv_metric;

bool update = this->covar_adaptation_.learn_covariance(inv_metric, this->z_.q);

if (update) {
this->init_stepsize(logger);
this->update_L_();
this->z_.set_inv_metric(inv_metric);

this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();
Expand Down
5 changes: 4 additions & 1 deletion src/stan/mcmc/hmc/static/adapt_diag_e_static_hmc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ class adapt_diag_e_static_hmc : public diag_e_static_hmc<Model, BaseRNG>,
s.accept_stat());
this->update_L_();

bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_,
Eigen::VectorXd inv_metric;

bool update = this->var_adaptation_.learn_variance(inv_metric,
this->z_.q);

if (update) {
this->init_stepsize(logger);
this->update_L_();
this->z_.set_inv_metric(inv_metric);

this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();
Expand Down
8 changes: 4 additions & 4 deletions src/stan/mcmc/hmc/static/base_static_hmc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ class base_static_hmc

~base_static_hmc() {}

void set_metric(const Eigen::MatrixXd& inv_e_metric) {
this->z_.set_metric(inv_e_metric);
void set_inv_metric(const Eigen::MatrixXd& inv_e_metric) {
this->z_.set_inv_metric(inv_e_metric);
}

void set_metric(const Eigen::VectorXd& inv_e_metric) {
this->z_.set_metric(inv_e_metric);
void set_inv_metric(const Eigen::VectorXd& inv_e_metric) {
this->z_.set_inv_metric(inv_e_metric);
}

sample transition(sample& init_sample, callbacks::logger& logger) {
Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/sample/hmc_nuts_dense_e.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ int hmc_nuts_dense_e(Model& model, const stan::io::var_context& init,

stan::mcmc::dense_e_nuts<Model, boost::ecuyer1988> sampler(model, rng);

sampler.set_metric(inv_metric);
sampler.set_inv_metric(inv_metric);

sampler.set_nominal_stepsize(stepsize);
sampler.set_stepsize_jitter(stepsize_jitter);
Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ int hmc_nuts_dense_e_adapt(

stan::mcmc::adapt_dense_e_nuts<Model, boost::ecuyer1988> sampler(model, rng);

sampler.set_metric(inv_metric);
sampler.set_inv_metric(inv_metric);

sampler.set_nominal_stepsize(stepsize);
sampler.set_stepsize_jitter(stepsize_jitter);
Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/sample/hmc_nuts_diag_e.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ int hmc_nuts_diag_e(Model& model, const stan::io::var_context& init,

stan::mcmc::diag_e_nuts<Model, boost::ecuyer1988> sampler(model, rng);

sampler.set_metric(inv_metric);
sampler.set_inv_metric(inv_metric);
sampler.set_nominal_stepsize(stepsize);
sampler.set_stepsize_jitter(stepsize_jitter);
sampler.set_max_depth(max_depth);
Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ int hmc_nuts_diag_e_adapt(

stan::mcmc::adapt_diag_e_nuts<Model, boost::ecuyer1988> sampler(model, rng);

sampler.set_metric(inv_metric);
sampler.set_inv_metric(inv_metric);
sampler.set_nominal_stepsize(stepsize);
sampler.set_stepsize_jitter(stepsize_jitter);
sampler.set_max_depth(max_depth);
Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/sample/hmc_static_dense_e.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ int hmc_static_dense_e(

stan::mcmc::dense_e_static_hmc<Model, boost::ecuyer1988> sampler(model, rng);

sampler.set_metric(inv_metric);
sampler.set_inv_metric(inv_metric);
sampler.set_nominal_stepsize_and_T(stepsize, int_time);
sampler.set_stepsize_jitter(stepsize_jitter);

Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/sample/hmc_static_dense_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ int hmc_static_dense_e_adapt(
stan::mcmc::adapt_dense_e_static_hmc<Model, boost::ecuyer1988> sampler(model,
rng);

sampler.set_metric(inv_metric);
sampler.set_inv_metric(inv_metric);
sampler.set_nominal_stepsize_and_T(stepsize, int_time);
sampler.set_stepsize_jitter(stepsize_jitter);

Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/sample/hmc_static_diag_e.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ int hmc_static_diag_e(Model& model, const stan::io::var_context& init,

stan::mcmc::diag_e_static_hmc<Model, boost::ecuyer1988> sampler(model, rng);

sampler.set_metric(inv_metric);
sampler.set_inv_metric(inv_metric);
sampler.set_nominal_stepsize_and_T(stepsize, int_time);
sampler.set_stepsize_jitter(stepsize_jitter);

Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/sample/hmc_static_diag_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ int hmc_static_diag_e_adapt(
stan::mcmc::adapt_diag_e_static_hmc<Model, boost::ecuyer1988> sampler(model,
rng);

sampler.set_metric(inv_metric);
sampler.set_inv_metric(inv_metric);
sampler.set_nominal_stepsize_and_T(stepsize, int_time);
sampler.set_stepsize_jitter(stepsize_jitter);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ TEST(McmcDenseEMetric, sample_p) {

stan::mcmc::dense_e_metric<stan::mcmc::mock_model, rng_t> metric(model);
stan::mcmc::dense_e_point z(2);
z.set_metric(m_inv);
z.set_inv_metric(m_inv);

int n_samples = 1000;

Expand Down

0 comments on commit 45fff72

Please sign in to comment.