From 37205ffecd60d8bf17a5ce36404de288140c1172 Mon Sep 17 00:00:00 2001 From: jgabry Date: Tue, 30 Jul 2024 16:06:40 -0600 Subject: [PATCH] Fix missing legends for unobserved levels in rhat and neff plots closes #327 --- R/mcmc-diagnostics.R | 9 +- .../mcmc-neff-missing-levels.svg | 93 +++++++++++++++++++ tests/testthat/test-mcmc-diagnostics.R | 11 +++ 3 files changed, 109 insertions(+), 4 deletions(-) create mode 100644 tests/testthat/_snaps/mcmc-diagnostics/mcmc-neff-missing-levels.svg diff --git a/R/mcmc-diagnostics.R b/R/mcmc-diagnostics.R index a56be0bd..3dc1b324 100644 --- a/R/mcmc-diagnostics.R +++ b/R/mcmc-diagnostics.R @@ -142,7 +142,8 @@ mcmc_rhat <- function(rhat, ..., size = NULL) { mapping = aes( yend = .data$parameter, xend = ifelse(min(.data$value) < 1, 1, -Inf)), - na.rm = TRUE) + + na.rm = TRUE, + show.legend = TRUE) + bayesplot_theme_get() if (min(data$value) < 1) { @@ -238,7 +239,8 @@ mcmc_neff <- function(ratio, ..., size = NULL) { fill = .data$rating)) + geom_segment( aes(yend = .data$parameter, xend = -Inf), - na.rm = TRUE) + + na.rm = TRUE, + show.legend = TRUE) + diagnostic_points(size) + vline_at( c(0.1, 0.5, 1), @@ -408,7 +410,7 @@ zero_pad_int <- function(xs) { } diagnostic_points <- function(size = NULL) { - args <- list(shape = 21, na.rm = TRUE) + args <- list(shape = 21, na.rm = TRUE, show.legend = TRUE) do.call("geom_point", c(args, size = size)) } @@ -454,7 +456,6 @@ diagnostic_colors <- function(diagnostic = c("rhat", "neff_ratio"), } color_labels <- diagnostic_color_labels[[diagnostic]] - list(diagnostic = diagnostic, aesthetic = aesthetic, color_levels = color_levels, diff --git a/tests/testthat/_snaps/mcmc-diagnostics/mcmc-neff-missing-levels.svg b/tests/testthat/_snaps/mcmc-diagnostics/mcmc-neff-missing-levels.svg new file mode 100644 index 00000000..7201c5bb --- /dev/null +++ b/tests/testthat/_snaps/mcmc-diagnostics/mcmc-neff-missing-levels.svg @@ -0,0 +1,93 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +0 +0.1 +0.25 +0.5 +0.75 +1 +N +e +f +f + +N + + + + + + +N +e +f +f + +N + +0.1 +N +e +f +f + +N + +0.5 +N +e +f +f + +N +> +0.5 +mcmc_neff (missing levels) + + diff --git a/tests/testthat/test-mcmc-diagnostics.R b/tests/testthat/test-mcmc-diagnostics.R index b1d3244e..3ebe0c37 100644 --- a/tests/testthat/test-mcmc-diagnostics.R +++ b/tests/testthat/test-mcmc-diagnostics.R @@ -146,6 +146,17 @@ test_that("mcmc_neff renders correctly", { vdiffr::expect_doppelganger("mcmc_neff (default)", p_base) }) +test_that("mcmc_neff renders legend correctly even if some levels missing", { + testthat::skip_on_cran() + testthat::skip_if_not_installed("vdiffr") + skip_on_r_oldrel() + + neffs <- c(0.1, 0.2, 0.3, 0.4) # above 0.5 is missing but should still appear in legend + + p_base <- mcmc_neff(neffs) + vdiffr::expect_doppelganger("mcmc_neff (missing levels)", p_base) +}) + test_that("mcmc_neff_hist renders correctly", { testthat::skip_on_cran() testthat::skip_if_not_installed("vdiffr")