Skip to content

Commit

Permalink
Update seed/rng-dependent tests
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Jan 30, 2024
1 parent 616cce9 commit 09a7d48
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 90 deletions.
6 changes: 3 additions & 3 deletions src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ TEST(McmcNutsBaseNuts, divergence_test) {
}

TEST(McmcNutsBaseNuts, transition) {
stan::rng_t base_rng(0);
stan::rng_t base_rng = stan::services::util::create_rng(0, 0);

int model_size = 1;
double init_momentum = 1.5;
Expand Down Expand Up @@ -362,7 +362,7 @@ TEST(McmcNutsBaseNuts, transition) {
EXPECT_EQ((2 << (sampler.get_max_depth() - 1)) - 1, sampler.n_leapfrog_);
EXPECT_FALSE(sampler.divergent_);

EXPECT_EQ(21 * init_momentum, s.cont_params()(0));
EXPECT_EQ(-31 * init_momentum, s.cont_params()(0));
EXPECT_EQ(0, s.log_prob());
EXPECT_EQ(1, s.accept_stat());
EXPECT_EQ("", debug.str());
Expand All @@ -373,7 +373,7 @@ TEST(McmcNutsBaseNuts, transition) {
}

TEST(McmcNutsBaseNuts, transition_egde_momenta) {
stan::rng_t base_rng(0);
stan::rng_t base_rng = stan::services::util::create_rng(424243, 0);

int model_size = 1;
double init_momentum = 1.5;
Expand Down
16 changes: 8 additions & 8 deletions src/test/unit/mcmc/hmc/nuts/softabs_nuts_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ TEST(McmcSoftAbsNuts, tree_boundary_test) {
}

TEST(McmcSoftAbsNuts, transition_test) {
stan::rng_t base_rng(4839294);
stan::rng_t base_rng = stan::services::util::create_rng(4839294, 0);

stan::mcmc::softabs_point z_init(3);
z_init.q(0) = 1;
Expand Down Expand Up @@ -338,15 +338,15 @@ TEST(McmcSoftAbsNuts, transition_test) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_EQ(4, sampler.depth_);
EXPECT_EQ((2 << 3) - 1, sampler.n_leapfrog_);
EXPECT_EQ(5, sampler.depth_);
EXPECT_EQ((2 << 4) - 1, sampler.n_leapfrog_);
EXPECT_FALSE(sampler.divergent_);

EXPECT_FLOAT_EQ(1.9313564, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.86902142, s.cont_params()(1));
EXPECT_FLOAT_EQ(1.6008, s.cont_params()(2));
EXPECT_FLOAT_EQ(-3.5239484, s.log_prob());
EXPECT_FLOAT_EQ(0.99690288, s.accept_stat());
EXPECT_FLOAT_EQ(-1.7373296, s.cont_params()(0));
EXPECT_FLOAT_EQ(1.0898665, s.cont_params()(1));
EXPECT_FLOAT_EQ(-0.38303182, s.cont_params()(2));
EXPECT_FLOAT_EQ(-2.1764181, s.log_prob());
EXPECT_FLOAT_EQ(0.9993856, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand Down
16 changes: 8 additions & 8 deletions src/test/unit/mcmc/hmc/nuts/unit_e_nuts_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ TEST(McmcUnitENuts, tree_boundary_test) {
}

TEST(McmcUnitENuts, transition_test) {
stan::rng_t base_rng(4839294);
stan::rng_t base_rng = stan::services::util::create_rng(4839294, 0);

stan::mcmc::unit_e_point z_init(3);
z_init.q(0) = 1;
Expand Down Expand Up @@ -338,15 +338,15 @@ TEST(McmcUnitENuts, transition_test) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_EQ(4, sampler.depth_);
EXPECT_EQ((2 << 3) - 1, sampler.n_leapfrog_);
EXPECT_EQ(5, sampler.depth_);
EXPECT_EQ((2 << 4) - 1, sampler.n_leapfrog_);
EXPECT_FALSE(sampler.divergent_);

EXPECT_FLOAT_EQ(1.8718261, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.74208695, s.cont_params()(1));
EXPECT_FLOAT_EQ(1.5202962, s.cont_params()(2));
EXPECT_FLOAT_EQ(-3.1828632, s.log_prob());
EXPECT_FLOAT_EQ(0.99604273, s.accept_stat());
EXPECT_FLOAT_EQ(-1.7890506, s.cont_params()(0));
EXPECT_FLOAT_EQ(1.2320533, s.cont_params()(1));
EXPECT_FLOAT_EQ(-0.62397981, s.cont_params()(2));
EXPECT_FLOAT_EQ(-2.554004, s.log_prob());
EXPECT_FLOAT_EQ(0.99910343, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include <gtest/gtest.h>

TEST(McmcStaticUniform, unit_e_transition) {
stan::rng_t base_rng(4839294);
stan::rng_t base_rng = stan::services::util::create_rng(4839294, 0);

stan::mcmc::unit_e_point z_init(1);
z_init.q(0) = 1;
Expand All @@ -41,9 +41,9 @@ TEST(McmcStaticUniform, unit_e_transition) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_FLOAT_EQ(0.27224374, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.037058324, s.log_prob());
EXPECT_FLOAT_EQ(0.9998666, s.accept_stat());
EXPECT_FLOAT_EQ(1.5896972, s.cont_params()(0));
EXPECT_FLOAT_EQ(-1.2635686, s.log_prob());
EXPECT_FLOAT_EQ(0.9994188, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand All @@ -52,7 +52,7 @@ TEST(McmcStaticUniform, unit_e_transition) {
}

TEST(McmcStaticUniform, diag_e_transition) {
stan::rng_t base_rng(4839294);
stan::rng_t base_rng = stan::services::util::create_rng(4839294, 0);

stan::mcmc::diag_e_point z_init(1);
z_init.q(0) = 1;
Expand All @@ -78,9 +78,9 @@ TEST(McmcStaticUniform, diag_e_transition) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_FLOAT_EQ(0.27224374, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.037058324, s.log_prob());
EXPECT_FLOAT_EQ(0.9998666, s.accept_stat());
EXPECT_FLOAT_EQ(1.5896972, s.cont_params()(0));
EXPECT_FLOAT_EQ(-1.2635686, s.log_prob());
EXPECT_FLOAT_EQ(0.9994188, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand All @@ -89,7 +89,7 @@ TEST(McmcStaticUniform, diag_e_transition) {
}

TEST(McmcStaticUniform, dense_e_transition) {
stan::rng_t base_rng(4839294);
stan::rng_t base_rng = stan::services::util::create_rng(4839294, 0);

stan::mcmc::dense_e_point z_init(1);
z_init.q(0) = 1;
Expand All @@ -115,9 +115,9 @@ TEST(McmcStaticUniform, dense_e_transition) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_FLOAT_EQ(0.27224374, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.037058324, s.log_prob());
EXPECT_FLOAT_EQ(0.9998666, s.accept_stat());
EXPECT_FLOAT_EQ(1.5896972, s.cont_params()(0));
EXPECT_FLOAT_EQ(-1.2635686, s.log_prob());
EXPECT_FLOAT_EQ(0.9994188, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand All @@ -126,7 +126,7 @@ TEST(McmcStaticUniform, dense_e_transition) {
}

TEST(McmcStaticUniform, softabs_transition) {
stan::rng_t base_rng(4839294);
stan::rng_t base_rng = stan::services::util::create_rng(4839294, 0);

stan::mcmc::softabs_point z_init(1);
z_init.q(0) = 1;
Expand All @@ -152,9 +152,9 @@ TEST(McmcStaticUniform, softabs_transition) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_FLOAT_EQ(0.37006485, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.068473995, s.log_prob());
EXPECT_FLOAT_EQ(0.9999119, s.accept_stat());
EXPECT_FLOAT_EQ(1.5338461, s.cont_params()(0));
EXPECT_FLOAT_EQ(-1.176342, s.log_prob());
EXPECT_FLOAT_EQ(0.9996115, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand All @@ -163,7 +163,7 @@ TEST(McmcStaticUniform, softabs_transition) {
}

TEST(McmcStaticUniform, adapt_unit_e_transition) {
stan::rng_t base_rng(4839294);
stan::rng_t base_rng = stan::services::util::create_rng(4839294, 0);

stan::mcmc::unit_e_point z_init(1);
z_init.q(0) = 1;
Expand All @@ -189,9 +189,9 @@ TEST(McmcStaticUniform, adapt_unit_e_transition) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_FLOAT_EQ(0.27224374, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.037058324, s.log_prob());
EXPECT_FLOAT_EQ(0.9998666, s.accept_stat());
EXPECT_FLOAT_EQ(1.5896972, s.cont_params()(0));
EXPECT_FLOAT_EQ(-1.2635686, s.log_prob());
EXPECT_FLOAT_EQ(0.9994188, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand All @@ -200,7 +200,7 @@ TEST(McmcStaticUniform, adapt_unit_e_transition) {
}

TEST(McmcStaticUniform, adapt_diag_e_transition) {
stan::rng_t base_rng(4839294);
stan::rng_t base_rng = stan::services::util::create_rng(4839294, 0);

stan::mcmc::diag_e_point z_init(1);
z_init.q(0) = 1;
Expand All @@ -226,9 +226,9 @@ TEST(McmcStaticUniform, adapt_diag_e_transition) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_FLOAT_EQ(0.27224374, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.037058324, s.log_prob());
EXPECT_FLOAT_EQ(0.9998666, s.accept_stat());
EXPECT_FLOAT_EQ(1.5896972, s.cont_params()(0));
EXPECT_FLOAT_EQ(-1.2635686, s.log_prob());
EXPECT_FLOAT_EQ(0.9994188, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand All @@ -237,7 +237,7 @@ TEST(McmcStaticUniform, adapt_diag_e_transition) {
}

TEST(McmcStaticUniform, adapt_dense_e_transition) {
stan::rng_t base_rng(4839294);
stan::rng_t base_rng = stan::services::util::create_rng(4839294, 0);

stan::mcmc::dense_e_point z_init(1);
z_init.q(0) = 1;
Expand All @@ -263,9 +263,9 @@ TEST(McmcStaticUniform, adapt_dense_e_transition) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_FLOAT_EQ(0.27224374, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.037058324, s.log_prob());
EXPECT_FLOAT_EQ(0.9998666, s.accept_stat());
EXPECT_FLOAT_EQ(1.5896972, s.cont_params()(0));
EXPECT_FLOAT_EQ(-1.2635686, s.log_prob());
EXPECT_FLOAT_EQ(0.9994188, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand All @@ -274,7 +274,7 @@ TEST(McmcStaticUniform, adapt_dense_e_transition) {
}

TEST(McmcStaticUniform, adapt_softabs_e_transition) {
stan::rng_t base_rng(4839294);
stan::rng_t base_rng = stan::services::util::create_rng(4839294, 0);

stan::mcmc::softabs_point z_init(1);
z_init.q(0) = 1;
Expand All @@ -300,9 +300,9 @@ TEST(McmcStaticUniform, adapt_softabs_e_transition) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_FLOAT_EQ(0.37006485, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.068473995, s.log_prob());
EXPECT_FLOAT_EQ(0.9999119, s.accept_stat());
EXPECT_FLOAT_EQ(1.5338461, s.cont_params()(0));
EXPECT_FLOAT_EQ(-1.176342, s.log_prob());
EXPECT_FLOAT_EQ(0.9996115, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand Down
4 changes: 2 additions & 2 deletions src/test/unit/mcmc/hmc/xhmc/base_xhmc_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ TEST(McmcXHMCBaseXHMC, divergence_test) {
}

TEST(McmcXHMCBaseXHMC, transition) {
stan::rng_t base_rng(0);
stan::rng_t base_rng = stan::services::util::create_rng(0, 0);

int model_size = 1;
double init_momentum = 1.5;
Expand All @@ -245,7 +245,7 @@ TEST(McmcXHMCBaseXHMC, transition) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_EQ(31.5, s.cont_params()(0));
EXPECT_EQ(-31 * init_momentum, s.cont_params()(0));
EXPECT_EQ(0, s.log_prob());
EXPECT_EQ(1, s.accept_stat());
EXPECT_EQ("", debug.str());
Expand Down
18 changes: 9 additions & 9 deletions src/test/unit/mcmc/hmc/xhmc/softabs_xhmc_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <gtest/gtest.h>

TEST(McmcUnitEXHMC, build_tree) {
stan::rng_t base_rng(4839294);
stan::rng_t base_rng = stan::services::util::create_rng(4839294, 0);

stan::mcmc::softabs_point z_init(3);
z_init.q(0) = 1;
Expand Down Expand Up @@ -58,13 +58,13 @@ TEST(McmcUnitEXHMC, build_tree) {
EXPECT_FLOAT_EQ(1.5019561, sampler.z().p(1));
EXPECT_FLOAT_EQ(-1.5019561, sampler.z().p(2));

EXPECT_FLOAT_EQ(0.8330583, z_propose.q(0));
EXPECT_FLOAT_EQ(-0.8330583, z_propose.q(1));
EXPECT_FLOAT_EQ(0.8330583, z_propose.q(2));
EXPECT_FLOAT_EQ(0.42903179, z_propose.q(0));
EXPECT_FLOAT_EQ(-0.42903179, z_propose.q(1));
EXPECT_FLOAT_EQ(0.42903179, z_propose.q(2));

EXPECT_FLOAT_EQ(-1.1836562, z_propose.p(0));
EXPECT_FLOAT_EQ(1.1836562, z_propose.p(1));
EXPECT_FLOAT_EQ(-1.1836562, z_propose.p(2));
EXPECT_FLOAT_EQ(-1.4385087, z_propose.p(0));
EXPECT_FLOAT_EQ(1.4385087, z_propose.p(1));
EXPECT_FLOAT_EQ(-1.4385087, z_propose.p(2));

EXPECT_EQ(8, n_leapfrog);
EXPECT_FLOAT_EQ(3.7645235, ave);
Expand All @@ -79,7 +79,7 @@ TEST(McmcUnitEXHMC, build_tree) {
}

TEST(McmcUnitEXHMC, transition) {
stan::rng_t base_rng(4839294);
stan::rng_t base_rng = stan::services::util::create_rng(483294, 0);

stan::mcmc::softabs_point z_init(3);
z_init.q(0) = 1;
Expand Down Expand Up @@ -112,7 +112,7 @@ TEST(McmcUnitEXHMC, transition) {
EXPECT_FLOAT_EQ(-1, s.cont_params()(1));
EXPECT_FLOAT_EQ(1, s.cont_params()(2));
EXPECT_FLOAT_EQ(-1.5, s.log_prob());
EXPECT_FLOAT_EQ(0.99829924, s.accept_stat());
EXPECT_FLOAT_EQ(0.99993229, s.accept_stat());

EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
Expand Down
18 changes: 9 additions & 9 deletions src/test/unit/mcmc/hmc/xhmc/unit_e_xhmc_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <gtest/gtest.h>

TEST(McmcUnitEXHMC, build_tree) {
stan::rng_t base_rng(4839294);
stan::rng_t base_rng = stan::services::util::create_rng(483294, 0);

stan::mcmc::unit_e_point z_init(3);
z_init.q(0) = 1;
Expand Down Expand Up @@ -58,13 +58,13 @@ TEST(McmcUnitEXHMC, build_tree) {
EXPECT_FLOAT_EQ(1.4131583, sampler.z().p(1));
EXPECT_FLOAT_EQ(-1.4131583, sampler.z().p(2));

EXPECT_FLOAT_EQ(0.78105003, z_propose.q(0));
EXPECT_FLOAT_EQ(-0.78105003, z_propose.q(1));
EXPECT_FLOAT_EQ(0.78105003, z_propose.q(2));
EXPECT_FLOAT_EQ(0.65928948, z_propose.q(0));
EXPECT_FLOAT_EQ(-0.65928948, z_propose.q(1));
EXPECT_FLOAT_EQ(0.65928948, z_propose.q(2));

EXPECT_FLOAT_EQ(-1.1785525, z_propose.p(0));
EXPECT_FLOAT_EQ(1.1785525, z_propose.p(1));
EXPECT_FLOAT_EQ(-1.1785525, z_propose.p(2));
EXPECT_FLOAT_EQ(-1.2505695, z_propose.p(0));
EXPECT_FLOAT_EQ(1.2505695, z_propose.p(1));
EXPECT_FLOAT_EQ(-1.2505695, z_propose.p(2));

EXPECT_EQ(8, n_leapfrog);
EXPECT_FLOAT_EQ(4.2207355, ave);
Expand All @@ -79,7 +79,7 @@ TEST(McmcUnitEXHMC, build_tree) {
}

TEST(McmcUnitEXHMC, transition) {
stan::rng_t base_rng(4839294);
stan::rng_t base_rng = stan::services::util::create_rng(483294, 0);

stan::mcmc::unit_e_point z_init(3);
z_init.q(0) = 1;
Expand Down Expand Up @@ -112,7 +112,7 @@ TEST(McmcUnitEXHMC, transition) {
EXPECT_FLOAT_EQ(-1, s.cont_params()(1));
EXPECT_FLOAT_EQ(1, s.cont_params()(2));
EXPECT_FLOAT_EQ(-1.5, s.log_prob());
EXPECT_FLOAT_EQ(0.99805242, s.accept_stat());
EXPECT_FLOAT_EQ(0.99994934, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand Down
Loading

0 comments on commit 09a7d48

Please sign in to comment.