From 982a713e0d6ccbb8cd78d1fa337523655f1fda21 Mon Sep 17 00:00:00 2001 From: TJ Mahr Date: Mon, 1 Feb 2021 12:25:03 -0600 Subject: [PATCH 1/2] add density options to mcmc_dens_overlay() closes #258 --- NEWS.md | 2 + R/mcmc-distributions.R | 251 ++++++++++++++++++++++++-------------- man/MCMC-distributions.Rd | 16 ++- 3 files changed, 174 insertions(+), 95 deletions(-) diff --git a/NEWS.md b/NEWS.md index bb090c96..673546c8 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,8 @@ Items for next release go here +* `mcmc_dens()` and `mcmc_dens_overlay()` gain arguments for controlling the + the density calculation. (#258) # bayesplot 1.8.0 diff --git a/R/mcmc-distributions.R b/R/mcmc-distributions.R index 63a63c95..7f34e11c 100644 --- a/R/mcmc-distributions.R +++ b/R/mcmc-distributions.R @@ -11,6 +11,7 @@ #' @template args-regex_pars #' @template args-transformations #' @template args-facet_args +#' @template args-density-controls #' @param ... Currently ignored. #' #' @template return-ggplot @@ -105,15 +106,17 @@ NULL #' @template args-hist #' @template args-hist-freq #' -mcmc_hist <- function(x, - pars = character(), - regex_pars = character(), - transformations = list(), - ..., - facet_args = list(), - binwidth = NULL, - breaks = NULL, - freq = TRUE) { +mcmc_hist <- function( + x, + pars = character(), + regex_pars = character(), + transformations = list(), + ..., + facet_args = list(), + binwidth = NULL, + breaks = NULL, + freq = TRUE +) { check_ignored_arguments(...) .mcmc_hist( x, @@ -131,13 +134,19 @@ mcmc_hist <- function(x, #' @rdname MCMC-distributions #' @export -mcmc_dens <- function(x, - pars = character(), - regex_pars = character(), - transformations = list(), - ..., - facet_args = list(), - trim = FALSE) { +mcmc_dens <- function( + x, + pars = character(), + regex_pars = character(), + transformations = list(), + ..., + facet_args = list(), + trim = FALSE, + bw = NULL, + adjust = NULL, + kernel = NULL, + n_dens = NULL +) { check_ignored_arguments(...) .mcmc_dens( x, @@ -147,6 +156,10 @@ mcmc_dens <- function(x, facet_args = facet_args, by_chain = FALSE, trim = trim, + bw = NULL, + adjust = NULL, + kernel = NULL, + n_dens = NULL, ... ) } @@ -154,14 +167,16 @@ mcmc_dens <- function(x, #' @rdname MCMC-distributions #' @export #' -mcmc_hist_by_chain <- function(x, - pars = character(), - regex_pars = character(), - transformations = list(), - ..., - facet_args = list(), - binwidth = NULL, - freq = TRUE) { +mcmc_hist_by_chain <- function( + x, + pars = character(), + regex_pars = character(), + transformations = list(), + ..., + facet_args = list(), + binwidth = NULL, + freq = TRUE +) { check_ignored_arguments(...) .mcmc_hist( x, @@ -178,14 +193,20 @@ mcmc_hist_by_chain <- function(x, #' @rdname MCMC-distributions #' @export -mcmc_dens_overlay <- function(x, - pars = character(), - regex_pars = character(), - transformations = list(), - ..., - facet_args = list(), - color_chains = TRUE, - trim = FALSE) { +mcmc_dens_overlay <- function( + x, + pars = character(), + regex_pars = character(), + transformations = list(), + ..., + facet_args = list(), + color_chains = TRUE, + trim = FALSE, + bw = NULL, + adjust = NULL, + kernel = NULL, + n_dens = NULL +) { check_ignored_arguments(...) .mcmc_dens( x, @@ -196,6 +217,10 @@ mcmc_dens_overlay <- function(x, by_chain = TRUE, color_chains = color_chains, trim = trim, + bw = bw, + adjust = adjust, + kernel = kernel, + n_dens = n_dens, ... ) } @@ -204,19 +229,29 @@ mcmc_dens_overlay <- function(x, #' @template args-density-controls #' @param color_chains Option for whether to separately color chains. #' @export -mcmc_dens_chains <- function(x, - pars = character(), - regex_pars = character(), - transformations = list(), - ..., - color_chains = TRUE, - bw = NULL, adjust = NULL, kernel = NULL, - n_dens = NULL) { +mcmc_dens_chains <- function( + x, + pars = character(), + regex_pars = character(), + transformations = list(), + ..., + color_chains = TRUE, + bw = NULL, + adjust = NULL, + kernel = NULL, + n_dens = NULL +) { check_ignored_arguments(...) - data <- mcmc_dens_chains_data(x, pars = pars, regex_pars = regex_pars, - transformations = transformations, bw = bw, - adjust = adjust, kernel = kernel, - n_dens = n_dens) + data <- mcmc_dens_chains_data( + x, + pars = pars, + regex_pars = regex_pars, + transformations = transformations, + bw = bw, + adjust = adjust, + kernel = kernel, + n_dens = n_dens + ) n_chains <- length(unique(data$chain)) if (n_chains == 1) STOP_need_multiple_chains() @@ -233,17 +268,22 @@ mcmc_dens_chains <- function(x, } ggplot(data) + - aes_(x = ~ x, y = ~ parameter, color = ~ chain, - group = ~ interaction(chain, parameter)) + + aes_( + x = ~ x, y = ~ parameter, color = ~ chain, + group = ~ interaction(chain, parameter) + ) + geom_line(data = line_training) + ggridges::geom_density_ridges( aes_(height = ~ density), stat = "identity", fill = NA, - show.legend = FALSE) + + show.legend = FALSE + ) + labs(color = "Chain") + - scale_y_discrete(limits = unique(rev(data$parameter)), - expand = c(0.05, .6)) + + scale_y_discrete( + limits = unique(rev(data$parameter)), + expand = c(0.05, .6) + ) + scale_color + bayesplot_theme_get() + yaxis_title(FALSE) + @@ -254,38 +294,48 @@ mcmc_dens_chains <- function(x, #' @rdname MCMC-distributions #' @export -mcmc_dens_chains_data <- function(x, - pars = character(), - regex_pars = character(), - transformations = list(), - ..., - bw = NULL, adjust = NULL, kernel = NULL, - n_dens = NULL) { +mcmc_dens_chains_data <- function( + x, + pars = character(), + regex_pars = character(), + transformations = list(), + ..., + bw = NULL, adjust = NULL, kernel = NULL, + n_dens = NULL +) { check_ignored_arguments(...) x %>% - prepare_mcmc_array(pars = pars, regex_pars = regex_pars, - transformations = transformations) %>% + prepare_mcmc_array( + pars = pars, + regex_pars = regex_pars, + transformations = transformations + ) %>% melt_mcmc() %>% - compute_column_density(c(.data$Parameter, .data$Chain), .data$Value, - interval_width = 1, - bw = bw, adjust = adjust, kernel = kernel, - n_dens = n_dens) %>% + compute_column_density( + group_vars = c(.data$Parameter, .data$Chain), + value_var = .data$Value, + interval_width = 1, + bw = bw, adjust = adjust, kernel = kernel, n_dens = n_dens + ) %>% mutate(Chain = factor(.data$Chain)) %>% rlang::set_names(tolower) %>% dplyr::as_tibble() } + #' @rdname MCMC-distributions #' @inheritParams ppc_violin_grouped #' @export -mcmc_violin <- function(x, - pars = character(), - regex_pars = character(), - transformations = list(), - ..., - facet_args = list(), - probs = c(0.1, 0.5, 0.9)) { +mcmc_violin <- function( + x, + pars = character(), + regex_pars = character(), + transformations = list(), + ..., + facet_args = list(), + probs = c(0.1, 0.5, 0.9) +) { check_ignored_arguments(...) .mcmc_dens( x, @@ -303,16 +353,18 @@ mcmc_violin <- function(x, # internal ----------------------------------------------------------------- -.mcmc_hist <- function(x, - pars = character(), - regex_pars = character(), - transformations = list(), - facet_args = list(), - binwidth = NULL, - breaks = NULL, - by_chain = FALSE, - freq = TRUE, - ...) { +.mcmc_hist <- function( + x, + pars = character(), + regex_pars = character(), + transformations = list(), + facet_args = list(), + binwidth = NULL, + breaks = NULL, + by_chain = FALSE, + freq = TRUE, + ... +) { x <- prepare_mcmc_array(x, pars, regex_pars, transformations) if (by_chain && !has_multiple_chains(x)) { @@ -363,19 +415,31 @@ mcmc_violin <- function(x, xaxis_title(on = n_param == 1) } -.mcmc_dens <- function(x, - pars = character(), - regex_pars = character(), - transformations = list(), - facet_args = list(), - by_chain = FALSE, - color_chains = FALSE, - geom = c("density", "violin"), - probs = c(0.1, 0.5, 0.9), - trim = FALSE, - ...) { +.mcmc_dens <- function( + x, + pars = character(), + regex_pars = character(), + transformations = list(), + facet_args = list(), + by_chain = FALSE, + color_chains = FALSE, + geom = c("density", "violin"), + probs = c(0.1, 0.5, 0.9), + trim = FALSE, + bw = NULL, + adjust = NULL, + kernel = NULL, + n_dens = NULL, + ... +) { + + bw <- bw %||% "nrd0" + adjust <- adjust %||% 1 + kernel <- kernel %||% "gaussian" + n_dens <- n_dens %||% 1024 + x <- prepare_mcmc_array(x, pars, regex_pars, transformations) - data <- melt_mcmc(x) + data <- melt_mcmc.mcmc_array(x) data$Chain <- factor(data$Chain) n_param <- num_params(data) @@ -396,11 +460,16 @@ mcmc_violin <- function(x, } else { list(x = ~ Value) } + geom_args <- list(size = 0.5, na.rm = TRUE) if (violin) { geom_args[["draw_quantiles"]] <- probs } else { geom_args[["trim"]] <- trim + geom_args[["bw"]] <- bw + geom_args[["adjust"]] <- adjust + geom_args[["kernel"]] <- kernel + geom_args[["n"]] <- n_dens } if (by_chain) { diff --git a/man/MCMC-distributions.Rd b/man/MCMC-distributions.Rd index 70385b96..43f7eb2f 100644 --- a/man/MCMC-distributions.Rd +++ b/man/MCMC-distributions.Rd @@ -30,7 +30,11 @@ mcmc_dens( transformations = list(), ..., facet_args = list(), - trim = FALSE + trim = FALSE, + bw = NULL, + adjust = NULL, + kernel = NULL, + n_dens = NULL ) mcmc_hist_by_chain( @@ -52,7 +56,11 @@ mcmc_dens_overlay( ..., facet_args = list(), color_chains = TRUE, - trim = FALSE + trim = FALSE, + bw = NULL, + adjust = NULL, + kernel = NULL, + n_dens = NULL ) mcmc_dens_chains( @@ -153,12 +161,12 @@ function.)} \item{trim}{A logical scalar passed to \code{\link[ggplot2:geom_density]{ggplot2::geom_density()}}.} -\item{color_chains}{Option for whether to separately color chains.} - \item{bw, adjust, kernel, n_dens}{Optional arguments passed to \code{\link[stats:density]{stats::density()}} to override default kernel density estimation parameters. \code{n_dens} defaults to \code{1024}.} +\item{color_chains}{Option for whether to separately color chains.} + \item{probs}{A numeric vector passed to \code{\link[ggplot2:geom_violin]{ggplot2::geom_violin()}}'s \code{draw_quantiles} argument to specify at which quantiles to draw horizontal lines. Set to \code{NULL} to remove the lines.} From 13c8253bb518523bcfc77ac19843c79b40db1973 Mon Sep 17 00:00:00 2001 From: TJ Mahr Date: Mon, 1 Feb 2021 14:05:37 -0600 Subject: [PATCH 2/2] make sure mcmc_dens() works --- R/mcmc-distributions.R | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/R/mcmc-distributions.R b/R/mcmc-distributions.R index 7f34e11c..abd611aa 100644 --- a/R/mcmc-distributions.R +++ b/R/mcmc-distributions.R @@ -156,10 +156,10 @@ mcmc_dens <- function( facet_args = facet_args, by_chain = FALSE, trim = trim, - bw = NULL, - adjust = NULL, - kernel = NULL, - n_dens = NULL, + bw = bw, + adjust = adjust, + kernel = kernel, + n_dens = n_dens, ... ) } @@ -445,7 +445,7 @@ mcmc_violin <- function( geom <- match.arg(geom) violin <- geom == "violin" - geom_fun <- if (by_chain) "stat_density" else paste0("geom_", geom) + geom_fun <- if (!violin) "stat_density" else "geom_violin" if (by_chain || violin) { if (!has_multiple_chains(x)) { @@ -519,3 +519,4 @@ mcmc_violin <- function( yaxis_title(on = n_param == 1 && violin) + xaxis_title(on = n_param == 1) } +