Skip to content

Commit

Permalink
Add split-chain option to rank overlay plots
Browse files Browse the repository at this point in the history
Related to #333
  • Loading branch information
sims1253 committed Dec 16, 2024
1 parent 20910f5 commit 3ac550c
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 4 deletions.
36 changes: 32 additions & 4 deletions R/mcmc-traces.R
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ trace_style_np <- function(div_color = "red", div_size = 0.25, div_alpha = 1) {
#' of rank-normalized MCMC samples. Defaults to `20`.
#' @param ref_line For the rank plots, whether to draw a horizontal line at the
#' average number of ranks per bin. Defaults to `FALSE`.
#' @param split_chains Logical indicating whether to split each chain into two parts.
#' If TRUE, each chain is split into first and second half with "_1" and "_2" suffixes.
#' Defaults to `FALSE`.
#' @export
mcmc_rank_overlay <- function(x,
pars = character(),
Expand All @@ -285,7 +288,8 @@ mcmc_rank_overlay <- function(x,
facet_args = list(),
...,
n_bins = 20,
ref_line = FALSE) {
ref_line = FALSE,
split_chains = FALSE) {
check_ignored_arguments(...)
data <- mcmc_trace_data(
x,
Expand All @@ -294,7 +298,28 @@ mcmc_rank_overlay <- function(x,
transformations = transformations
)

n_chains <- unique(data$n_chains)
# Split chains if requested
if (split_chains) {
data$n_chains = data$n_chains/2
data$n_iterations = data$n_iterations/2
# Calculate midpoint for each chain
n_samples <- length(unique(data$iteration))
midpoint <- n_samples/2

# Create new data frame with split chains
data <- data %>%
group_by(.data$chain) %>%
mutate(
chain = ifelse(
iteration <= midpoint,
paste0(.data$chain, "_1"),
paste0(.data$chain, "_2")
)
) %>%
ungroup()
}

n_chains <- length(unique(data$chain))
n_param <- unique(data$n_parameters)

# We have to bin and count the data ourselves because
Expand All @@ -319,6 +344,7 @@ mcmc_rank_overlay <- function(x,
bin_start = unique(histobins$bin_start),
stringsAsFactors = FALSE
))

d_bin_counts <- all_combos %>%
left_join(d_bin_counts, by = c("parameter", "chain", "bin_start")) %>%
mutate(n = dplyr::if_else(is.na(n), 0L, n))
Expand All @@ -331,7 +357,9 @@ mcmc_rank_overlay <- function(x,
mutate(bin_start = right_edge) %>%
dplyr::bind_rows(d_bin_counts)

scale_color <- scale_color_manual("Chain", values = chain_colors(n_chains))
# Update legend title based on split_chains
legend_title <- if (split_chains) "Split Chains" else "Chain"
scale_color <- scale_color_manual(legend_title, values = chain_colors(n_chains))

layer_ref_line <- if (ref_line) {
geom_hline(
Expand All @@ -352,7 +380,7 @@ mcmc_rank_overlay <- function(x,
}

ggplot(d_bin_counts) +
aes(x = .data$bin_start, y = .data$n, color = .data$chain) +
aes(x = .data$bin_start, y = .data$n, color = .data$chain) +
geom_step() +
layer_ref_line +
facet_call +
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions tests/testthat/data-for-mcmc-tests.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,11 @@ vdiff_dframe_rank_overlay_bins_test <- posterior::as_draws_df(
)
)

vdiff_dframe_rank_overlay_split_chain_test <- posterior::as_draws_df(
list(
list(theta = -2 + 0.003 * 1:1000 + stats::arima.sim(list(ar = 0.7), n = 1000, sd = 0.5)),
list(theta = 1 + -0.003 * 1:1000 + stats::arima.sim(list(ar = 0.7), n = 1000, sd = 0.5))
)
)

set.seed(seed = NULL)
7 changes: 7 additions & 0 deletions tests/testthat/test-mcmc-traces.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ test_that("mcmc_rank_overlay renders correctly", {
# https://github.com/stan-dev/bayesplot/issues/331
p_not_all_bins_exist <- mcmc_rank_overlay(vdiff_dframe_rank_overlay_bins_test)

# https://github.com/stan-dev/bayesplot/issues/333
p_split_chains <- mcmc_rank_overlay(vdiff_dframe_rank_overlay_split_chain_test,
split_chains = TRUE)

vdiffr::expect_doppelganger("mcmc_rank_overlay (default)", p_base)
vdiffr::expect_doppelganger(
"mcmc_rank_overlay (reference line)",
Expand All @@ -170,6 +174,9 @@ test_that("mcmc_rank_overlay renders correctly", {

# https://github.com/stan-dev/bayesplot/issues/331
vdiffr::expect_doppelganger("mcmc_rank_overlay (not all bins)", p_not_all_bins_exist)

# https://github.com/stan-dev/bayesplot/issues/333
vdiffr::expect_doppelganger("mcmc_rank_overlay (split chains)", p_split_chains)
})

test_that("mcmc_rank_hist renders correctly", {
Expand Down

0 comments on commit 3ac550c

Please sign in to comment.