Skip to content

Commit

Permalink
Merge pull request #261 from stan-dev/issue-258
Browse files Browse the repository at this point in the history
add density controls to mcmc_dens() and mcmc_dens_overlay()
  • Loading branch information
jgabry authored Oct 5, 2021
2 parents 60edc5a + bb50861 commit 3fc3a89
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 96 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

* Fix R cmd check error on linux for CRAN

* `mcmc_dens()` and `mcmc_dens_overlay()` gain arguments for controlling the
the density calculation. (#258)

# bayesplot 1.8.0

Expand Down
254 changes: 162 additions & 92 deletions R/mcmc-distributions.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#' @template args-regex_pars
#' @template args-transformations
#' @template args-facet_args
#' @template args-density-controls
#' @param ... Currently ignored.
#'
#' @template return-ggplot
Expand Down Expand Up @@ -105,15 +106,17 @@ NULL
#' @template args-hist
#' @template args-hist-freq
#'
mcmc_hist <- function(x,
pars = character(),
regex_pars = character(),
transformations = list(),
...,
facet_args = list(),
binwidth = NULL,
breaks = NULL,
freq = TRUE) {
mcmc_hist <- function(
x,
pars = character(),
regex_pars = character(),
transformations = list(),
...,
facet_args = list(),
binwidth = NULL,
breaks = NULL,
freq = TRUE
) {
check_ignored_arguments(...)
.mcmc_hist(
x,
Expand All @@ -131,13 +134,19 @@ mcmc_hist <- function(x,

#' @rdname MCMC-distributions
#' @export
mcmc_dens <- function(x,
pars = character(),
regex_pars = character(),
transformations = list(),
...,
facet_args = list(),
trim = FALSE) {
mcmc_dens <- function(
x,
pars = character(),
regex_pars = character(),
transformations = list(),
...,
facet_args = list(),
trim = FALSE,
bw = NULL,
adjust = NULL,
kernel = NULL,
n_dens = NULL
) {
check_ignored_arguments(...)
.mcmc_dens(
x,
Expand All @@ -147,21 +156,27 @@ mcmc_dens <- function(x,
facet_args = facet_args,
by_chain = FALSE,
trim = trim,
bw = bw,
adjust = adjust,
kernel = kernel,
n_dens = n_dens,
...
)
}

#' @rdname MCMC-distributions
#' @export
#'
mcmc_hist_by_chain <- function(x,
pars = character(),
regex_pars = character(),
transformations = list(),
...,
facet_args = list(),
binwidth = NULL,
freq = TRUE) {
mcmc_hist_by_chain <- function(
x,
pars = character(),
regex_pars = character(),
transformations = list(),
...,
facet_args = list(),
binwidth = NULL,
freq = TRUE
) {
check_ignored_arguments(...)
.mcmc_hist(
x,
Expand All @@ -178,14 +193,20 @@ mcmc_hist_by_chain <- function(x,

#' @rdname MCMC-distributions
#' @export
mcmc_dens_overlay <- function(x,
pars = character(),
regex_pars = character(),
transformations = list(),
...,
facet_args = list(),
color_chains = TRUE,
trim = FALSE) {
mcmc_dens_overlay <- function(
x,
pars = character(),
regex_pars = character(),
transformations = list(),
...,
facet_args = list(),
color_chains = TRUE,
trim = FALSE,
bw = NULL,
adjust = NULL,
kernel = NULL,
n_dens = NULL
) {
check_ignored_arguments(...)
.mcmc_dens(
x,
Expand All @@ -196,6 +217,10 @@ mcmc_dens_overlay <- function(x,
by_chain = TRUE,
color_chains = color_chains,
trim = trim,
bw = bw,
adjust = adjust,
kernel = kernel,
n_dens = n_dens,
...
)
}
Expand All @@ -204,19 +229,29 @@ mcmc_dens_overlay <- function(x,
#' @template args-density-controls
#' @param color_chains Option for whether to separately color chains.
#' @export
mcmc_dens_chains <- function(x,
pars = character(),
regex_pars = character(),
transformations = list(),
...,
color_chains = TRUE,
bw = NULL, adjust = NULL, kernel = NULL,
n_dens = NULL) {
mcmc_dens_chains <- function(
x,
pars = character(),
regex_pars = character(),
transformations = list(),
...,
color_chains = TRUE,
bw = NULL,
adjust = NULL,
kernel = NULL,
n_dens = NULL
) {
check_ignored_arguments(...)
data <- mcmc_dens_chains_data(x, pars = pars, regex_pars = regex_pars,
transformations = transformations, bw = bw,
adjust = adjust, kernel = kernel,
n_dens = n_dens)
data <- mcmc_dens_chains_data(
x,
pars = pars,
regex_pars = regex_pars,
transformations = transformations,
bw = bw,
adjust = adjust,
kernel = kernel,
n_dens = n_dens
)

n_chains <- length(unique(data$chain))
if (n_chains == 1) STOP_need_multiple_chains()
Expand All @@ -233,17 +268,22 @@ mcmc_dens_chains <- function(x,
}

ggplot(data) +
aes_(x = ~ x, y = ~ parameter, color = ~ chain,
group = ~ interaction(chain, parameter)) +
aes_(
x = ~ x, y = ~ parameter, color = ~ chain,
group = ~ interaction(chain, parameter)
) +
geom_line(data = line_training) +
ggridges::geom_density_ridges(
aes_(height = ~ density),
stat = "identity",
fill = NA,
show.legend = FALSE) +
show.legend = FALSE
) +
labs(color = "Chain") +
scale_y_discrete(limits = unique(rev(data$parameter)),
expand = c(0.05, .6)) +
scale_y_discrete(
limits = unique(rev(data$parameter)),
expand = c(0.05, .6)
) +
scale_color +
bayesplot_theme_get() +
yaxis_title(FALSE) +
Expand All @@ -254,38 +294,48 @@ mcmc_dens_chains <- function(x,

#' @rdname MCMC-distributions
#' @export
mcmc_dens_chains_data <- function(x,
pars = character(),
regex_pars = character(),
transformations = list(),
...,
bw = NULL, adjust = NULL, kernel = NULL,
n_dens = NULL) {
mcmc_dens_chains_data <- function(
x,
pars = character(),
regex_pars = character(),
transformations = list(),
...,
bw = NULL, adjust = NULL, kernel = NULL,
n_dens = NULL
) {
check_ignored_arguments(...)

x %>%
prepare_mcmc_array(pars = pars, regex_pars = regex_pars,
transformations = transformations) %>%
prepare_mcmc_array(
pars = pars,
regex_pars = regex_pars,
transformations = transformations
) %>%
melt_mcmc() %>%
compute_column_density(c(.data$Parameter, .data$Chain), .data$Value,
interval_width = 1,
bw = bw, adjust = adjust, kernel = kernel,
n_dens = n_dens) %>%
compute_column_density(
group_vars = c(.data$Parameter, .data$Chain),
value_var = .data$Value,
interval_width = 1,
bw = bw, adjust = adjust, kernel = kernel, n_dens = n_dens
) %>%
mutate(Chain = factor(.data$Chain)) %>%
rlang::set_names(tolower) %>%
dplyr::as_tibble()
}


#' @rdname MCMC-distributions
#' @inheritParams ppc_violin_grouped
#' @export
mcmc_violin <- function(x,
pars = character(),
regex_pars = character(),
transformations = list(),
...,
facet_args = list(),
probs = c(0.1, 0.5, 0.9)) {
mcmc_violin <- function(
x,
pars = character(),
regex_pars = character(),
transformations = list(),
...,
facet_args = list(),
probs = c(0.1, 0.5, 0.9)
) {
check_ignored_arguments(...)
.mcmc_dens(
x,
Expand All @@ -303,16 +353,18 @@ mcmc_violin <- function(x,


# internal -----------------------------------------------------------------
.mcmc_hist <- function(x,
pars = character(),
regex_pars = character(),
transformations = list(),
facet_args = list(),
binwidth = NULL,
breaks = NULL,
by_chain = FALSE,
freq = TRUE,
...) {
.mcmc_hist <- function(
x,
pars = character(),
regex_pars = character(),
transformations = list(),
facet_args = list(),
binwidth = NULL,
breaks = NULL,
by_chain = FALSE,
freq = TRUE,
...
) {
x <- prepare_mcmc_array(x, pars, regex_pars, transformations)

if (by_chain && !has_multiple_chains(x)) {
Expand Down Expand Up @@ -363,25 +415,37 @@ mcmc_violin <- function(x,
xaxis_title(on = n_param == 1)
}

.mcmc_dens <- function(x,
pars = character(),
regex_pars = character(),
transformations = list(),
facet_args = list(),
by_chain = FALSE,
color_chains = FALSE,
geom = c("density", "violin"),
probs = c(0.1, 0.5, 0.9),
trim = FALSE,
...) {
.mcmc_dens <- function(
x,
pars = character(),
regex_pars = character(),
transformations = list(),
facet_args = list(),
by_chain = FALSE,
color_chains = FALSE,
geom = c("density", "violin"),
probs = c(0.1, 0.5, 0.9),
trim = FALSE,
bw = NULL,
adjust = NULL,
kernel = NULL,
n_dens = NULL,
...
) {

bw <- bw %||% "nrd0"
adjust <- adjust %||% 1
kernel <- kernel %||% "gaussian"
n_dens <- n_dens %||% 1024

x <- prepare_mcmc_array(x, pars, regex_pars, transformations)
data <- melt_mcmc(x)
data <- melt_mcmc.mcmc_array(x)
data$Chain <- factor(data$Chain)
n_param <- num_params(data)

geom <- match.arg(geom)
violin <- geom == "violin"
geom_fun <- if (by_chain) "stat_density" else paste0("geom_", geom)
geom_fun <- if (!violin) "stat_density" else "geom_violin"

if (by_chain || violin) {
if (!has_multiple_chains(x)) {
Expand All @@ -396,11 +460,16 @@ mcmc_violin <- function(x,
} else {
list(x = ~ Value)
}

geom_args <- list(size = 0.5, na.rm = TRUE)
if (violin) {
geom_args[["draw_quantiles"]] <- probs
} else {
geom_args[["trim"]] <- trim
geom_args[["bw"]] <- bw
geom_args[["adjust"]] <- adjust
geom_args[["kernel"]] <- kernel
geom_args[["n"]] <- n_dens
}

if (by_chain) {
Expand Down Expand Up @@ -450,3 +519,4 @@ mcmc_violin <- function(x,
yaxis_title(on = n_param == 1 && violin) +
xaxis_title(on = n_param == 1)
}

Loading

0 comments on commit 3fc3a89

Please sign in to comment.