Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement checkpointing for bmm models #130

Draft
wants to merge 6 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ Suggests:
cowplot,
stringr,
remotes,
waldo
waldo,
xfun,
chkptstanr
Config/testthat/edition: 3
Imports:
magrittr,
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

### New features

* add option to checkpoint bmm models sampling via the `checkpoints` argument in `fit_model()`. This uses the `chkptstanr` package as a backend to save the sampling results every "checkpoints" iterations. This is useful for long sampling runs, as it allows you to resume sampling from the last checkpoint in case of a crash or other interruption (#129). This option should be considered Experimental. It works only with the `cmdstanr` backend, and it requires you to install a forked version of `chkptstanr` from GitHub, which implements a number of bugfixes. To install the forked version, run `remotes::install_github("venpopov/chkptstanr")`. See '?fit_model' for more information on how to use the `checkpoints` argument, and see the `chkptstanr` package documentation for the motivation and benefits of using checkpoints.
* add a check for the sdmSimple model if the data is sorted by predictors. This leads to much faster sampling. The user can control the default behavior with the `sort_data` argument (#72)
* the mixture3p and imm models now require that when set size is used as a predictor, the intercept must be suppressed. This is because set size 1 otherwise causes problems - there can be no contribution of non_target responses when there is set size 1, and it is not meaningful to estimate an intercept for parameters that involve non_target responses (#96).
* add postprocessing methods for sdmSimple to allow for pp_check(), conditional_effects and bridgesampling usage with the model (#30)
Expand Down
1 change: 0 additions & 1 deletion R/bmm_model_sdmSimple.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ configure_model.sdmSimple <- function(model, data, formula) {
formula <- bmf2bf(model, bmm_formula)

# construct the default prior
# TODO: for now it just fixes mu to 0, I have to add proper priors
prior <- fixed_pars_priors(model)
if (getOption("bmm.default_priors", TRUE)) {
prior <- prior +
Expand Down
35 changes: 32 additions & 3 deletions R/fit_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,36 @@
#' printed. Set refresh = 0 to turn this off as well. If using backend =
#' "rstan" you can also set open_progress = FALSE to prevent opening
#' additional progress bars.
#' @param ... Further arguments passed to [brms::brm()] or Stan. See the
#' @param checkpoints Experimental. A numeric vector of iteration numbers at which to save the
#' current state of the sampler. This option uses the [chkptstanr][chkptstanr::chkptstanr-package] package to
#' allow interrupted sampling to be resumed. Disabled by default. Enabling this
#' option requires the [chkptstanr][chkptstanr::chkptstanr-package] package to be installed (see details).
#' @param checkpoints_folder If checkpoints is not NULL, this argument specifies the
#' directory where the checkpoints will be saved
#' @param checkpoints_path if NULL (default), the checkpoints_folder will be created
#' in the current working directory. Alternatively, you can specify a path to
#' an existing folder where the checkpoints_folder will be created
#' @param ... Further arguments passed to [brms::brm()], Stan or chkptstanr. See the
#' description of [brms::brm()] for more details
#'
#' @details `r a= supported_models(); a`
#'
#' Type `help(package=bmm)` for a full list of available help topics.
#'
#' ## Using checkpoints
#'
#' The `checkpoints` argument allows you to save the current state of the
#' sampler at specific iteration numbers. This can be useful if you want to
#' interrupt the sampling process and resume it later. This feature requires
#' the chkptstanr package to be installed, and to use "backend = cmdstanr".
#' The current CRAN version of chkptstanr has a bug that prevents it from
#' working. Until the issue is fixed, you can install a working forked version
#' of chkptstanr with:
#'
#' ``` r
#' remotes::install_github("venpopov/chkptstanr")
#' ```
#'
#' @returns An object of class brmsfit which contains the posterior draws along
#' with many other useful information about the model. Use methods(class =
#' "brmsfit") for an overview on available methods.
Expand Down Expand Up @@ -72,11 +95,14 @@
#' parallel=T,
#' iter=500,
#' backend='cmdstanr')
#'
#' # TODO: add checkpoint example
#' }
#'
fit_model <- function(formula, data, model, parallel = FALSE, chains = 4,
prior = NULL, sort_data = getOption('bmm.sort_data', NULL),
silent = getOption('bmm.silent', 1), ...) {
silent = getOption('bmm.silent', 1), checkpoints = NULL,
checkpoints_folder = NULL, checkpoints_path = NULL, ...) {
# warning for using old version
dots <- list(...)
if ("model_type" %in% names(dots)) {
Expand Down Expand Up @@ -105,7 +131,10 @@ fit_model <- function(formula, data, model, parallel = FALSE, chains = 4,
# estimate the model
dots <- list(...)
fit_args <- combine_args(nlist(config_args, opts, dots))
fit <- call_brm(fit_args)
fit <- run_model(fit_args,
checkpoints = checkpoints,
checkpoints_folder = checkpoints_folder,
checkpoints_path = checkpoints_path)

# model postprocessing
postprocess_brm(model, fit, fit_args = fit_args, user_formula = user_formula,
Expand Down
60 changes: 58 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,65 @@ glue_lf <- function(...) {
#' not used directly, but called by fit_model(). If fit_model() is run with
#' backend="mock", then we can perform tests on the fit_args to check if the
#' model configuration is correct. Avoids compiling and running the model
#' if checkpoints is not NULL, use [chkptstanr]
#' @noRd
call_brm <- function(fit_args) {
fit <- brms::do_call(brms::brm, fit_args)
run_model <- function(fit_args, checkpoints, checkpoints_folder, checkpoints_path) {
if (is.null(checkpoints)) {
fit <- brms::do_call(brms::brm, fit_args)
return(fit)
}

if (is.null(checkpoints_folder)) {
stop2("You must provide a folder name to save the checkpoints")
}

# needed because of silly setup in chkptstanr::create_dir(). Eventually can remove
# if I rework their function
if (xfun::is_abs_path(checkpoints_folder)) {
stop2("The checkpoints_folder argument must be a relative path.\n",
"You can provide a base path in which to create the folder",
" with the checkpoints_path argument, or leave it as NULL to use",
" the current working directory.")
}

if (!requireNamespace("chkptstanr", quietly = TRUE)) {
stop2(
"\nPackage \"chkptstanr\" must be installed to use this function.\n",
"The current CRAN version of chkptstanr has a bug that prevents it from",
"working. Until the issue is fixed, you can install a working forked version",
"of chkptstanr with:\n\n",
"remotes::install_github(\"venpopov/chkptstanr\")"
)
}

if (fit_args$backend == "rstan") {
stop2("Checkpoints are not supported for rstan. Use backend='cmdstanr' instead.")
}

if (!is.null(fit_args$iter)) {
fit_args$iter_warmup <- ifelse(is.null(fit_args$warmup), fit_args$iter/2, fit_args$warmup)
fit_args$iter_sampling <- fit_args$iter - fit_args$iter_warmup
fit_args$iter <- NULL
fit_args$warmup <- NULL
}

fit_args$iter_per_chkpt <- checkpoints
fit_args$path <- file_path2(checkpoints_path, checkpoints_folder)
if (!dir.exists(fit_args$path)) {
# TODO: this check should really be implemented in chkptstanr::create_folder
fit_args$path <- chkptstanr::create_folder(checkpoints_folder, path = checkpoints_path)
}
attr(fit_args$path, "info") <- "chkpt_brms folder"

fit <- brms::do_call(chkptstanr::chkpt_brms, fit_args)
fit
}

# wrapper around file path ignoring null values
file_path2 <- function(...) {
dots <- list(...)
dots <- dots[!sapply(dots, is.null)]
do.call(file.path, dots)
}


Expand Down
32 changes: 31 additions & 1 deletion man/fit_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions tests/internal/test_checkpointing.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
test_that('fit_model works with checkpointing', {
folder <- 'local/checkpoints2'
on.exit(unlink(folder, recursive = T, force = T))

data1 <- dplyr::filter(OberauerLin_2017, set_size %in% c(1,2,3,4), ID %in% 1:10)
formula <- bmf(c ~ set_size, kappa ~ 1)
model <- sdmSimple('dev_rad')

cat("\n\nRunning for 2 checkpoints then stopping\n\n")

fit <- try(fit_model(formula, data1, model,
parallel = T,
backend = 'cmdstanr',
sort_data = T,
iter = 100,
checkpoints = 25,
stop_after = 2,
checkpoints_folder = folder),
silent = T)

cat("\n\nTrying to pick up where we stopped\n\n")

fit <- try(fit_model(formula, data1, model,
parallel = T,
backend = 'cmdstanr',
sort_data = T,
iter = 100,
checkpoints = 25,
checkpoints_folder = folder),
silent = T)

expect_false(is(fit, 'try-error'))
})