diff --git a/NAMESPACE b/NAMESPACE index 62cf616b..f8497c96 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -171,6 +171,9 @@ S3method(iteration_ids,draws_rvars) S3method(iteration_ids,rvar) S3method(length,rvar) S3method(levels,rvar) +S3method(log_weights,draws) +S3method(log_weights,draws_rvars) +S3method(log_weights,rvar) S3method(mad,default) S3method(mad,rvar) S3method(mad,rvar_ordered) @@ -394,7 +397,9 @@ S3method(weight_draws,draws_df) S3method(weight_draws,draws_list) S3method(weight_draws,draws_matrix) S3method(weight_draws,draws_rvars) +S3method(weight_draws,rvar) S3method(weights,draws) +S3method(weights,rvar) export("%**%") export("%in%") export("draws_of<-") @@ -454,6 +459,7 @@ export(is_rvar) export(is_rvar_factor) export(is_rvar_ordered) export(iteration_ids) +export(log_weights) export(mad) export(match) export(mcse_mean) diff --git a/NEWS.md b/NEWS.md index 03f5a1a8..1d0e079d 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,8 @@ * Add `pareto_smooth` option to `weight_draws`, to Pareto smooth weights before adding to a draws object. +* Add support for applying weights to individual `rvar` objects. +* Add `log_weights()` function for easy access to raw internal weights. * Matrix multiplication of `rvar`s can now be done with the base matrix multiplication operator (`%*%`) instead of `%**%` in R >= 4.3. * `variables()`, `variables<-()`, `set_variables()`, and `nvariables()` now diff --git a/R/as_draws_array.R b/R/as_draws_array.R index 8ed9a529..84f4ec3a 100644 --- a/R/as_draws_array.R +++ b/R/as_draws_array.R @@ -96,6 +96,7 @@ as_draws_array.draws_rvars <- function(x, ...) { x <- check_variables_are_numeric( x, to = "draws_array", is_non_numeric = is_rvar_factor, convert = FALSE ) + x <- promote_rvar_weights_to_variable(x) # cbind discards class information when applied to vectors, which converts # the underlying factors to numeric diff --git a/R/as_draws_df.R b/R/as_draws_df.R index 9eefcdba..28d17a60 100644 --- a/R/as_draws_df.R +++ b/R/as_draws_df.R @@ -110,6 +110,7 @@ as_draws_df.draws_rvars <- function(x, ...) { if (ndraws(x) == 0L) { return(empty_draws_df(variables(x))) } + x <- promote_rvar_weights_to_variable(x) out <- do.call(cbind, lapply(seq_along(x), function(i) { # flatten each rvar so it only has two dimensions: draws and variables # this also collapses indices into variable names in the format "var[i,j,k,...]" diff --git a/R/as_draws_matrix.R b/R/as_draws_matrix.R index 5d7a37c6..03fe4ab3 100644 --- a/R/as_draws_matrix.R +++ b/R/as_draws_matrix.R @@ -85,6 +85,7 @@ as_draws_matrix.draws_rvars <- function(x, ...) { x <- check_variables_are_numeric( x, to = "draws_matrix", is_non_numeric = is_rvar_factor, convert = FALSE ) + x <- promote_rvar_weights_to_variable(x) # cbind discards class information when applied to vectors, which converts # the underlying factors to numeric diff --git a/R/as_draws_rvars.R b/R/as_draws_rvars.R index fd7a558a..b60a051e 100755 --- a/R/as_draws_rvars.R +++ b/R/as_draws_rvars.R @@ -207,9 +207,27 @@ as_draws_rvars.mcmc.list <- function(x, ...) { check_new_variables(names(x)) - x <- conform_rvar_ndraws_nchains(x) + x <- conform_rvar_nchains_ndraws_weights(x) class(x) <- class_draws_rvars() + + # move the .log_weight column into the log_weights attribute of each rvar, + # but only if there is no conflict between any existing weights on the rvars + if (".log_weight" %in% names(x)) { + existing_weights <- log_weights(x[[1]]) + .log_weight <- as.vector(draws_of(x$.log_weight)) + if (is.null(existing_weights)) { + x$.log_weight <- NULL + x <- weight_draws(x, .log_weight, log = TRUE) + } else { + # if we reach this point either existing_weights and .log_weight + # are identical (so we don't have to do anything) or they aren't + # and weights2_common will throw the appropriate error --- thus + # we don't need to do anything with its output + weights2_common(existing_weights, .log_weight) + } + } + x } @@ -258,3 +276,13 @@ empty_draws_rvars <- function(variables = character(0), nchains = 0) { class(out) <- class_draws_rvars() out } + +# when converting draws_rvars to other formats, we must promote log weights +# to be a variable before doing the conversion +promote_rvar_weights_to_variable <- function(x) { + .log_weights <- log_weights(x) + if (!is.null(.log_weights)) { + x$.log_weight <- rvar(log_weights(x), nchains = nchains(x)) + } + x +} diff --git a/R/convergence.R b/R/convergence.R index cf895f07..0fb6e95f 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -82,6 +82,8 @@ rhat_basic.rvar <- function(x, split = TRUE, ...) { #' recommend the improved ESS convergence diagnostics implemented in #' [ess_bulk()] and [ess_tail()]. See Vehtari (2021) for an in-depth #' comparison of different effective sample size estimators. +#' If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -104,18 +106,27 @@ ess_basic <- function(x, ...) UseMethod("ess_basic") #' @rdname ess_basic #' @export -ess_basic.default <- function(x, split = TRUE, ...) { +ess_basic.default <- function(x, split = TRUE, weights = NULL, ...) { split <- as_one_logical(split) if (split) { x <- .split_chains(x) } - .ess(x) + + if (is.null(weights)) { + .ess(x) + } else { + r_eff <- .ess(x) / (nrow(x) * ncol(x)) + .ess_weighted(x, weights, r_eff = r_eff, ...) + } } #' @rdname ess_basic #' @export ess_basic.rvar <- function(x, split = TRUE, ...) { - summarise_rvar_by_element_with_chains(x, ess_basic, split, ...) + + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_basic, split, weights = weights, ...) + } #' Rhat convergence diagnostic @@ -162,6 +173,8 @@ rhat.rvar <- function(x, ...) { #' rank normalized values using split chains. For the tail effective sample size #' see [ess_tail()]. See Vehtari (2021) for an in-depth #' comparison of different effective sample size estimators. +#' If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -182,14 +195,27 @@ ess_bulk <- function(x, ...) UseMethod("ess_bulk") #' @rdname ess_bulk #' @export -ess_bulk.default <- function(x, ...) { - .ess(z_scale(.split_chains(x))) +ess_bulk.default <- function(x, weights = NULL, ...) { + if (is.null(weights)) { + .ess(z_scale(.split_chains(x))) + } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + + r_eff <- .ess(z_scale(.split_chains(x))) / (nrow(x) * ncol(x)) + .ess_weighted(x, weights, r_eff = r_eff, ...) + } } #' @rdname ess_bulk #' @export ess_bulk.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_bulk, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_bulk, weights = weights, ...) } #' Tail effective sample size (tail-ESS) @@ -200,6 +226,8 @@ ess_bulk.rvar <- function(x, ...) { #' sample sizes for 5% and 95% quantiles. For the bulk effective sample #' size see [ess_bulk()]. See Vehtari (2021) for an in-depth #' comparison of different effective sample size estimators. +#' If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -220,22 +248,24 @@ ess_tail <- function(x, ...) UseMethod("ess_tail") #' @rdname ess_tail #' @export -ess_tail.default <- function(x, ...) { - q05_ess <- ess_quantile(x, 0.05) - q95_ess <- ess_quantile(x, 0.95) +ess_tail.default <- function(x, weights = NULL, ...) { + q05_ess <- ess_quantile(x, 0.05, weights = weights, ...) + q95_ess <- ess_quantile(x, 0.95, weights = weights, ...) min(q05_ess, q95_ess) } #' @rdname ess_tail #' @export ess_tail.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_tail, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_tail, weights = weights, ...) } #' Effective sample sizes for quantiles #' -#' Compute effective sample size estimates for quantile estimates of a single -#' variable. +#' Compute effective sample size estimates for quantile estimates of a +#' single variable. If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -258,13 +288,26 @@ ess_quantile <- function(x, probs = c(0.05, 0.95), ...) { #' @rdname ess_quantile #' @export -ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { +ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights = NULL, ...) { probs <- as.numeric(probs) if (any(probs < 0 | probs > 1)) { stop_no_call("'probs' must contain values between 0 and 1.") } names <- as_one_logical(names) - out <- ulapply(probs, .ess_quantile, x = x) + if (is.null(weights)) { + out <- ulapply(probs, .ess_quantile, x = x) + } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + + r_eff <- ulapply(probs, .ess_quantile, x = x) / (nrow(x) * ncol(x)) + out <- mapply(.ess_quantile_weighted, prob = probs, r_eff = r_eff, MoreArgs = list(x = x, weights = weights)) + + } if (names) { names(out) <- paste0("ess_q", probs * 100) } @@ -274,7 +317,8 @@ ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { #' @rdname ess_quantile #' @export ess_quantile.rvar <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { - summarise_rvar_by_element_with_chains(x, ess_quantile, probs, names, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_quantile, probs, weights = weights, names, ...) } #' @rdname ess_quantile @@ -293,10 +337,23 @@ ess_median <- function(x, ...) { len <- length(x) prob <- (len - 0.5) / len } - I <- x <= quantile(x, prob) + I <- (x <= quantile(x, prob)) .ess(.split_chains(I)) } +.ess_quantile_weighted <- function(x, prob, weights, r_eff) { + if (should_return_NA(x)) { + return(NA_real_) + } + x <- as.matrix(x) + if (prob == 1) { + len <- length(x) + prob <- (len - 0.5) / len + } + I <- (x <= weighted_quantile(x, prob, weights = weights)) + .ess_weighted(I, weights = weights, r_eff = r_eff) +} + #' Effective sample size for the mean #' #' Compute an effective sample size estimate for a mean (expectation) @@ -319,14 +376,28 @@ ess_mean <- function(x, ...) UseMethod("ess_mean") #' @rdname ess_quantile #' @export -ess_mean.default <- function(x, ...) { - .ess(.split_chains(x)) +ess_mean.default <- function(x, weights = NULL, ...) { + + if (is.null(weights)) { + .ess(.split_chains(x)) + } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + .ess_weighted(x, weights, r_eff = r_eff, ...) + } } #' @rdname ess_mean #' @export ess_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_mean, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_mean, weights = weights, ...) } #' Effective sample size for the standard deviation @@ -334,6 +405,8 @@ ess_mean.rvar <- function(x, ...) { #' Compute an effective sample size estimate for the standard deviation (SD) #' estimate of a single variable. This is defined as the effective sample size #' estimate for the absolute deviation from mean. +#' If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -353,20 +426,36 @@ ess_sd <- function(x, ...) UseMethod("ess_sd") #' @rdname ess_sd #' @export -ess_sd.default <- function(x, ...) { - .ess(.split_chains(abs(x-mean(x)))) +ess_sd.default <- function(x, weights = NULL, ...) { + if (is.null(weights)) { + .ess(.split_chains(abs(x - mean(x)))) + } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + + r_eff <- .ess(.split_chains(abs(x - mean(x)))) / (nrow(x) * ncol(x)) + .ess_weighted(abs(x - mean(x)), weights = weights, r_eff = r_eff, ...) + } } #' @rdname ess_sd #' @export ess_sd.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_sd, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_sd, weights = weights, ...) } +# TODO: ess_weights + #' Monte Carlo standard error for quantiles #' #' Compute Monte Carlo standard errors for quantile estimates of a -#' single variable. +#' single variable. If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -389,23 +478,36 @@ mcse_quantile <- function(x, probs = c(0.05, 0.95), ...) { #' @rdname mcse_quantile #' @export -mcse_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { +mcse_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights = NULL, ...) { probs <- as.numeric(probs) if (any(probs < 0 | probs > 1)) { stop_no_call("'probs' must contain values between 0 and 1.") } names <- as_one_logical(names) - out <- ulapply(probs, .mcse_quantile, x = x) + if (is.null(weights)) { + out <- ulapply(probs, .mcse_quantile, x = x) + } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + + out <- ulapply(probs, .mcse_quantile_weighted, x = x, weights = weights) + } if (names) { names(out) <- paste0("mcse_q", probs * 100) } + out } #' @rdname mcse_quantile #' @export mcse_quantile.rvar <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { - summarise_rvar_by_element_with_chains(x, mcse_quantile, probs, names, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, mcse_quantile, probs, names, weights = weights, ...) } #' @rdname mcse_quantile @@ -415,6 +517,7 @@ mcse_median <- function(x, ...) { } # MCSE of a single quantile +# TODO: refer to paper .mcse_quantile <- function(x, prob) { ess <- ess_quantile(x, prob) p <- c(0.1586553, 0.8413447) @@ -423,13 +526,32 @@ mcse_median <- function(x, ...) { S <- length(ssims) th1 <- ssims[max(floor(a[1] * S), 1)] th2 <- ssims[min(ceiling(a[2] * S), S)] + as.vector((th2 - th1) / 2) } +.mcse_quantile_weighted <- function(x, prob, weights) { + ess <- ess_quantile(x, prob, weights = weights) + p <- c(0.1586553, 0.8413447) + a <- qbeta(p, ess * prob + 1, ess * (1 - prob) + 1) + x_idx <- order(x) + x_sorted <- x[x_idx] + weights_sorted <- weights[x_idx] + S <- length(x) + + cweights <- cumsum(weights_sorted) + th1 <- x_sorted[max(max(which(cweights < a[1])), 1)] + th2 <- x_sorted[min(min(which(cweights > a[2])), S)] + + as.vector((th2 - th1) / 2) +} + + #' Monte Carlo standard error for the mean #' #' Compute the Monte Carlo standard error for the mean (expectation) of a -#' single variable. +#' single variable. If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -449,14 +571,27 @@ mcse_mean <- function(x, ...) UseMethod("mcse_mean") #' @rdname mcse_mean #' @export -mcse_mean.default <- function(x, ...) { - sd(x) / sqrt(ess_mean(x)) +mcse_mean.default <- function(x, weights = NULL, ...) { + if (is.null(weights)) { + sd(x) / sqrt(ess_mean(x)) + } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + .mcse_weighted(x, weights, r_eff, ...) + } } #' @rdname mcse_mean #' @export mcse_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, mcse_mean, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, mcse_mean, weights = weights, ...) } #' Monte Carlo standard error for the standard deviation @@ -484,28 +619,60 @@ mcse_sd <- function(x, ...) UseMethod("mcse_sd") #' @rdname mcse_sd #' @export -mcse_sd.default <- function(x, ...) { - # var/sd are not a simple expectation of g(X), e.g. variance - # has (X-E[X])^2. The following ESS is based on a relevant quantity - # in the computation and is empirically a good choice. - sims_c <- x - mean(x) - ess <- ess_mean((sims_c)^2) - # Variance of variance estimate by Kenney and Keeping (1951, p. 141), - # which doesn't assume normality of sims. - Evar <- mean(sims_c^2) - varvar <- (mean(sims_c^4) - Evar^2) / ess - # The first order Taylor series approximation of variance of sd. - # Kenney and Keeping (1951, p. 141) write "...since fluctuations of - # any moment are of order N^{-1/2}, squares and higher powers of - # differentials of the moments can be neglected " - varsd <- varvar / Evar / 4 - sqrt(varsd) +mcse_sd.default <- function(x, weights = NULL, ...) { + + if (is.null(weights)) { + + # var/sd are not a simple expectation of g(X), e.g. variance + # has (X-E[X])^2. The following ESS is based on a relevant quantity + # in the computation and is empirically a good choice. + sims_c <- x - mean(x) + ess <- ess_mean((sims_c)^2) + # Variance of variance estimate by Kenney and Keeping (1951, p. 141), + # which doesn't assume normality of sims. + Evar <- mean(sims_c^2) + varvar <- (mean(sims_c^4) - Evar^2) / ess # (Equation 6.20) + + # The first order Taylor series approximation of variance of sd. + # Kenney and Keeping (1951, p. 141) write "...since fluctuations of + # any moment are of order N^{-1/2}, squares and higher powers of + # differentials of the moments can be neglected " + varsd <- varvar / Evar / 4 + sqrt(varsd) + + } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + + # for weights try varvar weighted / varvar unweighted to see relative efficiency of weights + + first_moment_weighted <- weighted.mean(x, w = weights) + + x_centered <- x - first_moment_weighted + second_moment_weighted <- weighted.mean(x_centered^2, w = weights) + fourth_moment_weighted <- weighted.mean(x_centered^4, w = weights) + + r_eff <- .ess(x_centered^2) / (nrow(x) * ncol(x)) + weighted_ess <- .ess_weighted(x_centered^2, weights = weights, r_eff = r_eff) + + # Kenney and Keeping (1951, eq 6.20) + varvar_weighted <- (fourth_moment_weighted - second_moment_weighted^2) / weighted_ess + + # First-order Taylor series approximation + varsd <- varvar_weighted / second_moment_weighted / 4 + sqrt(varsd) + } } #' @rdname mcse_sd #' @export mcse_sd.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, mcse_sd, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, mcse_sd, weights = weights, ...) } #' Compute Quantiles @@ -541,12 +708,9 @@ quantile2.default <- function( ) { names <- as_one_logical(names) na.rm <- as_one_logical(na.rm) - if (!na.rm && anyNA(x)) { - # quantile itself doesn't handle this case (#110) - out <- rep(NA_real_, length(probs)) - } else { - out <- quantile(x, probs = probs, na.rm = na.rm, ...) - } + + out <- weighted_quantile(x, probs = probs, na.rm = na.rm, ...) + if (names) { names(out) <- paste0("q", probs * 100) } else { @@ -560,7 +724,12 @@ quantile2.default <- function( quantile2.rvar <- function( x, probs = c(0.05, 0.95), na.rm = FALSE, names = TRUE, ... ) { - summarise_rvar_by_element_with_chains(x, quantile2, probs, na.rm, names, ...) + weights <- weights(x) + summarise_rvar_by_element(x, function(draws) { + quantile2( + draws, probs = probs, weights = weights, na.rm = na.rm, names = names, ... + ) + }) } # internal ---------------------------------------------------------------- @@ -782,6 +951,23 @@ fold_draws <- function(x) { ess } +.mcse_weighted <- function(x, weights, r_eff, ...) { + # Vehtari et al. 2022 equation 6 + + x <- as.numeric(x) + weighted_mean <- matrixStats::weightedMean(x, w = weights) + + sqrt(weights^2 %*% (x - c(weighted_mean))^2 / r_eff) +} + +.ess_weighted <- function(x, weights, r_eff, ...) { + # Vehtari et al. 2022 equation 7 + mcse <- .mcse_weighted(x, weights, r_eff, ...) + + var <- mean((x - mean(x))^2) + var / mcse^2 +} + # should NA be returned by a convergence diagnostic? should_return_NA <- function(x, tol = .Machine$double.eps) { if (anyNA(x) || checkmate::anyInfinite(x)) { diff --git a/R/discrete-summaries.R b/R/discrete-summaries.R index 3c141c9d..f4f4e140 100644 --- a/R/discrete-summaries.R +++ b/R/discrete-summaries.R @@ -2,11 +2,9 @@ #' #' Normalized entropy, for measuring dispersion in draws from categorical distributions. #' -#' @param x (multiple options) A vector to be interpreted as draws from -#' a categorical distribution, such as: -#' - A [factor] -#' - A [numeric] (should be [integer] or integer-like) -#' - An [rvar], [rvar_factor], or [rvar_ordered] +#' @template args-summaries-x-categorical +#' @template args-summaries-weights +#' @template args-methods-dots #' #' @details #' Calculates the normalized Shannon entropy of the draws in `x`. This value is @@ -51,14 +49,14 @@ #' xy #' entropy(xy) #' @export -entropy <- function(x) { +entropy <- function(x, ...) { UseMethod("entropy") } #' @rdname entropy #' @export -entropy.default <- function(x) { +entropy.default <- function(x, weights = NULL, ...) { if (anyNA(x)) return(NA_real_) - p <- prop.table(simple_table(x)$count) + p <- prop.table(weighted_simple_table(x, weights)$count) n <- length(p) if (n == 1) { @@ -71,8 +69,8 @@ entropy.default <- function(x) { } #' @rdname entropy #' @export -entropy.rvar <- function(x) { - summarise_rvar_by_element(x, entropy) +entropy.rvar <- function(x, ...) { + summarise_rvar_by_element(x, entropy, weights = weights(x)) } @@ -85,6 +83,8 @@ entropy.rvar <- function(x) { #' - A [factor] #' - A [numeric] (should be [integer] or integer-like) #' - An [rvar], [rvar_factor], or [rvar_ordered] +#' @template args-summaries-weights +#' @template args-methods-dots #' #' @details #' Calculates Tastle and Wierman's (2007) *dissention* measure: @@ -125,12 +125,12 @@ entropy.rvar <- function(x) { #' xy #' dissent(xy) #' @export -dissent <- function(x) { +dissent <- function(x, ...) { UseMethod("dissent") } #' @rdname dissent #' @export -dissent.default <- function(x) { +dissent.default <- function(x, weights = NULL, ...) { if (anyNA(x)) return(NA_real_) if (length(x) == 0) return(0) @@ -141,21 +141,22 @@ dissent.default <- function(x) { d <- diff(range(x)) } - tab <- simple_table(x) + tab <- weighted_simple_table(x, weights) p <- prop.table(tab$count) if (length(p) == 1) { out <- 0 } else { x_i <- tab$x - out <- -sum(p * log2(1 - abs(x_i - mean(x)) / d)) + mean_x <- if (is.null(weights)) mean(x) else weighted.mean(x, weights) + out <- -sum(p * log2(1 - abs(x_i - mean_x) / d)) } out } #' @rdname dissent #' @export -dissent.rvar <- function(x) { - summarise_rvar_by_element(x, dissent) +dissent.rvar <- function(x, ...) { + summarise_rvar_by_element(x, dissent, weights = weights(x)) } @@ -163,11 +164,9 @@ dissent.rvar <- function(x) { #' #' Modal category of a vector. #' -#' @param x (multiple options) A vector to be interpreted as draws from -#' a categorical distribution, such as: -#' - A [factor] -#' - A [numeric] (should be [integer] or integer-like) -#' - An [rvar], [rvar_factor], or [rvar_ordered] +#' @template args-summaries-x-categorical +#' @template args-summaries-weights +#' @template args-methods-dots #' #' @details #' Finds the modal category (i.e., most frequent value) in `x`. In the case of @@ -192,20 +191,20 @@ dissent.rvar <- function(x) { #' xy #' modal_category(xy) #' @export -modal_category <- function(x) { +modal_category <- function(x, ...) { UseMethod("modal_category") } #' @rdname modal_category #' @export -modal_category.default <- function(x) { +modal_category.default <- function(x, weights = NULL, ...) { if (anyNA(x)) return(NA) - tab <- simple_table(x) + tab <- weighted_simple_table(x, weights) tab$x[which.max(tab$count)] } #' @rdname modal_category #' @export -modal_category.rvar <- function(x) { - summarise_rvar_by_element(x, modal_category) +modal_category.rvar <- function(x, ...) { + summarise_rvar_by_element(x, modal_category, weights = weights(x)) } @@ -231,3 +230,25 @@ simple_table <- function(x) { count = tabulate(x_int, nbins = length(values)) ) } + +#' A weighted version of simple_table +#' @param x a vector (numeric, factor, character, etc) +#' @param weights weights +#' @returns a list with two components of the same length +#' - `x`: unique values from the input `x` +#' - `count`: sum of weights for each unique value of `x` +#' @noRd +weighted_simple_table <- function(x, weights) { + if (is.null(weights)) return(simple_table(x)) + stopifnot(identical(length(x), length(weights))) + + if (is.factor(x)) { + values <- levels(x) + } else { + values <- unique(x) + } + list( + x = values, + count = vapply(split(weights, factor(x, values)), sum, numeric(1), USE.NAMES = FALSE) + ) +} diff --git a/R/draws-index.R b/R/draws-index.R index beb83fa3..fec535e6 100644 --- a/R/draws-index.R +++ b/R/draws-index.R @@ -266,8 +266,11 @@ nchains.rvar <- function(x) { # attribute on an rvar, ALWAYS use this function so that the proxy # cache is invalidated `nchains_rvar<-` <- function(x, value) { - attr(x, "nchains") <- value - invalidate_rvar_cache(x) + if (attr(x, "nchains") != value) { + attr(x, "nchains") <- value + x <- invalidate_rvar_cache(x) + } + x } diff --git a/R/mutate_variables.R b/R/mutate_variables.R index 4ead827e..fd42dd14 100644 --- a/R/mutate_variables.R +++ b/R/mutate_variables.R @@ -86,7 +86,7 @@ mutate_variables.draws_rvars <- function(.x, ...) { for (var in names(dots)) { .x[[var]] <- as_rvar(eval_tidy(dots[[var]], .x, env)) } - conform_rvar_ndraws_nchains(.x) + conform_rvar_nchains_ndraws_weights(.x) } # evaluate an expression passed to 'mutate_variables' and check its validity diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index a0bb8486..794c7188 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -46,21 +46,55 @@ pareto_khat.default <- function(x, #' @rdname pareto_khat #' @export -pareto_khat.rvar <- function(x, ...) { - draws_diags <- summarise_rvar_by_element_with_chains( - x, - pareto_smooth.default, - return_k = TRUE, - smooth_draws = FALSE, - ... - ) - dim(draws_diags) <- dim(draws_diags) %||% length(draws_diags) - margins <- seq_along(dim(draws_diags)) +pareto_khat.rvar <- function(x, verbose = FALSE, ...) { + if (is.null(weights(x))) { + draws_diags <- summarise_rvar_by_element_with_chains( + x, + pareto_smooth.default, + smooth_draws = FALSE, + return_k = TRUE, + verbose = verbose, + ... + ) - diags <- list( - khat = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat) - ) + dim(draws_diags) <- dim(draws_diags) %||% length(draws_diags) + margins <- seq_along(dim(draws_diags)) + + diags <- list( + khat = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat) + ) + } else { + + # take the max of khat for x * weights and khat for weights + weights_diags <- pareto_khat( + weights(x, log = TRUE), + are_log_weights = TRUE, + ... + ) + + w <- weights(x) + + xu <- weight_draws(x, NULL) + xu <- xu * rvar(w) + + product_diags <- summarise_rvar_by_element_with_chains( + xu, + pareto_khat.default, + verbose = verbose, + ... + ) + + dim(product_diags) <- dim(product_diags) %||% length(product_diags) + margins <- seq_along(dim(product_diags)) + diags <- list( + khat = apply(product_diags, margins, + function(x) { + max(x[[1]]$khat, + weights_diags$khat) + }) + ) + } diags } @@ -149,6 +183,8 @@ pareto_diags.default <- function(x, #' @rdname pareto_diags #' @export pareto_diags.rvar <- function(x, ...) { + + if (is.null(weights(x))) { draws_diags <- summarise_rvar_by_element_with_chains( x, pareto_smooth.default, @@ -167,6 +203,35 @@ pareto_diags.rvar <- function(x, ...) { khat_threshold = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat_threshold), convergence_rate = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$convergence_rate) ) + } else { + + # take the max of khat for x * weights and khat for weights + + weights_diags <- pareto_diags( + weights(x, log = TRUE), + are_log_weights = TRUE, + ... + ) + + w <- weights(x) + + x <- weight_draws(x, NULL) + product_diags <- summarise_rvar_by_element_with_chains( + x * rvar(w, nchains = nchains(x)), + pareto_diags, + ... + ) + + dim(product_diags) <- dim(product_diags) %||% length(product_diags) + margins <- seq_along(dim(product_diags)) + + diags <- list( + khat = apply(product_diags, margins, function(x) max(x[[1]]$khat, weights_diags$khat)), + min_ss = apply(product_diags, margins, function(x) max(x[[1]]$min_ss, weights_diags$min_ss)), + khat_threshold = apply(product_diags, margins, function(x) max(x[[1]]$khat_threshold, weights_diags$khat_threshold)), + convergence_rate = apply(product_diags, margins, function(x) min(x[[1]]$convergence_rate, weights_diags$convergence_rate)) + ) + } diags } @@ -250,7 +315,7 @@ pareto_smooth.rvar <- function(x, return_k = FALSE, extra_diags = FALSE, ...) { #' @export pareto_smooth.default <- function(x, tail = c("both", "right", "left"), - r_eff = 1, + r_eff = NULL, ndraws_tail = NULL, return_k = FALSE, extra_diags = FALSE, @@ -279,7 +344,7 @@ pareto_smooth.default <- function(x, if (are_log_weights) { tail <- "right" } - + tail <- match.arg(tail) S <- length(x) @@ -330,7 +395,7 @@ pareto_smooth.default <- function(x, k <- max(left_k, right_k) x <- smoothed$x - + } else { smoothed <- .pareto_smooth_tail( @@ -444,7 +509,7 @@ pareto_convergence_rate.rvar <- function(x, ...) { # shift log values for safe exponentiation x <- x - max(x) } - + tail <- match.arg(tail) S <- length(x) @@ -458,10 +523,10 @@ pareto_convergence_rate.rvar <- function(x, ...) { draws_tail <- ord$x[tail_ids] cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values - + max_tail <- max(draws_tail) min_tail <- min(draws_tail) - + if (ndraws_tail >= 5) { ord <- sort.int(x, index.return = TRUE) if (abs(max_tail - min_tail) < .Machine$double.eps / 100) { @@ -617,7 +682,7 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) { msg <- NULL if (!are_weights) { - + if (khat > 1) { msg <- paste0(msg, " Mean does not exist, making empirical mean estimate of the draws not applicable.") } else { @@ -630,7 +695,7 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) { } } else { if (khat > khat_threshold || khat > 0.7) { - msg <- paste0(msg, " Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n") + msg <- paste0(msg, " Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n") } } message("Pareto k-hat = ", round(khat, 2), ".", msg) diff --git a/R/resample_draws.R b/R/resample_draws.R index f14bbc86..5784b2d1 100644 --- a/R/resample_draws.R +++ b/R/resample_draws.R @@ -72,7 +72,7 @@ resample_draws.draws <- function(x, weights = NULL, method = "stratified", weights <- rep.int(1/ndraws_total, ndraws_total) } # resampling invalidates stored weights - x <- remove_variables(x, ".log_weight") + x <- weight_draws(x, NULL) } else { weights <- weights / sum(weights) } diff --git a/R/rvar-.R b/R/rvar-.R index 33db28fd..32ecdd05 100755 --- a/R/rvar-.R +++ b/R/rvar-.R @@ -26,6 +26,10 @@ #' is ignored and the second dimension of `x` is used to index chains. #' Internally, the array will be converted to a format without the chain index. #' Ignored when `x` is already an [`rvar`]. +#' @param log_weights (numeric vector) A vector of log weights of length `ndraws(x)`. +#' Weights will be internally stored on the log scale and will not be normalized, +#' but normalized (non-log) weights can be returned via the [weights.rvar()] +#' method later. #' #' @details #' @@ -54,6 +58,62 @@ #' on the underlying array using the [draws_of()] function. To re-use existing #' random number generator functions to efficiently create `rvar`s, use [rvar_rng()]. #' +#' @section `rvar` Internals: +#' +#' The `rvar` datatype is not intended to be modified directly; rather, you should +#' only use exported functions from \pkg{posterior}, such as [rvar()], [draws_of()], +#' [log_weights()], and [weight_draws()] to create and manipulate `rvar`s. +#' For completeness, and to aid internal development, this section documents the +#' internal structure of the `rvar` datatype. While the public-facing API is +#' intended to be stable, **this internal structure is subject to change without +#' notice**. +#' +#' An `rvar` `x` consists of: +#' +#' - A zero-length `list()` with class `c("rvar", "vctrs_vctr")`. If `draws_of(x)` +#' is a [`factor`], the class will be `c("rvar_factor", "rvar", "vctrs_vctr")`, +#' and if `draws_of(x)` is an [`ordered`], the class will be +#' `c("rvar_ordered", "rvar_factor", "rvar", "vctrs_vctr")`. These classes are +#' set automatically if the underlying draws are modified. +#' +#' The list has these attributes: +#' +#' - `draws`: An [`array`] containing the draws, where the first dimension +#' indexes draws. **Always** get this attribute using [draws_of()] and set it +#' using `draws_of(x) <- value`. To simplify programming, `length(dim(draws_of(x)))` +#' is guaranteed to always be greater than or equal to 2. Zero-length `rvar`s +#' have `dim(draws_of(x)) = c(1,0)`. The draws may be a [`numeric`], +#' [`integer`], [`logical`], [`factor`], or [`ordered`] array. +#' +#' The dimensions after the first are reported as the dimensions of `x`; i.e. +#' `dim(x) = dim(draws_of(x))[-1]` and `dimnames(x) = dimnames(draws_of(x))[-1]`. +#' Because `rvar`s *always* have dimensions (unlike base R datatypes, where +#' there is a distinction between a length-*n* vector with no dimensions and +#' a length-*n* array with only 1 dimension), `names(x) = dimnames(x)[[1]]`; +#' i.e., `names()` refers to the names along the first dimension only. +#' +#' - `nchains`: A scalar [`numeric`] giving the number of chains in this `rvar`. +#' **Always** get this attribute using [nchains()]. It cannot be set using the +#' public (exported) API, but can be modified through other functions (e.g. +#' [merge_chains()] or by creating a new [rvar()]). In internal code, **always** +#' set it using `nchains_rvar(x) <- value`. +#' +#' - `log_weights`: A vector [`numeric`] with length `ndraws(x)` giving the +#' log weight on each draw of this `rvar`, or `NULL` if the `rvar` is not +#' weighted. **Always** get this attribute using [weights()] or [log_weights()], +#' and set this attributes using [weight_draws()]. In internal code, it may +#' also be modified directly using `log_weights_rvar(x) <- value`. +#' +#' - `cache`: An [`environment`] that may contain cached output of the \pkg{vctrs} +#' proxy functions on `x` to improve performance of code that makes multiple +#' calls to those functions. The cache is updated automatically and invalidated +#' when necessary so long as the `rvar` is only modified using the functions +#' described in this section (or other functions in the publicly-exported +#' `rvar` API). The environment may contain these variables: +#' +#' - `vec_proxy`: cached output of [vctrs::vec_proxy()]. +#' - `vec_proxy_equal`: cached output of [vctrs::vec_proxy_equal()]. +#' #' @seealso [as_rvar()] to convert objects to `rvar`s. See [rdo()], [rfun()], and #' [rvar_rng()] for higher-level interfaces for creating `rvar`s. #' @@ -90,7 +150,11 @@ #' x #' #' @export -rvar <- function(x = double(), dim = NULL, dimnames = NULL, nchains = NULL, with_chains = FALSE) { +rvar <- function( + x = double(), dim = NULL, dimnames = NULL, + nchains = NULL, with_chains = FALSE, + log_weights = NULL +) { if (is_rvar(x)) { nchains <- nchains %||% nchains(x) with_chains = FALSE @@ -105,7 +169,7 @@ rvar <- function(x = double(), dim = NULL, dimnames = NULL, nchains = NULL, with nchains <- nchains %||% 1L } - out <- new_rvar(x, .nchains = nchains) + out <- new_rvar(x, .nchains = nchains, .log_weights = log_weights) if (!is.null(dim)) { dim(out) <- dim @@ -118,7 +182,7 @@ rvar <- function(x = double(), dim = NULL, dimnames = NULL, nchains = NULL, with } #' @importFrom vctrs new_vctr -new_rvar <- function(x = double(), .nchains = 1L) { +new_rvar <- function(x = double(), .nchains = 1L, .log_weights = NULL) { if (is.null(x)) { x <- double() } @@ -128,11 +192,13 @@ new_rvar <- function(x = double(), .nchains = 1L) { .ndraws <- dim(x)[[1]] .nchains <- as_one_integer(.nchains) check_nchains_compat_with_ndraws(.nchains, .ndraws) + .log_weights <- validate_weights(.log_weights, .ndraws, log = TRUE, pareto_smooth = FALSE) structure( list(), draws = x, nchains = .nchains, + log_weights = .log_weights, class = get_rvar_class(x), cache = new.env(parent = emptyenv()) ) @@ -252,14 +318,14 @@ rep.rvar <- function(x, times = 1, length.out = NA, each = 1, ...) { dim = dim(draws) dim[[2]] = dim[[2]] * times dim(rep_draws) = dim - out <- new_rvar(rep_draws, .nchains = nchains(x)) + draws_of(x) <- rep_draws } else { # use `length.out` rep_draws = rep_len(draws, length.out * ndraws(x)) dim(rep_draws) = c(ndraws(x), length(rep_draws) / ndraws(x)) - out <- new_rvar(rep_draws, .nchains = nchains(x)) + draws_of(x) <- rep_draws } - out + x } #' @rawNamespace S3method(rep.int,rvar,rep_int_rvar) @@ -422,7 +488,7 @@ rvar_ifelse = function(test, yes, no) { stop_no_call("`rvar_ifelse(test, yes, no)` requires `test` to be a logical rvar, or castable to one.") } c(yes, no) %<-% vec_cast_common(yes, no) - c(test, yes, no) %<-% conform_array_dims(conform_rvar_ndraws(list(test, yes, no))) + c(test, yes, no) %<-% conform_array_dims(conform_rvar_ndraws_weights(list(test, yes, no))) test_draws <- draws_of(test) false_draws <- test_draws %in% FALSE @@ -538,6 +604,27 @@ nchains2_common <- function(nchains_x, nchains_y) { } } +# find common weights for two rvars +#' @param promote_unweighted should unweighted rvars be promoted to have the +#' weights of weighted rvars they are combined with? typically `FALSE` for +#' binding operations and `TRUE` for math operations. +#' @noRd +weights2_common <- function(weights_x, weights_y, promote_unweighted = TRUE) { + if (promote_unweighted && is.null(weights_x)) { + weights_y + } else if (promote_unweighted && is.null(weights_y)) { + weights_x + } else if (identical(weights_x, weights_y)) { + weights_x + } else { + stop_no_call( + "Random variables have different log weights and cannot be used together:\n", + "<", vctrs::vec_ptype_abbr(weights_x), "> ", toString(weights_x, width = 60), "\n", + "<", vctrs::vec_ptype_abbr(weights_y), "> ", toString(weights_y, width = 60) + ) + } +} + # check that the given number of chains is compatible with the given number of draws check_nchains_compat_with_ndraws <- function(nchains, ndraws) { # except with constants, nchains must divide the number of draws @@ -549,8 +636,12 @@ check_nchains_compat_with_ndraws <- function(nchains, ndraws) { } } -# given two rvars, conform their number of chains -# so they can be used together (or throw an error if they can't be) +#' given a list of rvars, conform their number of chains +#' so they can be used together (or throw an error if they can't be). Constants +#' are treated as having any number of draws +#' @param rvars a list of rvars +#' @returns modified list of rvars all having the same number of chains +#' @noRd conform_rvar_nchains <- function(rvars) { # find the number of chains to use, treating constants as having any number of chains nchains_or_null <- lapply(rvars, function(x) if (ndraws(x) == 1) NULL else nchains(x)) @@ -563,47 +654,79 @@ conform_rvar_nchains <- function(rvars) { rvars } -# given two rvars, conform their number of draws -# so they can be used together (or throw an error if they can't be) -# @param keep_constants keep constants as 1-draw rvars -conform_rvar_ndraws <- function(rvars, keep_constants = FALSE) { - # broadcast to a common number of chains. If keep_constants = TRUE, - # constants will not be broadcast. - .ndraws = Reduce(ndraws2_common, lapply(rvars, ndraws)) +#' given a list of rvars, conform their their weights +#' so they can be used together (or throw an error if they can't be) +#' @param rvars a list of rvars +#' @param promote_unweighted should unweighted rvars be promoted to have the +#' weights of weighted rvars they are combined with? typically `FALSE` for +#' binding operations and `TRUE` for math operations. +#' @returns modified list of rvars all having the same weights. +#' @noRd +conform_rvar_weights <- function(rvars, promote_unweighted = TRUE) { + # only check rvars that are not constants --- constant rvars can + # always take on the weights of others + not_constant <- vapply(rvars, ndraws, numeric(1)) > 1 + weights_list <- lapply(rvars[not_constant], log_weights) + .log_weights <- Reduce( + function(...) weights2_common(..., promote_unweighted = promote_unweighted), + weights_list + ) + + for (i in seq_along(rvars)) { + log_weights_rvar(rvars[[i]]) <- .log_weights + } + + rvars +} + +#' given a list of rvars, conform their number of draws and their weights +#' so they can be used together (or throw an error if they can't be) +#' @param rvars a list of rvars +#' @param promote_unweighted should unweighted rvars be promoted to have the +#' weights of weighted rvars they are combined with? typically `FALSE` for +#' binding operations and `TRUE` for math operations. +#' @returns modified list of rvars all having the same number of draws and the +#' same weights. +#' @noRd +conform_rvar_ndraws_weights <- function(rvars, promote_unweighted = TRUE) { + # must conform weights before ndraws so that constants are handled properly + rvars <- conform_rvar_weights(rvars, promote_unweighted = promote_unweighted) + + .ndraws <- Reduce(ndraws2_common, lapply(rvars, ndraws)) + for (i in seq_along(rvars)) { - rvars[[i]] <- broadcast_draws(rvars[[i]], .ndraws, keep_constants) + rvars[[i]] <- broadcast_draws(rvars[[i]], .ndraws) } rvars } -# given multiple rvars, conform their number of draws and chains -# so they can be used together (or throw an error if they can't be) -# @param keep_constants keep constants as 1-draw rvars -conform_rvar_ndraws_nchains <- function(rvars, keep_constants = FALSE) { +#' given a list of rvars, conform their number of draws, number of chains, and +#' their weights so they can be used together (or throw an error if they can't be) +#' @param rvars a list of rvars +#' @param promote_unweighted should unweighted rvars be promoted to have the +#' weights of weighted rvars they are combined with? typically `FALSE` for +#' binding operations and `TRUE` for math operations. +#' @returns modified list of rvars all having the same number of chains, same +#' number of draws, and the same weights. +#' @noRd +conform_rvar_nchains_ndraws_weights <- function(rvars, promote_unweighted = TRUE) { + # must conform nchains before ndraws so that constants are handled properly rvars <- conform_rvar_nchains(rvars) - rvars <- conform_rvar_ndraws(rvars) + rvars <- conform_rvar_ndraws_weights(rvars, promote_unweighted = promote_unweighted) rvars } -# Check that the first rvar can be conformed to the dimensions of the second, -# ignoring 1s -check_rvar_dims_first <- function(x, y) { - x_dim <- dim(x) - x_dim_dropped <- as.integer(x_dim[x_dim != 1]) - y_dim <- dim(y) - y_dim_dropped <- as.integer(y_dim[y_dim != 1]) - - if (length(x_dim_dropped) == 0) { - # x can be treated as scalar, do so - dim(x) <- rep(1, length(dim(y))) - } else if (identical(x_dim_dropped, y_dim_dropped)) { - dim(x) <- dim(y) - } else { - stop_no_call("Cannot assign an rvar with dimension ", paste0(x_dim, collapse = ","), - " to an rvar with dimension ", paste0(y_dim, collapse = ",")) +#' Check that an rvar is a scalar (length 1) +#' @param x rvar to check +#' @returns x with `dim(x) == 1`, or throws an error if `x` is not scalar. +#' @noRd +check_rvar_is_scalar <- function(x) { + if (length(x) != 1) { + stop_no_call("Cannot insert an rvar with length != 1 into another rvar using `[[`") } + dim(x) <- 1 x } @@ -727,20 +850,17 @@ broadcast_array <- function(x, dim, broadcast_scalars = TRUE) { } # broadcast the draws dimension of an rvar to the requested size -broadcast_draws <- function(x, .ndraws, keep_constants = FALSE) { +broadcast_draws <- function(x, .ndraws) { ndraws_x = ndraws(x) - if ( - (ndraws_x == 1 && keep_constants) || - (ndraws_x == .ndraws) - ) { - x - } else { + + if (ndraws_x != .ndraws) { draws <- draws_of(x) new_dim <- dim(draws) new_dim[1] <- .ndraws - - new_rvar(broadcast_array(draws, new_dim), .nchains = nchains(x)) + draws_of(x) <- broadcast_array(draws, new_dim) } + + x } #' copy the dimension names (and name of the dimension) from dimension src_i @@ -892,7 +1012,8 @@ summarise_rvar_within_draws <- function(x, .f, ..., .transpose = FALSE, .when_em } else { draws <- apply(draws, 1, .f, ...) if (.transpose) draws <- t(draws) - new_rvar(draws, .nchains = nchains(x)) + draws_of(x) <- draws + x } } @@ -923,7 +1044,8 @@ summarise_rvar_within_draws_via_matrix <- function(x, .name, .f, ..., .ordered_o .draws <- .f(draws_of(x), ...) } - new_rvar(.draws, .nchains = nchains(x)) + draws_of(x) <- .draws + x } # apply vectorized function to an rvar's draws diff --git a/R/rvar-bind.R b/R/rvar-bind.R index 68b38885..d6db0dd2 100755 --- a/R/rvar-bind.R +++ b/R/rvar-bind.R @@ -89,9 +89,10 @@ broadcast_and_bind_rvars.rvar <- function(x, y, axis = 1) { draws_axis <- axis + 1 # because first dim is draws - # conform nchains + # conform nchains and weights # (don't need to do draws here since that's part of the broadcast below) c(x, y) %<-% conform_rvar_nchains(list(x, y)) + c(x, y) %<-% conform_rvar_weights(list(x, y), promote_unweighted = FALSE) # broadcast each array to the desired dimensions # (except along the axis we are binding along) @@ -112,7 +113,8 @@ broadcast_and_bind_rvars.rvar <- function(x, y, axis = 1) { # bind along desired axis result <- new_rvar( abind(draws_x, draws_y, along = draws_axis, use.dnns = TRUE), - .nchains = nchains(x) + .nchains = nchains(x), + .log_weights = log_weights(x) ) } diff --git a/R/rvar-cast.R b/R/rvar-cast.R index bdcfb209..adde92b7 100755 --- a/R/rvar-cast.R +++ b/R/rvar-cast.R @@ -213,6 +213,7 @@ vec_proxy.rvar = function(x, ...) { #' @noRd make_rvar_proxy = function(x) { nchains <- nchains(x) + log_weights <- log_weights(x) draws <- draws_of(x) is <- seq_len(NROW(x)) names(is) <- rownames(x) @@ -220,6 +221,7 @@ make_rvar_proxy = function(x) { list( index = i, nchains = nchains, + log_weights = log_weights, draws = draws ) }) @@ -246,7 +248,9 @@ vec_restore.rvar <- function(x, ...) { # find runs where the same underlying draws are in the proxy different_draws_from_previous <- vapply(seq_along(x)[-1], FUN.VALUE = logical(1), function(i) { - !identical(x[[i]]$draws, x[[i - 1]]$draws) || !identical(x[[i]]$nchains, x[[i - 1]]$nchains) + !identical(x[[i]]$draws, x[[i - 1]]$draws) || + !identical(x[[i]]$nchains, x[[i - 1]]$nchains) || + !identical(x[[i]]$log_weights, x[[i - 1]]$log_weights) }) draws_groups <- cumsum(c(TRUE, different_draws_from_previous)) @@ -254,7 +258,7 @@ vec_restore.rvar <- function(x, ...) { groups <- split(x, draws_groups) rvars <- lapply(groups, function(x) { i <- vapply(x, `[[`, "index", FUN.VALUE = numeric(1)) - rvar <- new_rvar(x[[1]]$draws, .nchains = x[[1]]$nchains) + rvar <- new_rvar(x[[1]]$draws, .nchains = x[[1]]$nchains, .log_weights = x[[1]]$log_weights) if (length(dim(rvar)) > 1) { rvar[i, ] } else { @@ -321,6 +325,7 @@ vec_proxy_equal.rvar = function(x, ...) { make_rvar_proxy_equal = function(x) { lapply(as.list(x), function(x) list( nchains = nchains(x), + log_weights = log_weights(x), draws = draws_of(x) )) } diff --git a/R/rvar-dist.R b/R/rvar-dist.R index 2c390a2c..ecb62389 100755 --- a/R/rvar-dist.R +++ b/R/rvar-dist.R @@ -40,8 +40,9 @@ #' @name rvar-dist #' @export density.rvar <- function(x, at, ...) { + weights <- weights(x) summarise_rvar_by_element(x, function(draws) { - d <- density(draws, cut = 0, ...) + d <- density(draws, weights = weights, cut = 0, ...) f <- approxfun(d$x, d$y, yleft = 0, yright = 0) f(at) }) @@ -50,11 +51,12 @@ density.rvar <- function(x, at, ...) { #' @rdname rvar-dist #' @export density.rvar_factor <- function(x, at, ...) { + weights <- weights(x) at <- as.numeric(factor(at, levels = levels(x))) - nbins <- nlevels(x) summarise_rvar_by_element(x, function(draws) { - props <- prop.table(tabulate(draws, nbins = nbins))[at] + tab <- weighted_simple_table(draws, weights) + props <- prop.table(tab$count)[at] props }) } @@ -66,8 +68,9 @@ distributional::cdf #' @rdname rvar-dist #' @export cdf.rvar <- function(x, q, ...) { + weights <- weights(x) summarise_rvar_by_element(x, function(draws) { - ecdf(draws)(q) + weighted_ecdf(draws, weights)(q) }) } @@ -76,7 +79,7 @@ cdf.rvar <- function(x, q, ...) { cdf.rvar_factor <- function(x, q, ...) { # CDF is not defined for unordered distributions # generate an all-NA array of the appropriate shape - out <- rep_len(NA, length(x) * length(q)) + out <- rep_len(NA_real_, length(x) * length(q)) if (length(x) > 1) dim(out) <- c(length(q), dim(x)) out } @@ -91,14 +94,10 @@ cdf.rvar_ordered <- function(x, q, ...) { #' @rdname rvar-dist #' @export quantile.rvar <- function(x, probs, ...) { - summarise_rvar_by_element_via_matrix(x, - "quantile", - function(draws) { - t(matrixStats::colQuantiles(draws, probs = probs, useNames = TRUE, ...)) - }, - .extra_dim = length(probs), - .extra_dimnames = list(NULL) - ) + weights <- weights(x) + summarise_rvar_by_element(x, function(draws) { + weighted_quantile(draws, probs = probs, weights = weights, ...) + }) } #' @rdname rvar-dist diff --git a/R/rvar-factor.R b/R/rvar-factor.R index 5f57f643..303b99f9 100644 --- a/R/rvar-factor.R +++ b/R/rvar-factor.R @@ -61,7 +61,13 @@ #' #' @export rvar_factor <- function( - x = factor(), dim = NULL, dimnames = NULL, nchains = NULL, with_chains = FALSE, ... + x = factor(), + dim = NULL, + dimnames = NULL, + nchains = NULL, + with_chains = FALSE, + log_weights = NULL, + ... ) { # to ensure we pick up levels already attached to x (if there are any), we @@ -71,7 +77,12 @@ rvar_factor <- function( } out <- rvar( - x, dim = dim, dimnames = dimnames, nchains = nchains, with_chains = with_chains + x, + dim = dim, + dimnames = dimnames, + nchains = nchains, + with_chains = with_chains, + log_weights = log_weights ) .rvar_to_rvar_factor(out, ...) } @@ -79,11 +90,24 @@ rvar_factor <- function( #' @rdname rvar_factor #' @export rvar_ordered <- function( - x = ordered(NULL), dim = NULL, dimnames = NULL, nchains = NULL, with_chains = FALSE, ... + x = ordered(NULL), + dim = NULL, + dimnames = NULL, + nchains = NULL, + with_chains = FALSE, + log_weights = NULL, + ... ) { rvar_factor( - x, dim = dim, dimnames = dimnames, nchains = nchains, with_chains = with_chains, ordered = TRUE, ... + x, + dim = dim, + dimnames = dimnames, + nchains = nchains, + with_chains = with_chains, + log_weights = log_weights, + ordered = TRUE, + ... ) } diff --git a/R/rvar-math.R b/R/rvar-math.R index e704731d..9395ddeb 100755 --- a/R/rvar-math.R +++ b/R/rvar-math.R @@ -15,6 +15,7 @@ Ops.rvar <- function(e1, e2) { .Ops.rvar <- function(f, e1, e2, preserve_dims = FALSE) { c(e1, e2) %<-% conform_rvar_nchains(list(e1, e2)) + c(e1, e2) %<-% conform_rvar_weights(list(e1, e2)) draws_x <- draws_of(e1) draws_y <- draws_of(e2) @@ -47,7 +48,7 @@ Ops.rvar <- function(e1, e2) { draws <- copy_dims(dim_source, draws) } - new_rvar(draws, .nchains = nchains(e1)) + new_rvar(draws, .nchains = nchains(e1), .log_weights = log_weights(e1)) } #' @export @@ -95,10 +96,12 @@ Math.rvar <- function(x, ...) { if (.Generic %in% c("cumsum", "cumprod", "cummax", "cummin")) { # cumulative functions need to be handled differently # from other functions in this generic - new_rvar(t(apply(draws_of(x), 1, f)), .nchains = nchains(x)) + draws_of(x) <- t(apply(draws_of(x), 1, f)) } else { - new_rvar(f(draws_of(x), ...), .nchains = nchains(x)) + draws_of(x) <- f(draws_of(x), ...) } + + x } #' @export @@ -186,7 +189,7 @@ Math.rvar_factor <- function(x, ...) { } # conform the draws dimension in both variables - c(x, y) %<-% conform_rvar_ndraws_nchains(list(x, y)) + c(x, y) %<-% conform_rvar_nchains_ndraws_weights(list(x, y)) # drop the names of the dimensions (mul.tensor gets uppity if dimension names # are duplicated, but we don't care about that) @@ -206,7 +209,7 @@ Math.rvar_factor <- function(x, ...) { result <- copy_dimnames(draws_of(x), 1:2, result, 1:2) result <- copy_dimnames(draws_of(y), 3, result, 3) - new_rvar(result, .nchains = nchains(x)) + new_rvar(result, .nchains = nchains(x), .log_weights = log_weights(x)) } # This generic is not exported here as matrixOps is only in R >= 4.3, so we must @@ -246,16 +249,17 @@ chol.rvar <- function(x, ...) { x_tensor <- as.tensor(aperm(draws_of(x), c(2,3,1))) # do the cholesky decomp - result <- unclass(chol.tensor(x_tensor, 1, 2, ...)) + out_draws <- unclass(chol.tensor(x_tensor, 1, 2, ...)) # move draws dimension back to the front - result <- aperm(result, c(3,1,2)) + out_draws <- aperm(out_draws, c(3,1,2)) # drop dimension names (chol.tensor screws them around) - names(dim(result)) <- NULL - dimnames(result) <- NULL + names(dim(out_draws)) <- NULL + dimnames(out_draws) <- NULL - new_rvar(result, .nchains = nchains(x)) + draws_of(x) <- out_draws + x } #' @importFrom methods setGeneric @@ -334,14 +338,15 @@ t.rvar = function(x) { .dimnames = dimnames(.draws) dim(.draws) = c(dim(.draws)[1], 1, dim(.draws)[2]) dimnames(.draws) = c(.dimnames[1], list(NULL), .dimnames[2]) - result <- new_rvar(.draws, .nchains = nchains(x)) + draws_of(x) <- .draws } else if (ndim == 3) { .draws <- copy_levels(.draws, aperm(.draws, c(1, 3, 2))) - result <- new_rvar(.draws, .nchains = nchains(x)) + draws_of(x) <- .draws } else { stop_no_call("argument is not a random vector or matrix") } - result + + x } #' @export diff --git a/R/rvar-print.R b/R/rvar-print.R index 1efdbeab..0e7b7a0f 100755 --- a/R/rvar-print.R +++ b/R/rvar-print.R @@ -60,7 +60,7 @@ print.rvar <- function(x, ..., summary = NULL, digits = NULL, color = TRUE, widt digits <- digits %||% getOption("posterior.digits", 2) # \u00b1 = plus/minus sign summary_functions <- get_summary_functions(draws_of(x), summary) - plus_minus <- summary_functions[[1]] != "modal_category" + plus_minus <- !identical(summary_functions[[1]], modal_category) summary_string <- if (plus_minus) { paste0(paste(names(summary_functions), collapse = " \u00b1 "), ":") } else { @@ -89,7 +89,7 @@ print.rvar <- function(x, ..., summary = NULL, digits = NULL, color = TRUE, widt #' @export format.rvar <- function(x, ..., summary = NULL, digits = NULL, color = FALSE) { digits <- digits %||% getOption("posterior.digits", 2) - format_rvar_draws(draws_of(x), ..., summary = summary, digits = digits, color = color) + format_rvar_draws(draws_of(x), weights(x), ..., summary = summary, digits = digits, color = color) } #' @rdname print.rvar @@ -126,7 +126,7 @@ str.rvar <- function( } cat0(" ", rvar_type_full(object), " ", - paste(format_rvar_draws(.draws, summary = summary, trim = TRUE), collapse = " "), + paste(format_rvar_draws(.draws, weights(object), summary = summary, trim = TRUE), collapse = " "), ellipsis, "\n" ) @@ -158,7 +158,11 @@ str.rvar <- function( } } str_attr(attributes(draws_of(object)), "draws_of(*)", c("names", "dim", "dimnames", "class", "levels")) - str_attr(attributes(object), "*", c("draws", "names", "dim", "dimnames", "class", "nchains", "cache")) + str_attr(attributes(object), "*", c("draws", "names", "dim", "dimnames", "class", "nchains", "cache", "log_weights")) + if ("log_weights" %in% names(attributes(object))) { + cat0(indent.str, paste0('- log_weights(*)=')) + str_next(log_weights(object), ...) + } } invisible(NULL) @@ -218,7 +222,12 @@ rvar_type_full <- function(x, dim1 = TRUE) { paste0(",", nchains(x)) } - paste0(rvar_class(x), "<", niterations(x), chain_str, ">", dim_str) + paste0( + if (!is.null(log_weights(x))) "weighted ", + rvar_class(x), + "<", niterations(x), chain_str, ">", + dim_str + ) } rvar_class <- function(x) { @@ -235,19 +244,19 @@ rvar_class <- function(x) { # formats a draws array for display as individual "variables" (i.e. maintaining # its dimensions except for the dimension representing draws) format_rvar_draws <- function( - draws, ..., pad_left = "", pad_right = "", summary = NULL, digits = 2, color = FALSE, trim = FALSE + draws, weights, ..., pad_left = "", pad_right = "", summary = NULL, digits = 2, color = FALSE, trim = FALSE ) { if (length(draws) == 0) { return(character()) } summary_functions <- get_summary_functions(draws, summary) - plus_minus <- summary_functions[[1]] != "modal_category" + plus_minus <- !identical(summary_functions[[1]], modal_category) summary_dimensions <- seq_len(length(dim(draws)) - 1) + 1 # these will be mean/sd, median/mad, mode/entropy, mode/dissent depending on `summary` - .mean <- .apply_factor(draws, summary_dimensions, summary_functions[[1]]) - .sd <- .apply_factor(draws, summary_dimensions, summary_functions[[2]]) + .mean <- .apply_factor(draws, summary_dimensions, function(x) summary_functions[[1]](x, weights)) + .sd <- .apply_factor(draws, summary_dimensions, function(x) summary_functions[[2]](x, weights)) out <- paste0( pad_left, @@ -313,6 +322,16 @@ format_levels <- function(levels, ordered = FALSE, max_level = NULL, width = get ) } +# matrixStats::weighted_sd assumes we know the sample size, so use +# this instead +weighted_sd <- function(x, w = NULL) { + if (is.null(w)) { + sd(x) + } else { + sqrt(weighted.mean((x - weighted.mean(x, w))^2, w) ) + } +} + # check that summary is a valid name of the type of summary to do and # return a vector of two elements, where the first is the point summary function # (mean, median, mode) and the second is the uncertainty function () @@ -325,10 +344,10 @@ get_summary_functions <- function(draws, summary = NULL) { if (is.null(summary)) summary <- getOption("posterior.rvar_summary", "mean_sd") switch(summary, - mean_sd = list(mean = "mean", sd = "sd"), - median_mad = list(median = "median", mad = "mad"), - mode_entropy = list(mode = "modal_category", entropy = "entropy"), - mode_dissent = list(mode = "modal_category", dissent = "dissent"), + mean_sd = list(mean = matrixStats::weightedMean, sd = weighted_sd), + median_mad = list(median = matrixStats::weightedMedian, mad = matrixStats::weightedMad), + mode_entropy = list(mode = modal_category, entropy = entropy), + mode_dissent = list(mode = modal_category, dissent = dissent), stop_no_call('`summary` must be one of "mean_sd" or "median_mad"') ) } diff --git a/R/rvar-rfun.R b/R/rvar-rfun.R index 27931988..79210e6a 100755 --- a/R/rvar-rfun.R +++ b/R/rvar-rfun.R @@ -85,6 +85,7 @@ rfun <- function (.f, rvar_args = NULL, rvar_dots = TRUE, ndraws = NULL) { vapply(args, is_rvar, logical(1)) rvar_args_draws <- as_draws_rvars(args[is_rvar_arg]) .nchains <- max(1, nchains(rvar_args_draws)) + .log_weights <- log_weights(rvar_args_draws) if (length(rvar_args_draws) == 0) { # no rvar arguments, so just create a random variable by applying this function @@ -103,7 +104,7 @@ rfun <- function (.f, rvar_args = NULL, rvar_dots = TRUE, ndraws = NULL) { dim(x) <- c(1, dim(x)) x }) - new_rvar(vctrs::list_unchop(list_of_draws), .nchains = .nchains) + new_rvar(vctrs::list_unchop(list_of_draws), .nchains = .nchains, .log_weights = .log_weights) } formals(rvar_f) <- f_formals rvar_f @@ -231,16 +232,18 @@ rvar_rng <- function(.f, n, ..., ndraws = NULL) { args <- list(...) is_rvar_arg <- vapply(args, is_rvar, logical(1)) - rvar_args <- conform_rvar_ndraws_nchains(args[is_rvar_arg]) + rvar_args <- conform_rvar_nchains_ndraws_weights(args[is_rvar_arg]) if (length(rvar_args) < 1) { nchains <- 1 ndraws <- ndraws %||% getOption("posterior.rvar_ndraws", 4000) + log_weights <- NULL } else { # we have some arguments that are rvars. We require them to be single-dimensional # (vectors) so that R's vector recycling will work correctly. nchains <- nchains(rvar_args[[1]]) ndraws <- ndraws(rvar_args[[1]]) + log_weights <- log_weights(rvar_args[[1]]) rvar_args_ndims <- lengths(lapply(rvar_args, dim)) if (!all(rvar_args_ndims == 1)) { @@ -266,5 +269,5 @@ rvar_rng <- function(.f, n, ..., ndraws = NULL) { args <- c(n = nd, args) result <- do.call(.f, args) dim(result) <- c(ndraws, n) - new_rvar(result, .nchains = nchains) + new_rvar(result, .nchains = nchains, .log_weights = log_weights) } diff --git a/R/rvar-slice.R b/R/rvar-slice.R index 0129e37a..4ea47bf6 100755 --- a/R/rvar-slice.R +++ b/R/rvar-slice.R @@ -141,25 +141,26 @@ NULL .draws <- draws_of(x)[, i, drop = FALSE] } dimnames(.draws) <- NULL - out <- new_rvar(.draws, .nchains = nchains(x)) + draws_of(x) <- .draws } else if (length(index) == length(dim(x))) { # multiple element selection => must have exactly the right number of dims .draws <- inject(draws_of(x)[, !!!index, drop = FALSE]) # must do drop manually in case the draws dimension has only 1 draw dim(.draws) <- c(ndraws(x), 1) - out <- new_rvar(.draws, .nchains = nchains(x)) + draws_of(x) <- .draws } else { stop_no_call("subscript out of bounds") } - out + + x } #' @rdname rvar-slice #' @export `[[<-.rvar` <- function(x, i, ..., value) { value <- vec_cast(value, x) - c(x, value) %<-% conform_rvar_ndraws_nchains(list(x, value)) - value <- check_rvar_dims_first(value, new_rvar(0)) + c(x, value) %<-% conform_rvar_nchains_ndraws_weights(list(x, value)) + value <- check_rvar_is_scalar(value) index <- check_rvar_yank_index(x, i, ...) if (length(index) == 1) { @@ -219,7 +220,7 @@ NULL # this kind of indexing must ignore chains nchains_rvar(x) <- 1L nchains_rvar(i) <- 1L - c(x, i) %<-% conform_rvar_ndraws(list(x, i)) + c(x, i) %<-% conform_rvar_ndraws_weights(list(x, i)) index <- list() draws_index <- list(draws_of(i)) } else { @@ -283,9 +284,10 @@ NULL if (!is_missing(draws_index[[1]])) { # if we subsetted draws, replace draw ids with sequential ids rownames(.draws) <- seq_len(NROW(.draws)) + log_weights_rvar(x) <- inject(log_weights(x)[!!!draws_index]) } - x <- new_rvar(.draws, .nchains = nchains(x)) + draws_of(x) <- .draws if (drop) { x <- drop(x) @@ -314,7 +316,7 @@ NULL # for the purposes of this kind of assignment, we check draws only, not chains, # as chain information is irrelevant when subsetting by draw - c(x, i) %<-% conform_rvar_ndraws(list(x, i)) + c(x, i) %<-% conform_rvar_ndraws_weights(list(x, i)) draws_index <- draws_of(i) # necessary number of draws in `value` is determined by whether or not @@ -323,7 +325,7 @@ NULL draws_of(value) <- broadcast_array(draws_of(value), c(value_ndraws, dim(x)), broadcast_scalars = FALSE) i <- missing_arg() } else { - c(x, value) %<-% conform_rvar_ndraws_nchains(list(x, value)) + c(x, value) %<-% conform_rvar_nchains_ndraws_weights(list(x, value)) draws_index <- missing_arg() } @@ -378,7 +380,7 @@ scalar_numeric_rvar_to_index <- function(i_rvar, x, ...) { if (!is.numeric(draws_of(i_rvar)) || length(i_rvar) != 1) { stop_no_call("`x[[i]]` for rvars `x` and `i` is only supported when `i` is a scalar numeric rvar.") } - out <- conform_rvar_ndraws_nchains(list(i_rvar, x, ...)) + out <- conform_rvar_nchains_ndraws_weights(list(i_rvar, x, ...)) c(i_rvar, x) %<-% out[1:2] out[[1]] <- matrix_to_index(cbind(seq_len(ndraws(x)), draws_of(i_rvar)), c(ndraws(x), length(x))) out diff --git a/R/rvar-summaries-over-draws.R b/R/rvar-summaries-over-draws.R index b0adb9ff..b8cc2b3d 100755 --- a/R/rvar-summaries-over-draws.R +++ b/R/rvar-summaries-over-draws.R @@ -68,7 +68,7 @@ E <- function(x, ...) { #' @export mean.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( - x, "mean", matrixStats::colMeans2, useNames = FALSE, .ordered_okay = FALSE, ... + x, "mean", matrixStats::colWeightedMeans, useNames = FALSE, .ordered_okay = FALSE, w = weights(x), ... ) } @@ -101,7 +101,7 @@ Pr.rvar <- function(x, ...) { #' @export median.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( - x, "median", matrixStats::colMedians, useNames = FALSE, ... + x, "median", matrixStats::colWeightedMedians, useNames = FALSE, w = weights(x), ... ) } @@ -124,6 +124,8 @@ max.rvar <- function(x, ...) { #' @rdname rvar-summaries-over-draws #' @export sum.rvar <- function(x, ...) { + .weights <- weights(x, normalize = FALSE) + if (!is.null(.weights)) x <- x * new_rvar(.weights, .nchains = nchains(x)) summarise_rvar_by_element_via_matrix( x, "sum", matrixStats::colSums2, useNames = FALSE, .ordered_okay = FALSE, ... ) @@ -132,6 +134,8 @@ sum.rvar <- function(x, ...) { #' @rdname rvar-summaries-over-draws #' @export prod.rvar <- function(x, ...) { + .weights <- weights(x, normalize = FALSE) + if (!is.null(.weights)) x <- x ^ new_rvar(.weights, .nchains = nchains(x)) summarise_rvar_by_element_via_matrix( x, "prod", matrixStats::colProds, useNames = FALSE, .ordered_okay = FALSE, ... ) @@ -172,9 +176,14 @@ distributional::variance #' @rdname rvar-summaries-over-draws #' @export variance.rvar <- function(x, ...) { - summarise_rvar_by_element_via_matrix( - x, "variance", matrixStats::colVars, useNames = FALSE, .ordered_okay = FALSE, ... - ) + .weights <- weights(x) + if (is.null(.weights)) { + summarise_rvar_by_element_via_matrix( + x, "variance", matrixStats::colVars, useNames = FALSE, .ordered_okay = FALSE, ... + ) + } else { + mean((x - mean(x))^2) + } } #' @rdname rvar-summaries-over-draws @@ -196,9 +205,14 @@ sd.default <- function(x, ...) stats::sd(x, ...) #' @rdname rvar-summaries-over-draws #' @export sd.rvar <- function(x, ...) { - summarise_rvar_by_element_via_matrix( - x, "sd", matrixStats::colSds, useNames = FALSE, .ordered_okay = FALSE, ... - ) + .weights <- weights(x) + if (is.null(.weights)) { + summarise_rvar_by_element_via_matrix( + x, "sd", matrixStats::colWeightedSds, useNames = FALSE, .ordered_okay = FALSE, w = weights(x), ... + ) + } else { + sqrt(variance(x)) + } } #' @rdname rvar-summaries-over-draws @@ -211,7 +225,7 @@ mad.default <- function(x, ...) stats::mad(x, ...) #' @export mad.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( - x, "mad", matrixStats::colMads, useNames = FALSE, .ordered_okay = FALSE, ... + x, "mad", matrixStats::colWeightedMads, useNames = FALSE, .ordered_okay = FALSE, w = weights(x), ... ) } #' @rdname rvar-summaries-over-draws diff --git a/R/subset_draws.R b/R/subset_draws.R index 7802c63f..a765f435 100644 --- a/R/subset_draws.R +++ b/R/subset_draws.R @@ -374,6 +374,7 @@ subset_dims <- function(x, ...) { for (i in seq_along(x)) { draws_of(x[[i]]) <- vec_slice(draws_of(x[[i]]), slice_index) nchains_rvar(x[[i]]) <- nchains + log_weights_rvar(x[[i]]) <- log_weights(x[[i]])[slice_index] } } if (!is.null(iteration)) { @@ -382,6 +383,7 @@ subset_dims <- function(x, ...) { (rep(chain_ids(x), each = niterations) - 1) * niterations(x) for (i in seq_along(x)) { draws_of(x[[i]]) <- vec_slice(draws_of(x[[i]]), slice_index) + log_weights_rvar(x[[i]]) <- log_weights(x[[i]])[slice_index] } } x diff --git a/R/summarise_draws.R b/R/summarise_draws.R index 6f13755a..05acd83b 100644 --- a/R/summarise_draws.R +++ b/R/summarise_draws.R @@ -329,9 +329,9 @@ empty_draws_summary <- function(dimensions = "variable") { create_summary_list <- function(x, v, funs, .args) { draws <- drop_dims_or_classes(x[, , v], dims = 3, reset_class = FALSE) - args <- c(list(draws), .args) v_summary <- named_list(names(funs)) for (m in names(funs)) { + args <- c(list(draws), .args[[m]]) v_summary[[m]] <- do.call(funs[[m]], args) } v_summary diff --git a/R/weight_draws.R b/R/weight_draws.R index fa8bfd8b..a336e177 100644 --- a/R/weight_draws.R +++ b/R/weight_draws.R @@ -7,10 +7,11 @@ #' `draws` objects. #' #' @template args-methods-x -#' @param weights (numeric vector) A vector of weights of length `ndraws(x)`. -#' Weights will be internally stored on the log scale (in a variable called -#' `.log_weight`) and will not be normalized, but normalized (non-log) weights -#' can be returned via the [weights.draws()] method later. +#' @param weights (numeric vector) A vector of weights of length `ndraws(x)`, +#' or `NULL` to remove weights. Weights will be internally stored on the log +#' scale and will not be normalized. Normalized (non-log) weights can be +#' returned via the [weights.draws()] method, and the unnormalized +#' log weights can be accessed via [log_weights()]. #' @param log (logical) Are the weights passed already on the log scale? The #' default is `FALSE`, that is, expecting `weights` to be on the standard #' (non-log) scale. @@ -45,6 +46,9 @@ #' head(weights(x)) #' head(weights(x, log=TRUE, normalize = FALSE)) # recover original log_wts #' +#' # log_weights(x) is equivalent to weights(x, log = TRUE, normalize = FALSE) +#' all.equal(log_weights(x), weights(x, log = TRUE, normalize = FALSE)) +#' #' # add weights on log scale and Pareto smooth them #' x <- weight_draws(x, weights = log_wts, log = TRUE, pareto_smooth = TRUE) #' @@ -56,14 +60,9 @@ weight_draws <- function(x, weights, ...) { #' @rdname weight_draws #' @export weight_draws.draws_matrix <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { + log_weights <- validate_weights(weights, ndraws(x), log, pareto_smooth) + if (is.null(weights)) return(remove_variables(x, ".log_weight")) - - pareto_smooth <- as_one_logical(pareto_smooth) - log <- as_one_logical(log) - log_weights <- validate_weights(weights, x, log = log) - if (pareto_smooth) { - log_weights <- pareto_smooth_log_weights(log_weights) - } if (".log_weight" %in% variables(x, reserved = TRUE)) { # overwrite existing weights x[, ".log_weight"] <- log_weights @@ -78,13 +77,9 @@ weight_draws.draws_matrix <- function(x, weights, log = FALSE, pareto_smooth = F #' @rdname weight_draws #' @export weight_draws.draws_array <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { + log_weights <- validate_weights(weights, ndraws(x), log, pareto_smooth) + if (is.null(weights)) return(remove_variables(x, ".log_weight")) - pareto_smooth <- as_one_logical(pareto_smooth) - log <- as_one_logical(log) - log_weights <- validate_weights(weights, x, log = log) - if (pareto_smooth) { - log_weights <- pareto_smooth_log_weights(log_weights) - } if (".log_weight" %in% variables(x, reserved = TRUE)) { # overwrite existing weights x[, , ".log_weight"] <- log_weights @@ -99,27 +94,16 @@ weight_draws.draws_array <- function(x, weights, log = FALSE, pareto_smooth = FA #' @rdname weight_draws #' @export weight_draws.draws_df <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { - - pareto_smooth <- as_one_logical(pareto_smooth) - log <- as_one_logical(log) - log_weights <- validate_weights(weights, x, log = log) - if (pareto_smooth) { - log_weights <- pareto_smooth_log_weights(log_weights) - } - x$.log_weight <- log_weights + x$.log_weight <- validate_weights(weights, ndraws(x), log, pareto_smooth) x } #' @rdname weight_draws #' @export weight_draws.draws_list <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { + log_weights <- validate_weights(weights, ndraws(x), log, pareto_smooth) + if (is.null(log_weights)) return(remove_variables(x, ".log_weight")) - pareto_smooth <- as_one_logical(pareto_smooth) - log <- as_one_logical(log) - log_weights <- validate_weights(weights, x, log = log) - if (pareto_smooth) { - log_weights <- pareto_smooth_log_weights(log_weights) - } niterations <- niterations(x) for (i in seq_len(nchains(x))) { sel <- (1 + (i - 1) * niterations):(i * niterations) @@ -131,14 +115,17 @@ weight_draws.draws_list <- function(x, weights, log = FALSE, pareto_smooth = FAL #' @rdname weight_draws #' @export weight_draws.draws_rvars <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { - - pareto_smooth <- as_one_logical(pareto_smooth) - log <- as_one_logical(log) - log_weights <- validate_weights(weights, x, log = log) - if (pareto_smooth) { - log_weights <- pareto_smooth_log_weights(log_weights) + .log_weights <- validate_weights(weights, ndraws(x), log, pareto_smooth) + for (i in seq_along(x)) { + log_weights_rvar(x[[i]]) <- .log_weights } - x$.log_weight <- rvar(log_weights) + x +} + +#' @rdname weight_draws +#' @export +weight_draws.rvar <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { + log_weights_rvar(x) <- validate_weights(weights, ndraws(x), log, pareto_smooth) x } @@ -146,15 +133,18 @@ weight_draws.draws_rvars <- function(x, weights, log = FALSE, pareto_smooth = FA #' #' Extract weights from [`draws`] objects, with one weight per draw. #' See [`weight_draws`] for details how to add weights to [`draws`] objects. +#' `log_weights(x)` is a low-level shortcut for `weights(x, log = TRUE, normalize = FALSE)`, +#' returning the internal log weights without transforming them. #' -#' @param object (draws) A [`draws`] object. +#' @param object (draws) A [`draws`] object or an [`rvar`]. #' @param log (logical) Should the weights be returned on the log scale? #' Defaults to `FALSE`. #' @param normalize (logical) Should the weights be normalized to sum to 1 on #' the standard scale? Defaults to `TRUE`. #' @template args-methods-dots #' -#' @return A vector of weights, with one weight per draw. +#' @return A vector of weights, with one weight per draw, or `NULL` if this +#' object does not contain weights. #' #' @seealso [`weight_draws`], [`resample_draws`] #' @@ -164,10 +154,10 @@ weight_draws.draws_rvars <- function(x, weights, log = FALSE, pareto_smooth = FA weights.draws <- function(object, log = FALSE, normalize = TRUE, ...) { log <- as_one_logical(log) normalize <- as_one_logical(normalize) - if (!".log_weight" %in% variables(object, reserved = TRUE)) { - return(NULL) - } - out <- extract_variable(object, ".log_weight") + + out <- log_weights(object) + if (is.null(out)) return(NULL) + if (normalize) { out <- out - log_sum_exp(out) } @@ -177,11 +167,58 @@ weights.draws <- function(object, log = FALSE, normalize = TRUE, ...) { out } +#' @rdname weights.draws +#' @export +weights.rvar <- weights.draws + +#' @rdname weights.draws +#' @export +log_weights <- function(object, ...) { + UseMethod("log_weights") +} + +#' @rdname weights.draws +#' @export +log_weights.draws <- function(object, ...) { + if (!".log_weight" %in% variables(object, reserved = TRUE)) { + return(NULL) + } + extract_variable(object, ".log_weight") +} + +#' @rdname weights.draws +#' @export +log_weights.draws_rvars <- function(object, ...) { + if (length(object) < 1) return(NULL) + log_weights(object[[1]]) +} + +#' @rdname weights.draws +#' @export +log_weights.rvar <- function(object, ...) { + attr(object, "log_weights") +} +# for internal use only currently: if you are setting the log_weights +# attribute on an rvar, ALWAYS use this function so that the proxy +# cache is invalidated +`log_weights_rvar<-` <- function(x, value) { + if (!identical(attr(x, "log_weights"), value)) { + attr(x, "log_weights") <- value + x <- invalidate_rvar_cache(x) + } + x +} + + # validate weights and return log weights -validate_weights <- function(weights, draws, log = FALSE) { - checkmate::assert_numeric(weights) - checkmate::assert_flag(log) - if (length(weights) != ndraws(draws)) { +validate_weights <- function(weights, ndraws, log = FALSE, pareto_smooth = FALSE) { + if (is.null(weights)) return(NULL) + assert_numeric(weights) + assert_atomic_vector(weights) + assert_flag(log) + assert_flag(pareto_smooth) + + if (length(weights) != ndraws) { stop_no_call("Number of weights must match the number of draws.") } if (!log) { @@ -190,6 +227,10 @@ validate_weights <- function(weights, draws, log = FALSE) { } weights <- log(weights) } + if (pareto_smooth) { + weights <- pareto_smooth_log_weights(weights) + } + weights } diff --git a/R/weighted.R b/R/weighted.R new file mode 100644 index 00000000..f444ef74 --- /dev/null +++ b/R/weighted.R @@ -0,0 +1,121 @@ +# weighted distribution functions -------------------------------------------- + +#' Weighted version of [stats::ecdf()]. +#' Based on ggdist::weighted_ecdf(). +#' @noRd +weighted_ecdf = function(x, weights = NULL) { + n = length(x) + if (n < 1) stop("Need at least 1 or more values to calculate an ECDF") + + weights = if (is.null(weights)) rep(1, n) else weights + + #sort only if necessary + if (is.unsorted(x)) { + sort_order = order(x) + x = x[sort_order] + weights = weights[sort_order] + } + + # calculate weighted cumulative probabilities + p = cumsum(weights) + p = p/p[n] + + approxfun(x, p, yleft = 0, yright = 1, ties = "ordered", method = "constant") +} + +#' Weighted version of [stats::quantile()]. +#' Based on ggdist::weighted_quantile(). +#' @noRd +weighted_quantile = function(x, + probs = seq(0, 1, 0.25), + weights = NULL, + na.rm = FALSE, + type = 7, + ... +) { + weighted_quantile_fun( + x, + weights = weights, + na.rm = na.rm, + type = type, + ... + )(probs) +} + +#' @rdname weighted_quantile +#' @export +weighted_quantile_fun = function(x, weights = NULL, na.rm = FALSE, type = 7, ...) { + na.rm <- as_one_logical(na.rm) + assert_number(type, lower = 1, upper = 9) + + if (na.rm) { + keep = !is.na(x) & !is.na(weights) + x = x[keep] + weights = weights[keep] + } else if (anyNA(x)) { + # quantile itself doesn't handle this case (#110) + return(function(p) rep(NA_real_, length(p))) + } + + # determine weights + weights = weights %||% rep(1, length(x)) + non_zero = weights != 0 + x = x[non_zero] + weights = weights[non_zero] + weights = weights / sum(weights) + + # if there is only 0 or 1 x values, we don't need the weighted version (and + # we couldn't calculate it anyway as we need > 2 points for the interpolation) + if (length(x) <= 1) { + return(function(p) quantile(x, p, names = FALSE)) + } + + # sort values if necessary + if (is.unsorted(x)) { + x_order = order(x) + x = x[x_order] + weights = weights[x_order] + } + + # calculate the weighted CDF + F_k = cumsum(weights) + + # generate the function for the approximate inverse CDF + if (1 <= type && type <= 3) { + # discontinuous quantiles + switch(type, + # type 1 + stepfun(F_k, c(x, x[length(x)]), right = TRUE), + # type 2 + { + x_over_2 = c(x, x[length(x)])/2 + inverse_cdf_type2_left = stepfun(F_k, x_over_2, right = FALSE) + inverse_cdf_type2_right = stepfun(F_k, x_over_2, right = TRUE) + function(x) inverse_cdf_type2_left(x) + inverse_cdf_type2_right(x) + }, + # type 3 + stepfun(F_k - weights/2, c(x[[1]], x), right = TRUE) + ) + } else { + # Continuous quantiles. These are based on the definition of p_k as described + # in the documentation of `quantile()`. The trick to re-writing those formulas + # (which use `n` and `k`) for the weighted case is that `k` = `F_k * n` and + # `1/n` = `weight_k`. Using these two facts, we can express the formulas for + # `p_k` without using `n` or `k`, which don't really apply in the weighted case. + p_k = switch(type - 3, + # type 4 + F_k, + # type 5 + F_k - weights/2, + # type 6 + F_k / (1 + weights), + # type 7 + (F_k - weights) / (1 - weights), + # type 8 + (F_k - weights/3) / (1 + weights/3), + # type 9 + (F_k - weights*3/8) / (1 + weights/4) + ) + approxfun(p_k, x, rule = 2, ties = "ordered") + } +} diff --git a/man-roxygen/args-summaries-weights.R b/man-roxygen/args-summaries-weights.R new file mode 100644 index 00000000..d7089d74 --- /dev/null +++ b/man-roxygen/args-summaries-weights.R @@ -0,0 +1,2 @@ +#' @param weights (numeric vector) A vector of weights of the same length as `x`, +#' or `NULL` for unweighted estimation. diff --git a/man-roxygen/args-summaries-x-categorical.R b/man-roxygen/args-summaries-x-categorical.R new file mode 100644 index 00000000..6bc363cf --- /dev/null +++ b/man-roxygen/args-summaries-x-categorical.R @@ -0,0 +1,5 @@ +#' @param x (multiple options) A vector to be interpreted as draws from +#' a categorical distribution, such as: +#' - A [factor] +#' - A [numeric] (should be [integer] or integer-like) +#' - An [rvar], [rvar_factor], or [rvar_ordered] diff --git a/man/dissent.Rd b/man/dissent.Rd index 4166ee54..d16431c1 100644 --- a/man/dissent.Rd +++ b/man/dissent.Rd @@ -6,11 +6,11 @@ \alias{dissent.rvar} \title{Dissention} \usage{ -dissent(x) +dissent(x, ...) -\method{dissent}{default}(x) +\method{dissent}{default}(x, weights = NULL, ...) -\method{dissent}{rvar}(x) +\method{dissent}{rvar}(x, ...) } \arguments{ \item{x}{(multiple options) A vector to be interpreted as draws from @@ -20,6 +20,11 @@ an ordinal distribution, such as: \item A \link{numeric} (should be \link{integer} or integer-like) \item An \link{rvar}, \link{rvar_factor}, or \link{rvar_ordered} }} + +\item{...}{Arguments passed to individual methods (if applicable).} + +\item{weights}{(numeric vector) A vector of weights of the same length as \code{x}, +or \code{NULL} for unweighted estimation.} } \value{ If \code{x} is a \link{factor} or \link{numeric}, returns a length-1 numeric vector with a value diff --git a/man/entropy.Rd b/man/entropy.Rd index 429068f0..8d8657bb 100644 --- a/man/entropy.Rd +++ b/man/entropy.Rd @@ -6,11 +6,11 @@ \alias{entropy.rvar} \title{Normalized entropy} \usage{ -entropy(x) +entropy(x, ...) -\method{entropy}{default}(x) +\method{entropy}{default}(x, weights = NULL, ...) -\method{entropy}{rvar}(x) +\method{entropy}{rvar}(x, ...) } \arguments{ \item{x}{(multiple options) A vector to be interpreted as draws from @@ -20,6 +20,11 @@ a categorical distribution, such as: \item A \link{numeric} (should be \link{integer} or integer-like) \item An \link{rvar}, \link{rvar_factor}, or \link{rvar_ordered} }} + +\item{...}{Arguments passed to individual methods (if applicable).} + +\item{weights}{(numeric vector) A vector of weights of the same length as \code{x}, +or \code{NULL} for unweighted estimation.} } \value{ If \code{x} is a \link{factor} or \link{numeric}, returns a length-1 numeric vector with a value diff --git a/man/modal_category.Rd b/man/modal_category.Rd index 8fd8300f..2f8351d6 100644 --- a/man/modal_category.Rd +++ b/man/modal_category.Rd @@ -6,11 +6,11 @@ \alias{modal_category.rvar} \title{Modal category} \usage{ -modal_category(x) +modal_category(x, ...) -\method{modal_category}{default}(x) +\method{modal_category}{default}(x, weights = NULL, ...) -\method{modal_category}{rvar}(x) +\method{modal_category}{rvar}(x, ...) } \arguments{ \item{x}{(multiple options) A vector to be interpreted as draws from @@ -20,6 +20,11 @@ a categorical distribution, such as: \item A \link{numeric} (should be \link{integer} or integer-like) \item An \link{rvar}, \link{rvar_factor}, or \link{rvar_ordered} }} + +\item{...}{Arguments passed to individual methods (if applicable).} + +\item{weights}{(numeric vector) A vector of weights of the same length as \code{x}, +or \code{NULL} for unweighted estimation.} } \value{ If \code{x} is a \link{factor} or \link{numeric}, returns a length-1 vector containing diff --git a/man/rvar.Rd b/man/rvar.Rd index 3c144c72..24c745ef 100755 --- a/man/rvar.Rd +++ b/man/rvar.Rd @@ -9,7 +9,8 @@ rvar( dim = NULL, dimnames = NULL, nchains = NULL, - with_chains = FALSE + with_chains = FALSE, + log_weights = NULL ) } \arguments{ @@ -50,6 +51,11 @@ used to determine the number of chains. If \code{TRUE}, the \code{nchains} argum is ignored and the second dimension of \code{x} is used to index chains. Internally, the array will be converted to a format without the chain index. Ignored when \code{x} is already an \code{\link{rvar}}.} + +\item{log_weights}{(numeric vector) A vector of log weights of length \code{ndraws(x)}. +Weights will be internally stored on the log scale and will not be normalized, +but normalized (non-log) weights can be returned via the \code{\link[=weights.rvar]{weights.rvar()}} +method later.} } \value{ An object of class \code{"rvar"} representing a random variable. @@ -83,6 +89,64 @@ As \code{\link[=rfun]{rfun()}} and \code{\link[=rdo]{rdo()}} incur some performa on the underlying array using the \code{\link[=draws_of]{draws_of()}} function. To re-use existing random number generator functions to efficiently create \code{rvar}s, use \code{\link[=rvar_rng]{rvar_rng()}}. } +\section{\code{rvar} Internals}{ + + +The \code{rvar} datatype is not intended to be modified directly; rather, you should +only use exported functions from \pkg{posterior}, such as \code{\link[=rvar]{rvar()}}, \code{\link[=draws_of]{draws_of()}}, +\code{\link[=log_weights]{log_weights()}}, and \code{\link[=weight_draws]{weight_draws()}} to create and manipulate \code{rvar}s. +For completeness, and to aid internal development, this section documents the +internal structure of the \code{rvar} datatype. While the public-facing API is +intended to be stable, \strong{this internal structure is subject to change without +notice}. + +An \code{rvar} \code{x} consists of: +\itemize{ +\item A zero-length \code{list()} with class \code{c("rvar", "vctrs_vctr")}. If \code{draws_of(x)} +is a \code{\link{factor}}, the class will be \code{c("rvar_factor", "rvar", "vctrs_vctr")}, +and if \code{draws_of(x)} is an \code{\link{ordered}}, the class will be +\code{c("rvar_ordered", "rvar_factor", "rvar", "vctrs_vctr")}. These classes are +set automatically if the underlying draws are modified. + +The list has these attributes: +\itemize{ +\item \code{draws}: An \code{\link{array}} containing the draws, where the first dimension +indexes draws. \strong{Always} get this attribute using \code{\link[=draws_of]{draws_of()}} and set it +using \code{draws_of(x) <- value}. To simplify programming, \code{length(dim(draws_of(x)))} +is guaranteed to always be greater than or equal to 2. Zero-length \code{rvar}s +have \code{dim(draws_of(x)) = c(1,0)}. The draws may be a \code{\link{numeric}}, +\code{\link{integer}}, \code{\link{logical}}, \code{\link{factor}}, or \code{\link{ordered}} array. + +The dimensions after the first are reported as the dimensions of \code{x}; i.e. +\code{dim(x) = dim(draws_of(x))[-1]} and \code{dimnames(x) = dimnames(draws_of(x))[-1]}. +Because \code{rvar}s \emph{always} have dimensions (unlike base R datatypes, where +there is a distinction between a length-\emph{n} vector with no dimensions and +a length-\emph{n} array with only 1 dimension), \code{names(x) = dimnames(x)[[1]]}; +i.e., \code{names()} refers to the names along the first dimension only. +\item \code{nchains}: A scalar \code{\link{numeric}} giving the number of chains in this \code{rvar}. +\strong{Always} get this attribute using \code{\link[=nchains]{nchains()}}. It cannot be set using the +public (exported) API, but can be modified through other functions (e.g. +\code{\link[=merge_chains]{merge_chains()}} or by creating a new \code{\link[=rvar]{rvar()}}). In internal code, \strong{always} +set it using \code{nchains_rvar(x) <- value}. +\item \code{log_weights}: A vector \code{\link{numeric}} with length \code{ndraws(x)} giving the +log weight on each draw of this \code{rvar}, or \code{NULL} if the \code{rvar} is not +weighted. \strong{Always} get this attribute using \code{\link[=weights]{weights()}} or \code{\link[=log_weights]{log_weights()}}, +and set this attributes using \code{\link[=weight_draws]{weight_draws()}}. In internal code, it may +also be modified directly using \code{log_weights_rvar(x) <- value}. +\item \code{cache}: An \code{\link{environment}} that may contain cached output of the \pkg{vctrs} +proxy functions on \code{x} to improve performance of code that makes multiple +calls to those functions. The cache is updated automatically and invalidated +when necessary so long as the \code{rvar} is only modified using the functions +described in this section (or other functions in the publicly-exported +\code{rvar} API). The environment may contain these variables: +\itemize{ +\item \code{vec_proxy}: cached output of \code{\link[vctrs:vec_proxy]{vctrs::vec_proxy()}}. +\item \code{vec_proxy_equal}: cached output of \code{\link[vctrs:vec_proxy_equal]{vctrs::vec_proxy_equal()}}. +} +} +} +} + \examples{ set.seed(1234) diff --git a/man/rvar_factor.Rd b/man/rvar_factor.Rd index 92ea056b..315a3972 100644 --- a/man/rvar_factor.Rd +++ b/man/rvar_factor.Rd @@ -11,6 +11,7 @@ rvar_factor( dimnames = NULL, nchains = NULL, with_chains = FALSE, + log_weights = NULL, ... ) @@ -20,6 +21,7 @@ rvar_ordered( dimnames = NULL, nchains = NULL, with_chains = FALSE, + log_weights = NULL, ... ) } @@ -62,6 +64,11 @@ is ignored and the second dimension of \code{x} is used to index chains. Internally, the array will be converted to a format without the chain index. Ignored when \code{x} is already an \code{\link{rvar}}.} +\item{log_weights}{(numeric vector) A vector of log weights of length \code{ndraws(x)}. +Weights will be internally stored on the log scale and will not be normalized, +but normalized (non-log) weights can be returned via the \code{\link[=weights.rvar]{weights.rvar()}} +method later.} + \item{...}{ Arguments passed on to \code{\link[base:factor]{base::factor}} \describe{ diff --git a/man/weight_draws.Rd b/man/weight_draws.Rd index d866d466..673223f0 100644 --- a/man/weight_draws.Rd +++ b/man/weight_draws.Rd @@ -7,6 +7,7 @@ \alias{weight_draws.draws_df} \alias{weight_draws.draws_list} \alias{weight_draws.draws_rvars} +\alias{weight_draws.rvar} \title{Weight \code{draws} objects} \usage{ weight_draws(x, weights, ...) @@ -20,15 +21,18 @@ weight_draws(x, weights, ...) \method{weight_draws}{draws_list}(x, weights, log = FALSE, pareto_smooth = FALSE, ...) \method{weight_draws}{draws_rvars}(x, weights, log = FALSE, pareto_smooth = FALSE, ...) + +\method{weight_draws}{rvar}(x, weights, log = FALSE, pareto_smooth = FALSE, ...) } \arguments{ \item{x}{(draws) A \code{draws} object or another \R object for which the method is defined.} -\item{weights}{(numeric vector) A vector of weights of length \code{ndraws(x)}. -Weights will be internally stored on the log scale (in a variable called -\code{.log_weight}) and will not be normalized, but normalized (non-log) weights -can be returned via the \code{\link[=weights.draws]{weights.draws()}} method later.} +\item{weights}{(numeric vector) A vector of weights of length \code{ndraws(x)}, +or \code{NULL} to remove weights. Weights will be internally stored on the log +scale and will not be normalized. Normalized (non-log) weights can be +returned via the \code{\link[=weights.draws]{weights.draws()}} method, and the unnormalized +log weights can be accessed via \code{\link[=log_weights]{log_weights()}}.} \item{...}{Arguments passed to individual methods (if applicable).} @@ -73,6 +77,9 @@ x <- weight_draws(x, weights = log_wts, log = TRUE) head(weights(x)) head(weights(x, log=TRUE, normalize = FALSE)) # recover original log_wts +# log_weights(x) is equivalent to weights(x, log = TRUE, normalize = FALSE) +all.equal(log_weights(x), weights(x, log = TRUE, normalize = FALSE)) + # add weights on log scale and Pareto smooth them x <- weight_draws(x, weights = log_wts, log = TRUE, pareto_smooth = TRUE) diff --git a/man/weights.draws.Rd b/man/weights.draws.Rd index 1a47788e..7b4d865c 100644 --- a/man/weights.draws.Rd +++ b/man/weights.draws.Rd @@ -2,12 +2,27 @@ % Please edit documentation in R/weight_draws.R \name{weights.draws} \alias{weights.draws} +\alias{weights.rvar} +\alias{log_weights} +\alias{log_weights.draws} +\alias{log_weights.draws_rvars} +\alias{log_weights.rvar} \title{Extract Weights from Draws Objects} \usage{ \method{weights}{draws}(object, log = FALSE, normalize = TRUE, ...) + +\method{weights}{rvar}(object, log = FALSE, normalize = TRUE, ...) + +log_weights(object, ...) + +\method{log_weights}{draws}(object, ...) + +\method{log_weights}{draws_rvars}(object, ...) + +\method{log_weights}{rvar}(object, ...) } \arguments{ -\item{object}{(draws) A \code{\link{draws}} object.} +\item{object}{(draws) A \code{\link{draws}} object or an \code{\link{rvar}}.} \item{log}{(logical) Should the weights be returned on the log scale? Defaults to \code{FALSE}.} @@ -18,11 +33,14 @@ the standard scale? Defaults to \code{TRUE}.} \item{...}{Arguments passed to individual methods (if applicable).} } \value{ -A vector of weights, with one weight per draw. +A vector of weights, with one weight per draw, or \code{NULL} if this +object does not contain weights. } \description{ Extract weights from \code{\link{draws}} objects, with one weight per draw. See \code{\link{weight_draws}} for details how to add weights to \code{\link{draws}} objects. +\code{log_weights(x)} is a low-level shortcut for \code{weights(x, log = TRUE, normalize = FALSE)}, +returning the internal log weights without transforming them. } \examples{ x <- example_draws() @@ -48,6 +66,9 @@ x <- weight_draws(x, weights = log_wts, log = TRUE) head(weights(x)) head(weights(x, log=TRUE, normalize = FALSE)) # recover original log_wts +# log_weights(x) is equivalent to weights(x, log = TRUE, normalize = FALSE) +all.equal(log_weights(x), weights(x, log = TRUE, normalize = FALSE)) + # add weights on log scale and Pareto smooth them x <- weight_draws(x, weights = log_wts, log = TRUE, pareto_smooth = TRUE) diff --git a/tests/testthat/test-convergence.R b/tests/testthat/test-convergence.R index 96194d8d..1e9034a9 100644 --- a/tests/testthat/test-convergence.R +++ b/tests/testthat/test-convergence.R @@ -145,3 +145,65 @@ test_that("autocovariance returns correct results", { ac2 <- acf(x, type = "covariance", lag.max = length(x), plot = FALSE)$acf[, 1, 1] expect_equal(ac1, ac2) }) + +test_that("NA quantile2 works", { + expect_equal(quantile2(NA_real_, c(0.25, 0.75)), c(q25 = NA_real_, q75 = NA_real_)) +}) + + +test_that("weighted convergence measures work", { + + # draws from standard normal + x <- cbind( + rnorm(100), + rnorm(100), + rnorm(100), + rnorm(100) + ) + + xr <- rvar(x, with_chains = TRUE) + + # target is normal(0, 0.5) + # here, ess should be higher for mean + # mcse should be lower for mean + w1 <- as.numeric(dnorm(x, sd = 0.5) / dnorm(x)) + w1 <- w1 / sum(w1) + xw1 <- weight_draws(xr, weights = w1) + + expect_true(ess_mean(xw1) > ess_mean(xr)) + expect_true(mcse_mean(xw1) < mcse_mean(xr)) + expect_true(ess_quantile(xw1, probs = 0.05) > ess_quantile(xr, probs = 0.05)) + expect_true(ess_quantile(xw1, probs = 0.95) > ess_quantile(xr, probs = 0.95)) + expect_true(mcse_quantile(xw1, probs = 0.05) < mcse_quantile(xr, probs = 0.05)) + expect_true(mcse_quantile(xw1, probs = 0.95) < mcse_quantile(xr, probs = 0.95)) + + # target is normal(0, 1.2) + # here ess should be lower, and mcse should be higher + w2 <- as.numeric(dnorm(x, sd = 1.2) / dnorm(x)) + w2 <- w2 / sum(w2) + xw2 <- weight_draws(xr, weights = w2) + + expect_true(ess_mean(xw2) < ess_mean(xr)) + expect_true(mcse_mean(xw2) > mcse_mean(xr)) + + expect_true(ess_quantile(xw2, probs = 0.05) < ess_quantile(xr, probs = 0.05)) + expect_true(ess_quantile(xw2, probs = 0.95) < ess_quantile(xr, probs = 0.95)) + expect_true(mcse_quantile(xw2, probs = 0.05) > mcse_quantile(xr, probs = 0.05)) + expect_true(mcse_quantile(xw2, probs = 0.95) > mcse_quantile(xr, probs = 0.95)) + + + # target is normal(1, 1) + # here ess for mean and q95 should be lower, but for q5 it should be higher + w3 <- as.numeric(dnorm(x, mean = 1, sd = 1) / dnorm(x)) + w3 <- w3 / sum(w3) + xw3 <- weight_draws(xr, weights = w3) + + expect_true(ess_mean(xw3) < ess_mean(xr)) + expect_true(mcse_mean(xw3) > mcse_mean(xr)) + + expect_true(ess_quantile(xw3, probs = 0.05) > ess_quantile(xr, probs = 0.05)) + expect_true(ess_quantile(xw3, probs = 0.95) < ess_quantile(xr, probs = 0.95)) + expect_true(mcse_quantile(xw3, probs = 0.05) < mcse_quantile(xr, probs = 0.05)) + expect_true(mcse_quantile(xw3, probs = 0.95) > mcse_quantile(xr, probs = 0.95)) + +}) diff --git a/tests/testthat/test-discrete-summaries.R b/tests/testthat/test-discrete-summaries.R index 78ce902b..5352aa3a 100644 --- a/tests/testthat/test-discrete-summaries.R +++ b/tests/testthat/test-discrete-summaries.R @@ -24,7 +24,6 @@ test_that("modal_category works on rvars", { expect_equal(modal_category(c(rvar(c("a","b","b","c","c")), rvar("c"))), c("b","c")) }) - # entropy ----------------------------------------------------------------- test_that("entropy works on vectors", { @@ -83,3 +82,32 @@ test_that("dissent works on rvars", { # know about the missing level at the end expect_equal(dissent(as_rvar_numeric(x)), c(-sum(p * log2(1 - abs(1:3 - 1.75) / 2)), 0, 1)) }) + + +# weighted summaries ------------------------------------------------------ + +test_that("weighted discrete summaries work", { + x <- c(0, 0, 0, 3, 3, 1) + levs <- c("h","e","f","g") + x_factor <- factor(c("h","h","h","g","g","e"), levels = levs) + x_ordered <- ordered(c("h","h","h","g","g","e"), levels = levs) + xw <- c(1, 2, 0, 3, 0) + xw_factor <- factor(c("e","f","h","g","h"), levels = levs) + xw_ordered <- ordered(c("e","f","h","g","h"), levels = levs) + w <- c(1, 0, 1.25, 2, 1.75) + + expect_equal(modal_category(xw, w), modal_category(x)) + expect_equal(modal_category(xw_factor, w), modal_category(x_factor)) + expect_equal(modal_category(xw_ordered, w), modal_category(x_ordered)) + + # entropy(xw, w) is equal to entropy(x_factor) because entropy(x_factor) + # accounts for the missing level just as entropy(xw, w) accounts for the + # element with 0 weight. entropy(x) cannot do this. + expect_equal(entropy(xw, w), entropy(x_factor)) + expect_equal(entropy(xw_factor, w), entropy(x_factor)) + expect_equal(entropy(xw_ordered, w), entropy(x_ordered)) + + expect_equal(dissent(xw, w), dissent(x)) + expect_equal(dissent(xw_factor, w), dissent(x_factor)) + expect_equal(dissent(xw_ordered, w), dissent(x_ordered)) +}) diff --git a/tests/testthat/test-pareto_smooth.R b/tests/testthat/test-pareto_smooth.R index 89635600..abf540a6 100644 --- a/tests/testthat/test-pareto_smooth.R +++ b/tests/testthat/test-pareto_smooth.R @@ -203,3 +203,29 @@ test_that("pareto_smooth works for log_weights", { expect_true(ps$diagnostics$khat > 0.7) }) + + + +test_that("pareto khat works for weighted rvars", { + + x <- cbind( + rnorm(100), + rnorm(100), + rnorm(100), + rnorm(100) + ) + + xr <- rvar(x, with_chains = TRUE) + + # target is normal(0, 1.2), should have high pareto-khat + w2 <- as.numeric(dnorm(x, sd = 5) / dnorm(x)) + w2 <- w2 / sum(w2) + xw2 <- weight_draws(xr, weights = w2) + + k <- pareto_khat(xw2)$khat + kw <- pareto_khat(w2, are_log_weights = TRUE)$khat + kp <- pareto_khat(draws_of(xw2) * w2)$khat + + expect_true(k > 0.7) + expect_equal(k, max(kw, kp)) +}) diff --git a/tests/testthat/test-print.R b/tests/testthat/test-print.R index 2b79fd39..155be0f5 100644 --- a/tests/testthat/test-print.R +++ b/tests/testthat/test-print.R @@ -53,7 +53,6 @@ test_that("print.draws_list runs without errors", { test_that("print.draws_rvars runs without errors", { skip_on_cran() - skip_on_os("windows") x <- as_draws_rvars(example_draws()) out <- capture.output(print(x)) expect_match( @@ -65,7 +64,7 @@ test_that("print.draws_rvars runs without errors", { x <- weight_draws(x, rep(1, ndraws(x))) expect_output( print(x), - "hidden reserved variables ..\\.log_weight.." + "weighted rvar" ) }) @@ -112,7 +111,6 @@ test_that("print.draws_list handles reserved variables correctly", { test_that("print.draws_rvars handles reserved variables correctly", { skip_on_cran() - skip_on_os("windows") x <- as_draws_rvars(example_draws()) variables(x)[1] <- ".log_weight" # reserved name expect_output(print(x, max_variables = 1), "tau") diff --git a/tests/testthat/test-resample_draws.R b/tests/testthat/test-resample_draws.R index bea49809..42abba81 100644 --- a/tests/testthat/test-resample_draws.R +++ b/tests/testthat/test-resample_draws.R @@ -79,6 +79,12 @@ test_that("resample_draws works on rvars", { expect_true(mean_rs > 6660 && mean_rs < 6670) expect_true(is_rvar(x_rs)) + x_rs <- resample_draws(weight_draws(x, w), method = "stratified") + mean_rs <- mean(x_rs) + expect_true(mean_rs > 6660 && mean_rs < 6670) + expect_true(is_rvar(x_rs)) + expect_null(log_weights(x_rs)) + # without weights x_rs <- resample_draws(x, method = "stratified") mean_rs <- mean(x_rs) diff --git a/tests/testthat/test-rvar-bind.R b/tests/testthat/test-rvar-bind.R index 6f580ed5..b10cc819 100755 --- a/tests/testthat/test-rvar-bind.R +++ b/tests/testthat/test-rvar-bind.R @@ -143,6 +143,15 @@ test_that("c works on rvar_ordered", { expect_equal(c(x_col, y), x_y) }) +test_that("binding weighted and unweighted rvars works", { + x = rvar(1:10) + xw = rvar(1:10, log_weights = 1:10) + + # binding weighted to unweighted constant is okay + expect_equal(c(xw, 1), rvar(cbind(1:10, 1), log_weights = 1:10)) + # but binding weights to unweighted non-constant is not okay + expect_error(c(xw, x), "different log weights") +}) # cbind.rvar -------------------------------------------------------------- diff --git a/tests/testthat/test-rvar-cast.R b/tests/testthat/test-rvar-cast.R index 384cb05e..e500b6dc 100755 --- a/tests/testthat/test-rvar-cast.R +++ b/tests/testthat/test-rvar-cast.R @@ -202,6 +202,18 @@ test_that("casting to/from rvar/distribution objects works", { expect_error(vctrs::vec_cast(x_mv, null_dist)) }) +test_that("vec_c works with rvar and distributions", { + x_dist <- distributional::dist_sample(list(a = 1:2, b = 3:4)) + y_dist <- distributional::dist_sample(list(c = 5:6, d = 7:8)) + xy_dist <- distributional::dist_sample(list(a = 1:2, b = 3:4, c = 5:6, d = 7:8)) + x_rvar <- rvar(matrix(c(1:4), ncol = 2, dimnames = list(NULL, c("a","b")))) + y_rvar <- rvar(matrix(c(5:8), ncol = 2, dimnames = list(NULL, c("c","d")))) + xy_rvar <- rvar(matrix(c(1:8), ncol = 4, dimnames = list(NULL, c("a","b","c","d")))) + + expect_equal(vctrs::vec_c(x_dist, y_rvar), xy_dist) + expect_equal(vctrs::vec_c(x_rvar, y_dist), xy_rvar) +}) + # type predicates --------------------------------------------------------- diff --git a/tests/testthat/test-rvar-dist.R b/tests/testthat/test-rvar-dist.R index 09c2460c..95f1a571 100755 --- a/tests/testthat/test-rvar-dist.R +++ b/tests/testthat/test-rvar-dist.R @@ -9,6 +9,10 @@ test_that("distributional functions work on a scalar rvar", { expect_equal(cdf(x, x_values), x_cdf) expect_equal(quantile(x, 1:4/4), quantile(x_values, 1:4/4, names = FALSE)) + + expect_equal(quantile(rvar(1:4), 0:4/4 + .Machine$double.eps, type = 1), c(1:4, 4)) + expect_equal(quantile(rvar(1:4), 0:4/4, type = 2), c(1, 1.5, 2.5, 3.5, 4)) + expect_equal(quantile(rvar(1:4), 0:4/4 + .Machine$double.eps, type = 3), c(1, 1:4)) }) test_that("distributional functions work on an rvar array", { @@ -33,7 +37,7 @@ test_that("distributional functions work on an rvar array", { q21 <- quantile(4:6, p) q12 <- quantile(7:9, p) q22 <- quantile(10:12, p) - x_quantiles <- array(c(q11, q21, q12, q22), dim = c(9, 2, 2), dimnames = list(NULL)) + x_quantiles <- array(c(q11, q21, q12, q22), dim = c(9, 2, 2)) expect_equal(quantile(x, p), x_quantiles) }) @@ -41,12 +45,16 @@ test_that("distributional functions work on an rvar_factor", { x_values <- c(2,2,2,4,4,4,4,3,5,3) x_letters <- letters[x_values] x <- rvar_factor(x_letters, levels = letters[1:5]) + x2 <- c(rvar_factor(letters), rvar_factor(letters)) - expect_equal(density(x, letters[1:6]), c(0, .3, .2, .4, .1, NA)) + expect_equal(density(x, letters[1:6]), c(0, .3, .2, .4, .1, NA_real_)) + expect_equal(density(x2, letters[1:3]), array(rep(1/26, 6), dim = c(3,2))) - expect_equal(cdf(x, letters[1:5]), c(NA, NA, NA, NA, NA)) + expect_equal(cdf(x, letters[1:5]), c(NA_real_, NA_real_, NA_real_, NA_real_, NA_real_)) + expect_equal(cdf(x2, letters[1:3]), array(rep(NA_real_, 6), dim = c(3,2))) - expect_equal(quantile(x, 1:4/4), c(NA, NA, NA, NA)) + expect_equal(quantile(x, 1:4/4), c(NA_real_, NA_real_, NA_real_, NA_real_)) + expect_equal(quantile(x2, 1:3/3), array(rep(NA_real_, 6), dim = c(3,2))) }) test_that("distributional functions work on an rvar_ordered", { @@ -60,3 +68,55 @@ test_that("distributional functions work on an rvar_ordered", { expect_equal(quantile(x, c(.3, .5, .9, 1)), letters[2:5]) }) + +# weighted rvar ----------------------------------------------------------- + +test_that("weighted rvar works", { + x1_draws = qnorm(ppoints(10)) + x2_draws = qnorm(ppoints(10), 5) + w1 = rep(1, 10) + w2 = rep(2, 10) + w3 = rep(0, 10) + x = rvar(c(x1_draws, x2_draws, rep(10, 10)), log_weights = log(c(w1, w2, w3))) + + expect_equal( + density(x, 0:9, bw = 2.25), + density(draws_of(x), weights = weights(x), bw = 2.25, from = 0, to = 9, n = 10)$y, + tolerance = 1e-4 + ) + expect_equal(cdf(x, 0:9), ecdf(x1_draws)(0:9)/3 + ecdf(x2_draws)(0:9)*2/3) + expect_equal(quantile(x, cdf(x, c(x1_draws, x2_draws)), type = 1), c(x1_draws, x2_draws)) + expect_equal(quantile(x, cdf(x, c(x1_draws, x2_draws)), type = 4), c(x1_draws, x2_draws)) + expect_equal(unname(quantile2(x, cdf(x, c(x1_draws, x2_draws)), type = 1)), c(x1_draws, x2_draws)) + expect_equal(unname(quantile2(x, cdf(x, c(x1_draws, x2_draws)), type = 4)), c(x1_draws, x2_draws)) + + x_na <- rvar(c(draws_of(x), NA_real_), log_weights = c(log_weights(x), 1)) + expect_equal(quantile(x_na, c(0.25, 0.5, 0.75), type = 4), c(NA_real_, NA_real_, NA_real_)) + expect_equal( + quantile(x_na, c(0.25, 0.5, 0.75), type = 7, na.rm = TRUE), + quantile(x, c(0.25, 0.5, 0.75), type = 7) + ) + + expect_equal(quantile(rvar(1), 0.5), 1) + expect_equal(quantile(rvar(), 0.5), numeric()) +}) + +test_that("weighted rvar_factor works", { + x = rvar_factor(c("b", "g", "f", "g"), levels = letters, log_weights = log(c(1/2, 1/6, 1/6, 1/6))) + + expect_equal(density(x, letters), c(0, 1/2, 0, 0, 0, 1/6, 1/3, rep(0, 19))) + expect_equal(cdf(x, letters), rep(NA_real_, 26)) + expect_equal(quantile(x, c(0.2, 0.8)), rep(NA_real_, 2)) +}) + +test_that("weighted rvar_ordered works", { + x = rvar_ordered(c("b", "g", "f", "g"), levels = letters, log_weights = log(c(1/2, 1/6, 1/6, 1/6))) + + expect_equal(density(x, letters), c(0, 1/2, 0, 0, 0, 1/6, 1/3, rep(0, 19))) + expect_equal(cdf(x, letters), cumsum(c(0, 1/2, 0, 0, 0, 1/6, 1/3, rep(0, 19)))) + expect_equal(quantile(x, c(0.2, 0.6, 0.8)), c("b", "f", "g")) + + xl = weight_draws(rvar_ordered(letters), 1:26) + expect_equal(quantile(xl, cdf(xl, letters) - .Machine$double.eps), letters) +}) + diff --git a/tests/testthat/test-rvar-print.R b/tests/testthat/test-rvar-print.R index fe28a6ab..63bad5ba 100755 --- a/tests/testthat/test-rvar-print.R +++ b/tests/testthat/test-rvar-print.R @@ -80,6 +80,14 @@ test_that("print() works", { regexp = "12 levels: a b c d e f g h i j k l", all = FALSE ) + + x_long <- rvar_factor(combn(letters, 2, paste, collapse = "")) + out <- capture.output(print(x_long, color = FALSE, width = 50)) + expect_match( + out, + regexp = "325 levels: ab ac ad ae af ag ah ai aj \\.\\.\\. yz", + all = FALSE + ) }) test_that("print() works", { @@ -108,6 +116,61 @@ test_that("print() works", { ) }) +test_that("printing weighted rvars works", { + w <- c(1, 0, 1.25, 2, 1.75) + levs <- c("h","e","f","g") + xw <- weight_draws(rvar(c(1, 2, 0, 3, 0)), w) + xw_factor <- weight_draws(rvar_factor(c("e","f","h","g","h"), levels = levs), w) + xw_ordered <- weight_draws(rvar_ordered(c("e","f","h","g","h"), levels = levs), w) + + out <- capture.output(print(xw, color = FALSE)) + expect_match( + out, + regexp = "weighted rvar<5>\\[1\\] mean . sd:", + all = FALSE + ) + expect_match( + out, + regexp = "1.2 . 1.3", + all = FALSE + ) + + out <- capture.output(print(xw, summary = "median_mad", color = FALSE)) + expect_match( + out, + regexp = "weighted rvar<5>\\[1\\] median . mad:", + all = FALSE + ) + expect_match( + out, + regexp = "0.64 . 0.94", + all = FALSE + ) + + out <- capture.output(print(xw_factor, color = FALSE)) + expect_match( + out, + regexp = "weighted rvar_factor<5>\\[1\\] mode :", + all = FALSE + ) + expect_match( + out, + regexp = "h <0.73>", + all = FALSE + ) + + out <- capture.output(print(xw_ordered, color = FALSE)) + expect_match( + out, + regexp = "weighted rvar_ordered<5>\\[1\\] mode :", + all = FALSE + ) + expect_match( + out, + regexp = "h <0.82>", + all = FALSE + ) +}) # str --------------------------------------------------------------------- @@ -200,9 +263,40 @@ test_that("str() works", { ) }) +test_that("str() works", { + x <- rvar(1:100, log_weights = 2:101) + + expect_output(str(weight_draws(rvar(), 1)), + " weighted rvar<1>\\[0\\] " + ) + out <- capture.output(str(x)) + expect_match( + out, + regexp = " weighted rvar<100>\\[1\\] 99 . 0.96", + all = FALSE + ) + expect_match( + out, + regexp = " - log_weights\\(\\*\\)= int \\[1:100\\] 2 3 4 5", + all = FALSE + ) +}) + # other ------------------------------------------------------------------- +test_that("tibble printing works", { + skip_on_cran() + + x <- rvar(1:10) + out <- capture.output(print(tibble::tibble(x))) + expect_match( + out, + regexp = " 5.5 . 3", + all = FALSE + ) +}) + test_that("glimpse on rvar works", { skip_on_cran() x_vec <- rvar(array(1:24, dim = c(6,4))) diff --git a/tests/testthat/test-rvar-summaries-over-draws.R b/tests/testthat/test-rvar-summaries-over-draws.R index 5d1cd62d..86713ab7 100755 --- a/tests/testthat/test-rvar-summaries-over-draws.R +++ b/tests/testthat/test-rvar-summaries-over-draws.R @@ -193,3 +193,23 @@ test_that("anyNA works", { x_ord[2,1] <- NA expect_equal(anyNA(x_ord), TRUE) }) + + +# weighted summaries ------------------------------------------------------ + +test_that("weighted summaries work", { + x <- rvar(c(1,1,2,2,2,3,3,3,3)) + n <- ndraws(x) + w <- c(2,3,4,0) + xw <- weight_draws(rvar(c(1,2,3,4)), w) + + expect_equal(sum(xw), sum(x)) + expect_equal(prod(xw), prod(x)) + expect_equal(mean(xw), mean(x)) + expect_equal(median(xw), matrixStats::weightedMedian(draws_of(xw), w)) + expect_equal(mad(xw), matrixStats::weightedMad(draws_of(xw), w)) + # weighted var and sd don't use sample correction because it depends on + # knowing the sample size + expect_equal(var(xw), var(x)*(n-1)/n) + expect_equal(sd(xw), sqrt(var(x)*(n-1)/n)) +}) diff --git a/tests/testthat/test-subset_draws.R b/tests/testthat/test-subset_draws.R index 66dfcfc2..b5e44602 100644 --- a/tests/testthat/test-subset_draws.R +++ b/tests/testthat/test-subset_draws.R @@ -137,7 +137,7 @@ test_that("subset_draws works correctly for draws_rvars objects", { x <- weight_draws(x, rep(1, ndraws(x))) x_sub <- subset_draws(x, variable = "mu") - expect_equal(variables(x_sub, reserved = TRUE), c("mu", ".log_weight")) + expect_equal(variables(x_sub, reserved = TRUE), c("mu")) x_sub <- subset_draws(x, variable = "mu", chain = c(1, 2, 3), exclude = TRUE) expect_equal(setdiff(variables(x, reserved = TRUE), "mu"), variables(x_sub, reserved = TRUE)) diff --git a/tests/testthat/test-summarise_draws.R b/tests/testthat/test-summarise_draws.R index 80ed971e..7c146501 100644 --- a/tests/testthat/test-summarise_draws.R +++ b/tests/testthat/test-summarise_draws.R @@ -100,24 +100,24 @@ test_that(paste( x <- as_draws_array(test_array) sum_x <- summarise_draws(x) parsum_x <- summarise_draws(x, .cores = cores) - expect_identical(sum_x, parsum_x) + expect_equal(sum_x, parsum_x) dimnames(x)$variable[2] <- reserved_variables()[1] sum_x <- summarise_draws(x) parsum_x <- summarise_draws(x, .cores = cores) - expect_identical(sum_x, parsum_x) + expect_equal(sum_x, parsum_x) n <- 1 test_array <- array(data = rnorm(1000*nc*n), dim = c(1000,nc,n)) x <- as_draws_array(test_array) sum_x <- summarise_draws(x) parsum_x <- summarise_draws(x, .cores = cores) - expect_identical(sum_x, parsum_x) + expect_equal(sum_x, parsum_x) dimnames(x)$variable[1] <- reserved_variables()[1] suppressWarnings(sum_x <- summarise_draws(x)) suppressWarnings(parsum_x <- summarise_draws(x, .cores = cores)) - expect_identical(sum_x, parsum_x) + expect_equal(sum_x, parsum_x) }) test_that("summarise_draws supports tibble::set_num_opts correctly", { diff --git a/tests/testthat/test-weight_draws.R b/tests/testthat/test-weight_draws.R index fb6e6cc3..cd94a7aa 100644 --- a/tests/testthat/test-weight_draws.R +++ b/tests/testthat/test-weight_draws.R @@ -9,6 +9,9 @@ test_that("weight_draws works on draws_matrix", { x2 <- weight_draws(x, log(weights), log = TRUE) weights2 <- weights(x2) expect_equal(weights2, weights / sum(weights)) + + # test replacement of weights + expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2)) }) test_that("weight_draws works on draws_array", { @@ -22,6 +25,9 @@ test_that("weight_draws works on draws_array", { x2 <- weight_draws(x, log(weights), log = TRUE) weights2 <- weights(x2, normalize = FALSE) expect_equal(weights2, weights) + + # test replacement of weights + expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2)) }) test_that("weight_draws works on draws_df", { @@ -35,6 +41,9 @@ test_that("weight_draws works on draws_df", { x2 <- weight_draws(x, log(weights), log = TRUE) weights2 <- weights(x2) expect_equal(weights2, weights / sum(weights)) + + # test replacement of weights + expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2)) }) test_that("weight_draws works on draws_list", { @@ -48,6 +57,9 @@ test_that("weight_draws works on draws_list", { x2 <- weight_draws(x, log(weights), log = TRUE) weights2 <- weights(x2, normalize = FALSE) expect_equal(weights2, weights) + + # test replacement of weights + expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2)) }) test_that("weight_draws works on draws_rvars", { @@ -61,6 +73,49 @@ test_that("weight_draws works on draws_rvars", { x2 <- weight_draws(x, log(weights), log = TRUE) weights2 <- weights(x2, normalize = FALSE) expect_equal(weights2, weights) + + # test replacement of weights + expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2)) +}) + +test_that("weights are propagated to variables in draws_rvars", { + d <- draws_rvars(x = rvar(1:10, log_weights = 2:11), y = 3:12) + expect_equal(log_weights(d$x), 2:11) + expect_equal(log_weights(d$y), 2:11) + + d <- draws_rvars(x = 1:10, y = 3:12, .log_weight = 2:11) + expect_equal(log_weights(d$x), 2:11) + expect_equal(log_weights(d$y), 2:11) + + expect_error( + draws_rvars(x = rvar(1:10, log_weights = 1:10), y = rvar(3:12, log_weights = 2:11)), + "different log weights" + ) + + expect_error( + draws_rvars(x = rvar(1:10, log_weights = 1:10), .log_weight = 2:11), + "different log weights" + ) +}) + +# removing weights works -------------------------------------------------- + +test_that("weights can be removed", { + x <- list( + matrix = as_draws_matrix(example_draws()), + array = as_draws_array(example_draws()), + df = as_draws_df(example_draws()), + list = as_draws_list(example_draws()), + rvars = as_draws_rvars(example_draws()), + rvar = as_draws_rvars(example_draws())$mu + ) + + weights <- rexp(ndraws(example_draws())) + x_weighted <- lapply(x, weight_draws, weights) + + for (type in names(x)) { + expect_equal(weight_draws(x_weighted[[!!type]], NULL), x[[!!type]]) + } }) # conversion preserves weights -------------------------------------------- @@ -71,7 +126,8 @@ test_that("conversion between formats preserves weights", { array = weight_draws(draws_array(x = 1:10), 1:10), df = weight_draws(draws_df(x = 1:10), 1:10), list = weight_draws(draws_list(x = 1:10), 1:10), - rvars = weight_draws(draws_rvars(x = 1:10), 1:10) + rvars = weight_draws(draws_rvars(x = 1:10), 1:10), + rvar = weight_draws(rvar(x = 1:10), 1:10) ) # chain/iteration/draw columns are placed at the end by conversion functions, @@ -97,3 +153,19 @@ test_that("pareto smoothing smooths weights in weight_draws", { smoothed <- weight_draws(x, lw, pareto_smooth = TRUE, log = TRUE) expect_false(all(weights(weighted) == weights(smoothed))) }) + +# assertions on weights vector ------------------------------------------------ + +test_that("weights must match draws", { + x <- example_draws() + types <- list(as_draws_matrix, as_draws_array, as_draws_df, as_draws_list, as_draws_rvars) + for (type in types) { + expect_error(weight_draws((!!type)(x), 1), "weights must match .* draws") + } +}) + +test_that("weights must be a vector, not array/matrix", { + x <- example_draws() + w <- seq_len(ndraws(x)) + expect_error(weight_draws(x, matrix(w)), "Must be.*vector.*not.*matrix") +}) diff --git a/vignettes/rvar.Rmd b/vignettes/rvar.Rmd index fe61835a..fbf3abba 100755 --- a/vignettes/rvar.Rmd +++ b/vignettes/rvar.Rmd @@ -566,6 +566,35 @@ x This approach is also nice because it generalizes easily to more than two components. +## Weights + +Weighted `rvar`s can be created by passing log weights to the `log_weights` +parameter of `rvar()`, by using the `weight_draws()` function (as with `draws` +objects), or by converting a weighted `draws` object to a `draws_rvars` object. +Functions of `rvar`s, such as `mean()`, `sd()`, etc, support weights as +appropriate. + +For example, we can create an `rvar` that is a mixture of draws from +Normal(0,1) and Normal(5,1) distributions: + +```{r rvar_weighted} +x <- rvar(c(rnorm(10000, mean = c(0,5)))) +x +``` + +By default the mean is about 2.5, as the components with mean 0 and mean 5 +are weighted equally. However, if we weight the component with mean 5 twice +as much, then the summary display will show the appropriate weighted mean: + +```{r weighted_mean} +x <- weight_draws(x, rep(c(1,2), 5000)) +x +``` + +The latest version of [ggdist](https://mjskay.github.io/ggdist/) also supports +weighted `rvar`s, and will calculate histograms, densities, point summaries, and +intervals of `rvar`s correctly, accounting for weights. + ## Applying functions over `rvar`s The `rvar` data type supplies an implementation of `as.list()`, which should give