Skip to content

Commit

Permalink
fix some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avehtari committed Jan 24, 2024
1 parent a54d577 commit 26e8a7b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 24 deletions.
12 changes: 6 additions & 6 deletions tests/testthat/test_loo_moment_matching.R
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,14 @@ test_that("loo_moment_match.default warnings work", {
expect_warning(loo_moment_match(x, loo_manual, post_draws_test, log_lik_i_test,
unconstrain_pars_test, log_prob_upars_test,
log_lik_i_upars_test, max_iters = 30L,
k_thres = 100, split = FALSE,
cov = TRUE, cores = 1), "Some Pareto k")
k_thres = 0.5, split = FALSE,
cov = TRUE, cores = 1), "The accuracy of self-normalized importance sampling")

expect_warning(loo_moment_match(x, loo_manual, post_draws_test, log_lik_i_test,
expect_no_warning(loo_moment_match(x, loo_manual, post_draws_test, log_lik_i_test,
unconstrain_pars_test, log_prob_upars_test,
log_lik_i_upars_test, max_iters = 30L,
k_thres = 0.5, split = FALSE,
cov = TRUE, cores = 1), "The accuracy of self-normalized importance sampling")
k_thres = 100, split = TRUE,
cov = TRUE, cores = 1))

expect_warning(loo_moment_match(x, loo_manual, post_draws_test, log_lik_i_test,
unconstrain_pars_test, log_prob_upars_test,
Expand Down Expand Up @@ -180,7 +180,7 @@ test_that("loo_moment_match.default works", {
k_thres = 0.8, split = FALSE,
cov = TRUE, cores = 1))

# diagnostic pareto k decreases but influence pareto k stays the same
# diagnostic Pareto k decreases but influence pareto k stays the same
expect_lt(loo_moment_match_object$diagnostics$pareto_k[1], loo_moment_match_object$pointwise[1,"influence_pareto_k"])
expect_equal(loo_moment_match_object$pointwise[,"influence_pareto_k"],loo_manual$pointwise[,"influence_pareto_k"])
expect_equal(loo_moment_match_object$pointwise[,"influence_pareto_k"],loo_manual$diagnostics$pareto_k)
Expand Down
30 changes: 13 additions & 17 deletions tests/testthat/test_print_plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ test_that("pareto_k_ids identifies correct observations", {
})

test_that("pareto_k_table gives correct output", {
psis1$diagnostics$pareto_k[1:10] <- runif(10, 0, 0.49)
psis1$diagnostics$pareto_k[11:17] <- runif(7, 0.51, 0.69)
psis1$diagnostics$pareto_k[18:20] <- runif(3, 0.71, 0.99)
threshold <- ps_khat_threshold(dim(psis1)[1])
psis1$diagnostics$pareto_k[1:10] <- runif(10, 0, threshold)
psis1$diagnostics$pareto_k[11:20] <- runif(10, threshold+0.01, 0.99)
psis1$diagnostics$pareto_k[21:32] <- runif(12, 1, 10)
k <- pareto_k_values(psis1)
tab <- pareto_k_table(psis1)
Expand All @@ -114,24 +114,20 @@ test_that("pareto_k_table gives correct output", {
expect_equal(sum(tab[, "Count"]), length(k))
expect_equal(sum(tab[, "Proportion"]), 1)

expect_equal(sum(k <= 0.5), tab[1,1])
expect_equal(sum(k > 0.5 & k <= 0.7), tab[2,1])
expect_equal(sum(k > 0.7 & k <= 1), tab[3,1])
expect_equal(sum(k > 1), tab[4,1])

psis1$diagnostics$pareto_k[1:32] <- 0.4
expect_output(print(pareto_k_table(psis1)), "All Pareto k estimates are good (k < 0.5)",
fixed = TRUE)

psis1$diagnostics$pareto_k[1:32] <- 0.65
expect_output(print(pareto_k_table(psis1)), "All Pareto k estimates are ok (k < 0.7)",
fixed = TRUE)
expect_equal(sum(k <= threshold), tab[1,1])
expect_equal(sum(k > threshold & k <= 1), tab[2,1])
expect_equal(sum(k > 1), tab[3,1])

# if n_eff is NULL
psis1$diagnostics$n_eff <- NULL
tab2 <- pareto_k_table(psis1)
expect_output(print(tab2), "<NA>")
expect_equal(unname(tab2[, "Min. n_eff"]), rep(NA_real_, 4))
expect_equal(unname(tab2[, "Min. n_eff"]), rep(NA_real_, 3))

psis1$diagnostics$pareto_k[1:32] <- 0.4
expect_output(print(pareto_k_table(psis1)),
paste0("All Pareto k estimates are good (k < ", round(threshold,2), ")"),
fixed = TRUE)
})


Expand All @@ -144,7 +140,7 @@ test_that("psis_n_eff_values extractor works", {
expect_identical(psis_n_eff_values(psis1), psis_n_eff_values(loo1))

psis1$diagnostics$n_eff <- NULL
expect_error(psis_n_eff_values(psis1), "No PSIS n_eff estimates found")
expect_error(psis_n_eff_values(psis1), "No PSIS ESS estimates found")
})

test_that("mcse_loo extractor gives correct value", {
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_psis_approximate_posterior.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ test_that("Laplace approximation, normal model", {
log_p <- test_data_psis_approximate_posterior$laplace_normal$log_p
log_g <- test_data_psis_approximate_posterior$laplace_normal$log_q
ll <- test_data_psis_approximate_posterior$laplace_normal$log_liks
expect_warning(
expect_no_warning(
psis_lap <-
psis_approximate_posterior(log_p = log_p, log_g = log_g, cores = 1, save_psis = FALSE)
)
Expand Down

0 comments on commit 26e8a7b

Please sign in to comment.