Skip to content

Commit

Permalink
Merge pull request #326 from n-kall/autothin
Browse files Browse the repository at this point in the history
add automatic thinning of draws
  • Loading branch information
paul-buerkner authored Jan 15, 2024
2 parents 0322b46 + 3dc38b1 commit 98bfcbd
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 16 deletions.
49 changes: 39 additions & 10 deletions R/thin_draws.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
#' Thin `draws` objects
#'
#' Thin [`draws`] objects to reduce their size and autocorrelation in the chains.
#' Thin [`draws`] objects to reduce their size and autocorrelation in
#' the chains.
#'
#' @aliases thin
#' @template args-methods-x
#' @param thin (positive integer) The period for selecting draws.
#' @param thin (positive numeric) The period for selecting draws. Must
#' be between 1 and the number of iterations. If the value is not an
#' integer, the draws will be selected such that the number of draws
#' returned is equal to round(ndraws(x) / thin). Intervals between
#' selected draws will be either ceiling(thin) or floor(thin), such
#' that the average interval will be close to the thin value. If
#' `NULL`, it will be automatically calculated based on bulk and
#' tail effective sample size as suggested by Säilynoja et
#' al. (2022).
#' @template args-methods-dots
#' @template ref-sailynoja-ecdf-2022
#' @template return-draws
#'
#' @examples
Expand All @@ -16,31 +26,50 @@
#' niterations(x)
#'
#' @export
thin_draws <- function(x, thin, ...) {
thin_draws <- function(x, thin = NULL, ...) {
UseMethod("thin_draws")
}

#' @rdname thin_draws
#' @export
thin_draws.draws <- function(x, thin, ...) {
thin <- as_one_integer(thin)
thin_draws.draws <- function(x, thin = NULL, ...) {
if (is.null(thin)) {
thin <- ess_based_thinning_all_vars(x)
message("Automatically thinned by ", round(thin, 1), " based on ESS.")
}

thin <- as_one_numeric(thin)
if (thin == 1L) {
# no thinning requested
return(x)
}
if (thin <= 0L) {
stop_no_call("'thin' must be a positive integer.")
if (thin <= 1L) {
stop_no_call("'thin' must be greater than or equal to 1")
}
niterations <- niterations(x)
if (thin > niterations ) {
if (thin > niterations) {
stop_no_call("'thin' must be smaller than the total number of iterations.")
}
iteration_ids <- seq(1, niterations, by = thin)
iteration_ids <- round(seq(1, niterations, by = thin))
subset_draws(x, iteration = iteration_ids)
}

#' @rdname thin_draws
#' @export
thin_draws.rvar <- function(x, thin, ...) {
thin_draws.rvar <- function(x, thin = NULL, ...) {
thin_draws(draws_rvars(x = x), thin, ...)$x
}

ess_based_thinning_all_vars <- function(x, ...) {
max(summarise_draws(x, thin = ess_based_thinning)$thin)
}

ess_based_thinning <- function(x, ...) {
# thin based on mean (over chains) of minimum of tail and bulk ess
x <- as.matrix(x)
ess_tailbulk_chains <- apply(x,
MARGIN = 2,
FUN = function(x) min(SW(ess_tail(x)), SW(ess_bulk(x)))
)
nrow(x) / mean(ess_tailbulk_chains)
}
5 changes: 5 additions & 0 deletions man-roxygen/ref-sailynoja-ecdf-2022.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#' @references
#' Teemu Säilynoja, Paul-Christian Bürkner, and Aki Vehtari (2022).
#' Graphical test for discrete uniformity and its applications in
#' goodness-of-fit evaluation and multiple sample comparison.
#' *Statistics and Computing*. 32, 32. doi:10.1007/s11222-022-10090-6
25 changes: 20 additions & 5 deletions man/thin_draws.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 21 additions & 1 deletion tests/testthat/test-thin_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ test_that("thin_draws works correctly", {
x <- as_draws_array(example_draws())
expect_equal(niterations(thin_draws(x, 5L)), niterations(x) / 5)
expect_equal(x, thin_draws(x, thin = 1L))
expect_error(thin_draws(x, -1), "'thin' must be a positive integer")
expect_error(thin_draws(x, -1), "'thin' must be greater than or equal to 1")
expect_error(thin_draws(x, 1000), "'thin' must be smaller than")
})

Expand All @@ -11,3 +11,23 @@ test_that("thin_draws works on rvars", {

expect_equal(thin_draws(as_draws_rvars(x)$theta, 10L), as_draws_rvars(thin_draws(x, 10L))$theta)
})

test_that("automatic thinning works as expected", {
x <- as_draws_array(example_draws())
mu <- subset_draws(x, "mu")
mu_1 <- subset_draws(mu, chain = 1)
mu_2 <- subset_draws(mu, chain = 2)
mu_3 <- subset_draws(mu, chain = 3)
mu_4 <- subset_draws(mu, chain = 4)

ess_mu_1 <- SW(min(ess_tail(mu_1), ess_bulk(mu_1)))
ess_mu_2 <- SW(min(ess_tail(mu_2), ess_bulk(mu_2)))
ess_mu_3 <- SW(min(ess_tail(mu_3), ess_bulk(mu_3)))
ess_mu_4 <- SW(min(ess_tail(mu_4), ess_bulk(mu_4)))

thin_by <- niterations(mu) / mean(
c(ess_mu_1, ess_mu_2, ess_mu_3, ess_mu_4))

expect_equal(thin_draws(mu), thin_draws(mu, thin = thin_by))

})

0 comments on commit 98bfcbd

Please sign in to comment.