diff --git a/DESCRIPTION b/DESCRIPTION index 3d55cb48..7ce3e65e 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -29,7 +29,7 @@ URL: https://mc-stan.org/cmdstanr/, https://discourse.mc-stan.org BugReports: https://github.com/stan-dev/cmdstanr/issues Encoding: UTF-8 LazyData: true -RoxygenNote: 7.3.0 +RoxygenNote: 7.3.1 Roxygen: list(markdown = TRUE, r6 = FALSE) SystemRequirements: CmdStan (https://mc-stan.org/users/interfaces/cmdstan) Depends: diff --git a/R/args.R b/R/args.R index 7b4a8686..0bfdf474 100644 --- a/R/args.R +++ b/R/args.R @@ -77,11 +77,12 @@ CmdStanArgs <- R6::R6Class( } self$output_dir <- repair_path(self$output_dir) self$output_basename <- output_basename - if (is.function(init)) { - init <- process_init_function(init, length(self$proc_ids), model_variables) - } else if (is.list(init) && !is.data.frame(init)) { - init <- process_init_list(init, length(self$proc_ids), model_variables) + if (inherits(self$method_args, "PathfinderArgs")) { + num_inits <- self$method_args$num_paths + } else { + num_inits <- length(self$proc_ids) } + init <- process_init(init, num_inits, model_variables) self$init <- init self$opencl_ids <- opencl_ids self$num_threads = NULL @@ -691,7 +692,12 @@ validate_cmdstan_args <- function(self) { assert_file_exists(self$data_file, access = "r") } num_procs <- length(self$proc_ids) - validate_init(self$init, num_procs) + if (inherits(self$method_args, "PathfinderArgs")) { + num_inits <- self$method_args$num_paths + } else { + num_inits <- length(self$proc_ids) + } + validate_init(self$init, num_inits) validate_seed(self$seed, num_procs) if (!is.null(self$opencl_ids)) { if (cmdstan_version() < "2.26") { @@ -1018,17 +1024,63 @@ validate_exe_file <- function(exe_file) { invisible(TRUE) } + +#' Generic for processing inits +#' @noRd +process_init <- function(...) { + UseMethod("process_init") +} + +#' Default method +#' @noRd +process_init.default <- function(x, ...) { + return(x) +} + +#' Write initial values to files if provided as posterior `draws` object +#' @noRd +#' @param init A type that inherits the `posterior::draws` class. +#' @param num_procs Number of inits requested +#' @param model_variables A list of all parameters with their types and +#' number of dimensions. Typically the output of model$variables(). +#' @param warn_partial Should a warning be thrown if inits are only specified +#' for a subset of parameters? Can be controlled by global option +#' `cmdstanr_warn_inits`. +#' @return A character vector of file paths. +process_init.draws <- function(init, num_procs, model_variables = NULL, + warn_partial = getOption("cmdstanr_warn_inits", TRUE)) { + if (!is.null(model_variables)) { + variable_names = names(model_variables$parameters) + } else { + variable_names = colnames(draws)[!grepl("__", colnames(draws))] + } + draws <- posterior::subset_draws(init, variable = variable_names) + draws <- posterior::resample_draws(draws, ndraws = num_procs, + method ="simple_no_replace") + draws_rvar = posterior::as_draws_rvars(draws) + inits = lapply(1:num_procs, \(draw_iter) { + init_i = lapply(variable_names, \(var_name) { + x = drop(posterior::draws_of(drop( + posterior::subset_draws(draws_rvar[[var_name]], draw=draw_iter)))) + return(x) + }) + names(init_i) = variable_names + return(init_i) + }) + return(process_init(inits, num_procs, model_variables, warn_partial)) +} + #' Write initial values to files if provided as list of lists #' @noRd #' @param init List of init lists. -#' @param num_procs Number of CmdStan processes. +#' @param num_procs Number of inits needed. #' @param model_variables A list of all parameters with their types and #' number of dimensions. Typically the output of model$variables(). #' @param warn_partial Should a warning be thrown if inits are only specified #' for a subset of parameters? Can be controlled by global option #' `cmdstanr_warn_inits`. #' @return A character vector of file paths. -process_init_list <- function(init, num_procs, model_variables = NULL, +process_init.list <- function(init, num_procs, model_variables = NULL, warn_partial = getOption("cmdstanr_warn_inits", TRUE)) { if (!all(sapply(init, function(x) is.list(x) && !is.data.frame(x)))) { stop("If 'init' is a list it must be a list of lists.", call. = FALSE) @@ -1083,10 +1135,11 @@ process_init_list <- function(init, num_procs, model_variables = NULL, } init_paths <- tempfile( - pattern = paste0("init-", seq_along(init), "-"), + pattern = "init-", tmpdir = cmdstan_tempdir(), - fileext = ".json" + fileext = "" ) + init_paths <- paste0(init_paths, "_", seq_along(init), ".json") for (i in seq_along(init)) { write_stan_json(init[[i]], init_paths[i]) } @@ -1096,11 +1149,12 @@ process_init_list <- function(init, num_procs, model_variables = NULL, #' Write initial values to files if provided as function #' @noRd #' @param init Function generating a single list of initial values. -#' @param num_procs Number of CmdStan processes. +#' @param num_procs Number of inits needed. #' @param model_variables A list of all parameters with their types and #' number of dimensions. Typically the output of model$variables(). #' @return A character vector of file paths. -process_init_function <- function(init, num_procs, model_variables = NULL) { +process_init.function <- function(init, num_procs, model_variables = NULL, + warn_partial = getOption("cmdstanr_warn_inits", TRUE)) { args <- formals(init) if (is.null(args)) { fn_test <- init() @@ -1116,7 +1170,159 @@ process_init_function <- function(init, num_procs, model_variables = NULL) { if (!is.list(fn_test) || is.data.frame(fn_test)) { stop("If 'init' is a function it must return a single list.") } - process_init_list(init_list, num_procs, model_variables) + process_init(init_list, num_procs, model_variables) +} + +#' Write initial values to files if provided as a `CmdStanMCMC` class +#' @noRd +#' @param init A `CmdStanMCMC` class +#' @param num_procs Number of inits requested +#' @param model_variables A list of all parameters with their types and +#' number of dimensions. Typically the output of model$variables(). +#' @param warn_partial Should a warning be thrown if inits are only specified +#' for a subset of parameters? Can be controlled by global option +#' `cmdstanr_warn_inits`. +#' @return A character vector of file paths. +process_init.CmdStanMCMC <- function(init, num_procs, model_variables = NULL, + warn_partial = getOption("cmdstanr_warn_inits", TRUE)) { + # Convert from data.table to data.frame + if (all(init$return_codes() == 1)) { + stop("We are unable to create initial values from a model with no samples. Please check the results of the model used for inits before continuing.") + } else if (!any(names(model_variables$parameters) %in% init$metadata()$stan_variables)) { + stop("None of the names of the parameters for the model used for initial values match the names of parameters from the model currently running.") + } + draws_df = init$draws(format = "df") + if (is.null(model_variables)) { + model_variables = list(parameters = colnames(draws_df)[2:(length(colnames(draws_df)) - 3)]) + } + init_draws_df = posterior::resample_draws(draws_df, ndraws = num_procs, + method = "simple_no_replace") + init_draws_lst = process_init(init_draws_df, + num_procs = num_procs, model_variables = model_variables) + return(init_draws_lst) +} + +#' Performs PSIS resampling on the draws from an approxmation method for inits. +#' @noRd +#' @param init A set of draws with `lp__` and `lp_approx__` columns. +#' @param num_procs Number of inits requested +#' @param model_variables A list of all parameters with their types and +#' number of dimensions. Typically the output of model$variables(). +#' @param warn_partial Should a warning be thrown if inits are only specified +#' for a subset of parameters? Can be controlled by global option +#' `cmdstanr_warn_inits`. +#' @return A character vector of file paths. +process_init_approx <- function(init, num_procs, model_variables = NULL, + warn_partial = getOption("cmdstanr_warn_inits", TRUE)) { + # Convert from data.table to data.frame + if (init$return_codes() == 1) { + stop("We are unable to create initial values from a model with no samples. Please check the results of the model used for inits before continuing.") + } else if (!any(names(model_variables$parameters) %in% init$metadata()$stan_variables)) { + stop("None of the names of the parameters for the model used for initial values match the names of parameters from the model currently running.") + } + draws_df = init$draws(format = "df") + if (is.null(model_variables)) { + model_variables = list(parameters = colnames(draws_df)[3:(length(colnames(draws_df)) - 3)]) + } + draws_df$lw = draws_df$lp__ - draws_df$lp_approx__ + # Calculate unique draws based on 'lw' using base R functions + unique_draws = length(unique(draws_df$lw)) + if (num_procs > unique_draws) { + if (inherits(init, "CmdStanPathfinder")) { + stop(paste0("Not enough distinct draws (", num_procs, ") in pathfinder fit to create inits. Try running Pathfinder with psis_resample=FALSE")) + } else { + stop(paste0("Not enough distinct draws (", num_procs, ") to create inits.")) + } + } + if (unique_draws < (0.95 * nrow(draws_df))) { + temp_df = aggregate(.draw ~ lw, data = draws_df, FUN = min) + draws_df = posterior::as_draws_df(merge(temp_df, draws_df, by = 'lw')) + draws_df$pareto_weight = exp(draws_df$lw - max(draws_df$lw)) + } else { + draws_df$pareto_weight = posterior::pareto_smooth( + exp(draws_df$lw - max(draws_df$lw)), tail = "right")[["x"]] + } + init_draws_df = posterior::resample_draws(draws_df, ndraws = num_procs, + weights = draws_df$pareto_weight, method = "simple_no_replace") + init_draws_lst = process_init(init_draws_df, + num_procs = num_procs, model_variables = model_variables, warn_partial) + return(init_draws_lst) +} + + +#' Write initial values to files if provided as a `CmdStanPathfinder` class +#' @noRd +#' @param init A `CmdStanPathfinder` class +#' @param num_procs Number of inits requested +#' @param model_variables A list of all parameters with their types and +#' number of dimensions. Typically the output of model$variables(). +#' @param warn_partial Should a warning be thrown if inits are only specified +#' for a subset of parameters? Can be controlled by global option +#' `cmdstanr_warn_inits`. +#' @return A character vector of file paths. +process_init.CmdStanPathfinder <- function(init, num_procs, model_variables = NULL, + warn_partial = getOption("cmdstanr_warn_inits", TRUE)) { + process_init_approx(init, num_procs, model_variables, warn_partial) +} + +#' Write initial values to files if provided as a `CmdStanVB` class +#' @noRd +#' @param init A `CmdStanVB` class +#' @param num_procs Number of inits requested +#' @param model_variables A list of all parameters with their types and +#' number of dimensions. Typically the output of model$variables(). +#' @param warn_partial Should a warning be thrown if inits are only specified +#' for a subset of parameters? Can be controlled by global option +#' `cmdstanr_warn_inits`. +#' @return A character vector of file paths. +process_init.CmdStanVB <- function(init, num_procs, model_variables = NULL, + warn_partial = getOption("cmdstanr_warn_inits", TRUE)) { + process_init_approx(init, num_procs, model_variables, warn_partial) +} + +#' Write initial values to files if provided as a `CmdStanLaplace` class +#' @noRd +#' @param init A `CmdStanLaplace` class +#' @param num_procs Number of inits requested +#' @param model_variables A list of all parameters with their types and +#' number of dimensions. Typically the output of model$variables(). +#' @param warn_partial Should a warning be thrown if inits are only specified +#' for a subset of parameters? Can be controlled by global option +#' `cmdstanr_warn_inits`. +#' @return A character vector of file paths. +process_init.CmdStanLaplace <- function(init, num_procs, model_variables = NULL, + warn_partial = getOption("cmdstanr_warn_inits", TRUE)) { + process_init_approx(init, num_procs, model_variables, warn_partial) +} + + +#' Write initial values to files if provided as a `CmdStanMLE` class +#' @noRd +#' @param init A `CmdStanMLE` class +#' @param num_procs Number of inits requested +#' @param model_variables A list of all parameters with their types and +#' number of dimensions. Typically the output of model$variables(). +#' @param warn_partial Should a warning be thrown if inits are only specified +#' for a subset of parameters? Can be controlled by global option +#' `cmdstanr_warn_inits`. +#' @return A character vector of file paths. +process_init.CmdStanMLE <- function(init, num_procs, model_variables = NULL, + warn_partial = getOption("cmdstanr_warn_inits", TRUE)) { + # Convert from data.table to data.frame + if (init$return_codes() == 1) { + stop("We are unable to create initial values from a model with no samples. Please check the results of the model used for inits before continuing.") + } else if (!any(names(model_variables$parameters) %in% init$metadata()$stan_variables)) { + stop("None of the names of the parameters for the model used for initial values match the names of parameters from the model currently running.") + } + draws_df = init$draws(format = "df") + if (is.null(model_variables)) { + model_variables = list(parameters = colnames(draws_df)[2:(length(colnames(draws_df)) - 3)]) + } + init_draws_df = posterior::resample_draws(draws_df, ndraws = num_procs, + method = "simple") + init_draws_lst_lst = process_init(init_draws_df, + num_procs = num_procs, model_variables = model_variables, warn_partial) + return(init_draws_lst_lst) } #' Validate initial values diff --git a/R/cmdstanr-package.R b/R/cmdstanr-package.R index 305241bf..d1f34bf1 100644 --- a/R/cmdstanr-package.R +++ b/R/cmdstanr-package.R @@ -30,6 +30,6 @@ #' @inherit cmdstan_model examples #' @import R6 #' -NULL +"_PACKAGE" if (getRversion() >= "2.15.1") utils::globalVariables(c("self", "private", "super")) diff --git a/R/fit.R b/R/fit.R index 99feca4c..68215608 100644 --- a/R/fit.R +++ b/R/fit.R @@ -518,11 +518,11 @@ unconstrain_variables <- function(variables) { " not provided!", call. = FALSE) } - # Remove zero-length parameters from model_variables, otherwise process_init_list + # Remove zero-length parameters from model_variables, otherwise process_init # warns about missing inputs model_variables$parameters <- model_variables$parameters[nonzero_length_params] - stan_pars <- process_init_list(list(variables), num_procs = 1, model_variables) + stan_pars <- process_init(list(variables), num_procs = 1, model_variables) private$model_methods_env_$unconstrain_variables(private$model_methods_env_$model_ptr_, stan_pars) } CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_variables) @@ -594,7 +594,7 @@ unconstrain_draws <- function(files = NULL, draws = NULL, # but not in metadata()$variables nonzero_length_params <- names(model_variables$parameters) %in% model_par_names - # Remove zero-length parameters from model_variables, otherwise process_init_list + # Remove zero-length parameters from model_variables, otherwise process_init # warns about missing inputs pars <- names(model_variables$parameters[nonzero_length_params]) diff --git a/man/model-method-check_syntax.Rd b/man/model-method-check_syntax.Rd index a646a5e1..68366fb5 100644 --- a/man/model-method-check_syntax.Rd +++ b/man/model-method-check_syntax.Rd @@ -86,8 +86,8 @@ Other CmdStanModel methods: \code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, \code{\link{model-method-pathfinder}}, -\code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, +\code{\link{model-method-sample_mpi}}, \code{\link{model-method-variables}}, \code{\link{model-method-variational}} } diff --git a/man/model-method-compile.Rd b/man/model-method-compile.Rd index d295eedc..7bfa47d7 100644 --- a/man/model-method-compile.Rd +++ b/man/model-method-compile.Rd @@ -157,8 +157,8 @@ Other CmdStanModel methods: \code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, \code{\link{model-method-pathfinder}}, -\code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, +\code{\link{model-method-sample_mpi}}, \code{\link{model-method-variables}}, \code{\link{model-method-variational}} } diff --git a/man/model-method-diagnose.Rd b/man/model-method-diagnose.Rd index 9a0acd31..99043501 100644 --- a/man/model-method-diagnose.Rd +++ b/man/model-method-diagnose.Rd @@ -129,8 +129,8 @@ Other CmdStanModel methods: \code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, \code{\link{model-method-pathfinder}}, -\code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, +\code{\link{model-method-sample_mpi}}, \code{\link{model-method-variables}}, \code{\link{model-method-variational}} } diff --git a/man/model-method-expose_functions.Rd b/man/model-method-expose_functions.Rd index a62f7bb8..b7d42231 100644 --- a/man/model-method-expose_functions.Rd +++ b/man/model-method-expose_functions.Rd @@ -77,8 +77,8 @@ Other CmdStanModel methods: \code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, \code{\link{model-method-pathfinder}}, -\code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, +\code{\link{model-method-sample_mpi}}, \code{\link{model-method-variables}}, \code{\link{model-method-variational}} } diff --git a/man/model-method-format.Rd b/man/model-method-format.Rd index 2aa34f18..d24010a4 100644 --- a/man/model-method-format.Rd +++ b/man/model-method-format.Rd @@ -106,8 +106,8 @@ Other CmdStanModel methods: \code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, \code{\link{model-method-pathfinder}}, -\code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, +\code{\link{model-method-sample_mpi}}, \code{\link{model-method-variables}}, \code{\link{model-method-variational}} } diff --git a/man/model-method-generate-quantities.Rd b/man/model-method-generate-quantities.Rd index bf25602e..23acba19 100644 --- a/man/model-method-generate-quantities.Rd +++ b/man/model-method-generate-quantities.Rd @@ -178,8 +178,8 @@ Other CmdStanModel methods: \code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, \code{\link{model-method-pathfinder}}, -\code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, +\code{\link{model-method-sample_mpi}}, \code{\link{model-method-variables}}, \code{\link{model-method-variational}} } diff --git a/man/model-method-laplace.Rd b/man/model-method-laplace.Rd index b033fbe3..253d67f5 100644 --- a/man/model-method-laplace.Rd +++ b/man/model-method-laplace.Rd @@ -214,8 +214,8 @@ Other CmdStanModel methods: \code{\link{model-method-generate-quantities}}, \code{\link{model-method-optimize}}, \code{\link{model-method-pathfinder}}, -\code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, +\code{\link{model-method-sample_mpi}}, \code{\link{model-method-variables}}, \code{\link{model-method-variational}} } diff --git a/man/model-method-optimize.Rd b/man/model-method-optimize.Rd index dcf77444..b9b53454 100644 --- a/man/model-method-optimize.Rd +++ b/man/model-method-optimize.Rd @@ -332,8 +332,8 @@ Other CmdStanModel methods: \code{\link{model-method-generate-quantities}}, \code{\link{model-method-laplace}}, \code{\link{model-method-pathfinder}}, -\code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, +\code{\link{model-method-sample_mpi}}, \code{\link{model-method-variables}}, \code{\link{model-method-variational}} } diff --git a/man/model-method-pathfinder.Rd b/man/model-method-pathfinder.Rd index 85fc9236..41504358 100644 --- a/man/model-method-pathfinder.Rd +++ b/man/model-method-pathfinder.Rd @@ -357,8 +357,8 @@ Other CmdStanModel methods: \code{\link{model-method-generate-quantities}}, \code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, -\code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, +\code{\link{model-method-sample_mpi}}, \code{\link{model-method-variables}}, \code{\link{model-method-variational}} } diff --git a/man/model-method-variables.Rd b/man/model-method-variables.Rd index dc80ed9a..87e9d73e 100644 --- a/man/model-method-variables.Rd +++ b/man/model-method-variables.Rd @@ -46,8 +46,8 @@ Other CmdStanModel methods: \code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, \code{\link{model-method-pathfinder}}, -\code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, +\code{\link{model-method-sample_mpi}}, \code{\link{model-method-variational}} } \concept{CmdStanModel methods} diff --git a/man/model-method-variational.Rd b/man/model-method-variational.Rd index 3678f11e..1b2d9a74 100644 --- a/man/model-method-variational.Rd +++ b/man/model-method-variational.Rd @@ -333,8 +333,8 @@ Other CmdStanModel methods: \code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, \code{\link{model-method-pathfinder}}, -\code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, +\code{\link{model-method-sample_mpi}}, \code{\link{model-method-variables}} } \concept{CmdStanModel methods} diff --git a/tests/testthat/resources/stan/parameter_types.stan b/tests/testthat/resources/stan/parameter_types.stan new file mode 100644 index 00000000..c54fede6 --- /dev/null +++ b/tests/testthat/resources/stan/parameter_types.stan @@ -0,0 +1,15 @@ +parameters { + real real_p; + vector[2] vector_p; + matrix[2, 2] matrix_p; + array[2] matrix[2, 2] array_matrix_p; + corr_matrix[2] corr_p; + array[2, 2] real array_array_real_p; + array[2, 2] vector[3] array_array_vector_p; + array[2, 2] matrix[3, 3] array_array_matrix_p; +// complex complex_p; +// complex_matrix[2, 2] complex_matrix_p; +// complex_vector[4] complex_vector_p; +// tuple(real, vector[3], array[2] matrix[2, 2], complex) tuple_int_vector_arraymatrix_complex_p; +// array[2] tuple(real, tuple(vector[2], array[2] tuple(real, complex, matrix[2, 2]))) arraytuple_big_p; +} diff --git a/tests/testthat/test-fit-init.R b/tests/testthat/test-fit-init.R new file mode 100644 index 00000000..2581bff1 --- /dev/null +++ b/tests/testthat/test-fit-init.R @@ -0,0 +1,69 @@ +library(cmdstanr) +set_cmdstan_path() + +mod_params <- testing_model("parameter_types") +mod_schools <- testing_model("schools") +mod_logistic <- testing_model("logistic") +data_list_schools <- testing_data("schools") +data_list_logistic <- testing_data("logistic") +test_inits <- function(mod, fit_init, data_list = NULL) { + utils::capture.output(fit_sample <- mod$sample(data = data_list, chains = 1, + init = fit_init, iter_sampling = 100, iter_warmup = 100, refresh = 0, + seed = 1234)) + utils::capture.output(fit_sample <- mod$sample(data = data_list, chains = 5, + init = fit_init, iter_sampling = 100, iter_warmup = 100, refresh = 0, + seed = 1234)) + utils::capture.output(fit_vb <- mod$variational(data = data_list, refresh = 0, + seed = 1234, init = fit_init, algorithm = "fullrank")) + utils::capture.output(fit_path <- mod$pathfinder(data = data_list, seed=1234, + refresh = 0, num_paths = 4, init = fit_init)) + utils::capture.output(fit_laplace <- mod$laplace(data = data_list, + seed = 1234, refresh=0, init=fit_init)) + utils::capture.output(fit_ml <- mod$optimize(data = data_list, seed = 1234, + refresh = 0, init = fit_init, history_size = 400, jacobian = TRUE, + algorithm = "lbfgs", tol_param = 1e-12, tol_rel_grad = 1e-12, + tol_grad = 1e-12, tol_rel_obj = 1e-12, tol_obj = 1e-12, init_alpha = 1e-4, + iter = 400)) + draws = posterior::as_draws_rvars(fit_sample$draws()) + utils::capture.output(fit_sample <- mod$sample(data = data_list, chains = 1, + init = draws, iter_sampling = 100, iter_warmup = 100, refresh = 0, + seed = 1234)) + return(0) +} + +test_that("Sample method works as init", { + utils::capture.output(fit_sample_init <- mod_params$sample(chains = 1, + iter_warmup = 100, iter_sampling = 100, refresh = 0, seed = 1234)) + utils::capture.output(fit_sample_multi_init <- mod_params$sample(chains = 4, init = fit_sample_init, + iter_warmup = 100, iter_sampling = 100, refresh = 0, seed = 1234)) + expect_no_error(test_inits(mod_params, fit_sample_init)) + expect_no_error(test_inits(mod_params, fit_sample_multi_init)) +}) + +test_that("Pathfinder method works as init", { + utils::capture.output(fit_path_init <- mod_params$pathfinder(seed=1234, + refresh = 0, num_paths = 4)) + expect_no_error(test_inits(mod_params, fit_path_init)) + utils::capture.output(fit_path_init <- mod_params$pathfinder(seed=1234, + refresh = 0, num_paths = 1)) + expect_no_error(test_inits(mod_params, fit_path_init)) +}) + +test_that("Laplace method works as init", { + utils::capture.output(fit_laplace_init <- mod_logistic$laplace( + data = data_list_logistic, seed = 1234, refresh=0)) + expect_no_error(test_inits(mod_logistic, fit_laplace_init, + data_list_logistic)) +}) + +test_that("Variational method works as init", { + utils::capture.output(fit_vb_init <- mod_logistic$variational( + data = data_list_logistic, seed=1234, refresh = 0)) + expect_no_error(test_inits(mod_logistic, fit_vb_init, data_list_logistic)) +}) + +test_that("Optimization method works as init", { + utils::capture.output(fit_ml_init <- mod_logistic$optimize( + data = data_list_logistic, seed=1234, refresh = 0)) + expect_no_error(test_inits(mod_logistic, fit_ml_init, data_list_logistic)) +})