From 26e8a7bd5fcd5c7deb0f26fb0f5c05aeef870e1c Mon Sep 17 00:00:00 2001 From: Aki Vehtari Date: Wed, 24 Jan 2024 11:11:15 +0200 Subject: [PATCH] fix some tests --- tests/testthat/test_loo_moment_matching.R | 12 ++++---- tests/testthat/test_print_plot.R | 30 ++++++++----------- .../test_psis_approximate_posterior.R | 2 +- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/tests/testthat/test_loo_moment_matching.R b/tests/testthat/test_loo_moment_matching.R index e981a689..10e80139 100644 --- a/tests/testthat/test_loo_moment_matching.R +++ b/tests/testthat/test_loo_moment_matching.R @@ -144,14 +144,14 @@ test_that("loo_moment_match.default warnings work", { expect_warning(loo_moment_match(x, loo_manual, post_draws_test, log_lik_i_test, unconstrain_pars_test, log_prob_upars_test, log_lik_i_upars_test, max_iters = 30L, - k_thres = 100, split = FALSE, - cov = TRUE, cores = 1), "Some Pareto k") + k_thres = 0.5, split = FALSE, + cov = TRUE, cores = 1), "The accuracy of self-normalized importance sampling") - expect_warning(loo_moment_match(x, loo_manual, post_draws_test, log_lik_i_test, + expect_no_warning(loo_moment_match(x, loo_manual, post_draws_test, log_lik_i_test, unconstrain_pars_test, log_prob_upars_test, log_lik_i_upars_test, max_iters = 30L, - k_thres = 0.5, split = FALSE, - cov = TRUE, cores = 1), "The accuracy of self-normalized importance sampling") + k_thres = 100, split = TRUE, + cov = TRUE, cores = 1)) expect_warning(loo_moment_match(x, loo_manual, post_draws_test, log_lik_i_test, unconstrain_pars_test, log_prob_upars_test, @@ -180,7 +180,7 @@ test_that("loo_moment_match.default works", { k_thres = 0.8, split = FALSE, cov = TRUE, cores = 1)) - # diagnostic pareto k decreases but influence pareto k stays the same + # diagnostic Pareto k decreases but influence pareto k stays the same expect_lt(loo_moment_match_object$diagnostics$pareto_k[1], loo_moment_match_object$pointwise[1,"influence_pareto_k"]) expect_equal(loo_moment_match_object$pointwise[,"influence_pareto_k"],loo_manual$pointwise[,"influence_pareto_k"]) expect_equal(loo_moment_match_object$pointwise[,"influence_pareto_k"],loo_manual$diagnostics$pareto_k) diff --git a/tests/testthat/test_print_plot.R b/tests/testthat/test_print_plot.R index 9d355901..a657f8c5 100644 --- a/tests/testthat/test_print_plot.R +++ b/tests/testthat/test_print_plot.R @@ -102,9 +102,9 @@ test_that("pareto_k_ids identifies correct observations", { }) test_that("pareto_k_table gives correct output", { - psis1$diagnostics$pareto_k[1:10] <- runif(10, 0, 0.49) - psis1$diagnostics$pareto_k[11:17] <- runif(7, 0.51, 0.69) - psis1$diagnostics$pareto_k[18:20] <- runif(3, 0.71, 0.99) + threshold <- ps_khat_threshold(dim(psis1)[1]) + psis1$diagnostics$pareto_k[1:10] <- runif(10, 0, threshold) + psis1$diagnostics$pareto_k[11:20] <- runif(10, threshold+0.01, 0.99) psis1$diagnostics$pareto_k[21:32] <- runif(12, 1, 10) k <- pareto_k_values(psis1) tab <- pareto_k_table(psis1) @@ -114,24 +114,20 @@ test_that("pareto_k_table gives correct output", { expect_equal(sum(tab[, "Count"]), length(k)) expect_equal(sum(tab[, "Proportion"]), 1) - expect_equal(sum(k <= 0.5), tab[1,1]) - expect_equal(sum(k > 0.5 & k <= 0.7), tab[2,1]) - expect_equal(sum(k > 0.7 & k <= 1), tab[3,1]) - expect_equal(sum(k > 1), tab[4,1]) - - psis1$diagnostics$pareto_k[1:32] <- 0.4 - expect_output(print(pareto_k_table(psis1)), "All Pareto k estimates are good (k < 0.5)", - fixed = TRUE) - - psis1$diagnostics$pareto_k[1:32] <- 0.65 - expect_output(print(pareto_k_table(psis1)), "All Pareto k estimates are ok (k < 0.7)", - fixed = TRUE) + expect_equal(sum(k <= threshold), tab[1,1]) + expect_equal(sum(k > threshold & k <= 1), tab[2,1]) + expect_equal(sum(k > 1), tab[3,1]) # if n_eff is NULL psis1$diagnostics$n_eff <- NULL tab2 <- pareto_k_table(psis1) expect_output(print(tab2), "") - expect_equal(unname(tab2[, "Min. n_eff"]), rep(NA_real_, 4)) + expect_equal(unname(tab2[, "Min. n_eff"]), rep(NA_real_, 3)) + + psis1$diagnostics$pareto_k[1:32] <- 0.4 + expect_output(print(pareto_k_table(psis1)), + paste0("All Pareto k estimates are good (k < ", round(threshold,2), ")"), + fixed = TRUE) }) @@ -144,7 +140,7 @@ test_that("psis_n_eff_values extractor works", { expect_identical(psis_n_eff_values(psis1), psis_n_eff_values(loo1)) psis1$diagnostics$n_eff <- NULL - expect_error(psis_n_eff_values(psis1), "No PSIS n_eff estimates found") + expect_error(psis_n_eff_values(psis1), "No PSIS ESS estimates found") }) test_that("mcse_loo extractor gives correct value", { diff --git a/tests/testthat/test_psis_approximate_posterior.R b/tests/testthat/test_psis_approximate_posterior.R index 83afd240..0529d0f4 100644 --- a/tests/testthat/test_psis_approximate_posterior.R +++ b/tests/testthat/test_psis_approximate_posterior.R @@ -47,7 +47,7 @@ test_that("Laplace approximation, normal model", { log_p <- test_data_psis_approximate_posterior$laplace_normal$log_p log_g <- test_data_psis_approximate_posterior$laplace_normal$log_q ll <- test_data_psis_approximate_posterior$laplace_normal$log_liks - expect_warning( + expect_no_warning( psis_lap <- psis_approximate_posterior(log_p = log_p, log_g = log_g, cores = 1, save_psis = FALSE) )