Skip to content

Commit

Permalink
add tests for partial variable matching and allow draws objs with les…
Browse files Browse the repository at this point in the history
…s draws than procs
  • Loading branch information
SteveBronder committed Apr 4, 2024
1 parent 872405f commit d85f3b5
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 36 deletions.
96 changes: 62 additions & 34 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ process_init.default <- function(x, ...) {
#' @param init A type that inherits the `posterior::draws` class.
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' number of dimensions. Typically the output of `model$variables()$parameters`.
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
Expand All @@ -1054,16 +1054,43 @@ process_init.draws <- function(init, num_procs, model_variables = NULL,
} else {
variable_names = colnames(draws)[!grepl("__", colnames(draws))]
}
draws <- posterior::subset_draws(init, variable = variable_names)
draws <- posterior::resample_draws(draws, ndraws = num_procs,
method ="simple_no_replace")
draws <- posterior::as_draws_df(init)
# Since all other process_init functions return `num_proc` inits
# This will only happen if a raw draws object is passed
if (nrow(draws) < num_procs) {
idx <- rep(1:nrow(draws), ceiling(num_procs / nrow(draws)))[1:num_procs]
draws <- draws[idx,]
} else if (nrow(draws) > num_procs) {
draws <- posterior::resample_draws(draws, ndraws = num_procs,
method ="simple_no_replace")
}
draws_rvar = posterior::as_draws_rvars(draws)
variable_names <- variable_names[variable_names %in% names(draws_rvar)]
draws_rvar <- posterior::subset_draws(draws_rvar, variable = variable_names)
inits = lapply(1:num_procs, function(draw_iter) {
init_i = lapply(variable_names, function(var_name) {
x = drop(posterior::draws_of(drop(
posterior::subset_draws(draws_rvar[[var_name]], draw=draw_iter))))
return(x)
})
bad_names = unlist(lapply(variable_names, function(var_name) {
x = drop(posterior::draws_of(drop(
posterior::subset_draws(draws_rvar[[var_name]], draw=draw_iter))))
if (any(is.infinite(x)) || any(is.na(x))) {
return(var_name)
}
return("")
}))
any_na_or_inf = bad_names != ""
if (any(any_na_or_inf)) {
err_msg = paste0(paste(bad_names[any_na_or_inf], collapse = ", "), " contains NA or Inf values!")
if (length(any_na_or_inf) > 1) {
err_msg = paste0("Variables: ", err_msg)
} else {
err_msg = paste0("Variable: ", err_msg)
}
stop(err_msg)
}
names(init_i) = variable_names
return(init_i)
})
Expand All @@ -1075,7 +1102,7 @@ process_init.draws <- function(init, num_procs, model_variables = NULL,
#' @param init List of init lists.
#' @param num_procs Number of inits needed.
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' number of dimensions. Typically the output of `model$variables()$parameters`.
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
Expand Down Expand Up @@ -1151,7 +1178,7 @@ process_init.list <- function(init, num_procs, model_variables = NULL,
#' @param init Function generating a single list of initial values.
#' @param num_procs Number of inits needed.
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' number of dimensions. Typically the output of `model$variables()$parameters`.
#' @return A character vector of file paths.
process_init.function <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
Expand All @@ -1173,24 +1200,30 @@ process_init.function <- function(init, num_procs, model_variables = NULL,
process_init(init_list, num_procs, model_variables)
}

#' Validate a fit is a valid init
#' @noRd
validate_fit_init = function(init, model_variables) {
# Convert from data.table to data.frame
if (all(init$return_codes() == 1)) {
stop("We are unable to create initial values from a model with no samples. Please check the results of the model used for inits before continuing.")
} else if (!is.null(model_variables) &&!any(names(model_variables$parameters) %in% init$metadata()$stan_variables)) {
stop("None of the names of the parameters for the model used for initial values match the names of parameters from the model currently running.")
}
}

#' Write initial values to files if provided as a `CmdStanMCMC` class
#' @noRd
#' @param init A `CmdStanMCMC` class
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' number of dimensions. Typically the output of `model$variables()$parameters`.
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
process_init.CmdStanMCMC <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
# Convert from data.table to data.frame
if (all(init$return_codes() == 1)) {
stop("We are unable to create initial values from a model with no samples. Please check the results of the model used for inits before continuing.")
} else if (!any(names(model_variables$parameters) %in% init$metadata()$stan_variables)) {
stop("None of the names of the parameters for the model used for initial values match the names of parameters from the model currently running.")
}
validate_fit_init(init, model_variables)
draws_df = init$draws(format = "df")
if (is.null(model_variables)) {
model_variables = list(parameters = colnames(draws_df)[2:(length(colnames(draws_df)) - 3)])
Expand All @@ -1207,20 +1240,16 @@ process_init.CmdStanMCMC <- function(init, num_procs, model_variables = NULL,
#' @param init A set of draws with `lp__` and `lp_approx__` columns.
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' number of dimensions. Typically the output of `model$variables()$parameters`.
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
#' @importFrom stats aggregate
process_init_approx <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
validate_fit_init(init, model_variables)
# Convert from data.table to data.frame
if (init$return_codes() == 1) {
stop("We are unable to create initial values from a model with no samples. Please check the results of the model used for inits before continuing.")
} else if (!any(names(model_variables$parameters) %in% init$metadata()$stan_variables)) {
stop("None of the names of the parameters for the model used for initial values match the names of parameters from the model currently running.")
}
draws_df = init$draws(format = "df")
if (is.null(model_variables)) {
model_variables = list(parameters = colnames(draws_df)[3:(length(colnames(draws_df)) - 3)])
Expand All @@ -1238,13 +1267,17 @@ process_init_approx <- function(init, num_procs, model_variables = NULL,
if (unique_draws < (0.95 * nrow(draws_df))) {
temp_df = stats::aggregate(.draw ~ lw, data = draws_df, FUN = min)
draws_df = posterior::as_draws_df(merge(temp_df, draws_df, by = 'lw'))
draws_df$pareto_weight = exp(draws_df$lw - max(draws_df$lw))
draws_df$weight = exp(draws_df$lw - max(draws_df$lw))
} else {
draws_df$pareto_weight = posterior::pareto_smooth(
exp(draws_df$lw - max(draws_df$lw)), tail = "right", return_k=FALSE)
if (inherits(init, "CmdStanPathfinder") && init$metadata()$psis_resample) {
draws_df$weight = rep(1.0, nrow(draws_df))
} else {
draws_df$weight = posterior::pareto_smooth(
exp(draws_df$lw - max(draws_df$lw)), tail = "right", return_k=FALSE)
}
}
init_draws_df = posterior::resample_draws(draws_df, ndraws = num_procs,
weights = draws_df$pareto_weight, method = "simple_no_replace")
weights = draws_df$weight, method = "simple_no_replace")
init_draws_lst = process_init(init_draws_df,
num_procs = num_procs, model_variables = model_variables, warn_partial)
return(init_draws_lst)
Expand All @@ -1256,7 +1289,7 @@ process_init_approx <- function(init, num_procs, model_variables = NULL,
#' @param init A `CmdStanPathfinder` class
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' number of dimensions. Typically the output of `model$variables()$parameters`.
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
Expand All @@ -1271,7 +1304,7 @@ process_init.CmdStanPathfinder <- function(init, num_procs, model_variables = NU
#' @param init A `CmdStanVB` class
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' number of dimensions. Typically the output of `model$variables()$parameters`.
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
Expand All @@ -1286,7 +1319,7 @@ process_init.CmdStanVB <- function(init, num_procs, model_variables = NULL,
#' @param init A `CmdStanLaplace` class
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' number of dimensions. Typically the output of `model$variables()$parameters`.
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
Expand All @@ -1302,25 +1335,20 @@ process_init.CmdStanLaplace <- function(init, num_procs, model_variables = NULL,
#' @param init A `CmdStanMLE` class
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' number of dimensions. Typically the output of `model$variables()$parameters`.
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
process_init.CmdStanMLE <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
# Convert from data.table to data.frame
if (init$return_codes() == 1) {
stop("We are unable to create initial values from a model with no samples. Please check the results of the model used for inits before continuing.")
} else if (!any(names(model_variables$parameters) %in% init$metadata()$stan_variables)) {
stop("None of the names of the parameters for the model used for initial values match the names of parameters from the model currently running.")
}
validate_fit_init(init, model_variables)
draws_df = init$draws(format = "df")
if (is.null(model_variables)) {
model_variables = list(parameters = colnames(draws_df)[2:(length(colnames(draws_df)) - 3)])
}
init_draws_df = posterior::resample_draws(draws_df, ndraws = num_procs,
method = "simple")
init_draws_df = draws_df[rep(1, num_procs),]
init_draws_lst_lst = process_init(init_draws_df,
num_procs = num_procs, model_variables = model_variables, warn_partial)
return(init_draws_lst_lst)
Expand Down
2 changes: 1 addition & 1 deletion man/cmdstanr-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions tests/testthat/resources/stan/logistic_simple.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
data {
int<lower=0> N;
array[N] int<lower=0, upper=1> y;
}
parameters {
real alpha;
}
model {
target += normal_lpdf(alpha | 0, 1);
target += bernoulli_logit_lpmf(y | alpha);
}
32 changes: 31 additions & 1 deletion tests/testthat/test-fit-init.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ set_cmdstan_path()
mod_params <- testing_model("parameter_types")
mod_schools <- testing_model("schools")
mod_logistic <- testing_model("logistic")
mod_logistic_simple <- testing_model("logistic_simple")
data_list_schools <- testing_data("schools")
data_list_logistic <- testing_data("logistic")
test_inits <- function(mod, fit_init, data_list = NULL) {
Expand Down Expand Up @@ -32,14 +33,23 @@ test_inits <- function(mod, fit_init, data_list = NULL) {
}

test_that("Sample method works as init", {
utils::capture.output(fit_sample_init <- mod_params$sample(chains = 1,
utils::capture.output(fit_sample_init <- mod_params$sample(chains = 1,
iter_warmup = 100, iter_sampling = 100, refresh = 0, seed = 1234))
utils::capture.output(fit_sample_multi_init <- mod_params$sample(chains = 4, init = fit_sample_init,
iter_warmup = 100, iter_sampling = 100, refresh = 0, seed = 1234))
expect_no_error(test_inits(mod_params, fit_sample_init))
expect_no_error(test_inits(mod_params, fit_sample_multi_init))
})

test_that("Subsets of parameters are allowed", {
utils::capture.output(fit_sample_init_simple <- mod_logistic_simple$sample(chains = 1,
data = data_list_logistic, iter_warmup = 100, iter_sampling = 100,
refresh = 0, seed = 1234))
expect_message(test_inits(mod_logistic, fit_sample_init_simple,
data_list_logistic))
})


test_that("Pathfinder method works as init", {
utils::capture.output(fit_path_init <- mod_params$pathfinder(seed=1234,
refresh = 0, num_paths = 4))
Expand Down Expand Up @@ -67,3 +77,23 @@ test_that("Optimization method works as init", {
data = data_list_logistic, seed=1234, refresh = 0))
expect_no_error(test_inits(mod_logistic, fit_ml_init, data_list_logistic))
})


test_that("Draws Object with NA or Inf throws error", {
utils::capture.output(fit_laplace_init <- mod_logistic$laplace(
data = data_list_logistic, seed = 1234, refresh=0))
draws_df = fit_laplace_init$draws()
draws_df[1, 3] = NA
expect_error(mod_logistic$laplace(
data = data_list_logistic, seed = 1234, refresh=0, init = draws_df[1, ]), "alpha contains NA or Inf values!")
draws_df[1, 4] = NA
expect_error(mod_logistic$sample(
data = data_list_logistic, seed = 1234, refresh=0, init = draws_df[1:4, ]), "alpha, beta contains NA or Inf values!")
draws_df = fit_laplace_init$draws()
draws_df[1, 3] = Inf
expect_error(mod_logistic$sample(
data = data_list_logistic, seed = 1234, refresh=0, init = draws_df[1:4, ]), "alpha contains NA or Inf values!")
draws_df[1, 4] = NA
expect_error(mod_logistic$sample(
data = data_list_logistic, seed = 1234, refresh=0, init = draws_df[1:4, ]), "alpha, beta contains NA or Inf values!")
})

0 comments on commit d85f3b5

Please sign in to comment.