Skip to content

Commit

Permalink
Add handling for variable order, fix chains return in draws
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns committed May 5, 2024
1 parent d34b77e commit bf397e4
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
4 changes: 2 additions & 2 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ unconstrain_variables <- function(variables) {
" not provided!", call. = FALSE)
}

variables_vector <- unlist(variables, recursive = TRUE, use.names = FALSE)
variables_vector <- unlist(variables[model_par_names], recursive = TRUE)
private$model_methods_env_$unconstrain_variables(private$model_methods_env_$model_ptr_, variables_vector)
}
CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_variables)
Expand Down Expand Up @@ -598,7 +598,7 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
unconstrained <- private$model_methods_env_$unconstrain_draws(private$model_methods_env_$model_ptr_, draws)
uncon_names <- private$model_methods_env_$unconstrained_param_names(private$model_methods_env_$model_ptr_, FALSE, FALSE)
names(unconstrained) <- repair_variable_names(uncon_names)
maybe_convert_draws_format(unconstrained, format)
maybe_convert_draws_format(unconstrained, format, .nchains = posterior::nchains(draws))
}
CmdStanFit$set("public", name = "unconstrain_draws", value = unconstrain_draws)

Expand Down
7 changes: 6 additions & 1 deletion R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -465,10 +465,10 @@ compile <- function(quiet = TRUE,
stanc_options = list(),
force_recompile = getOption("cmdstanr_force_recompile", default = FALSE),
compile_model_methods = FALSE,
compile_hessian_method = FALSE,
compile_standalone = FALSE,
dry_run = FALSE,
#deprecated
compile_hessian_method = FALSE,
threads = FALSE) {

if (length(self$stan_file()) == 0) {
Expand Down Expand Up @@ -505,6 +505,11 @@ compile <- function(quiet = TRUE,
cpp_options[["stan_threads"]] <- TRUE
}

# temporary deprecation warnings
if (isTRUE(compile_hessian_method)) {
warning("'compile_hessian_method' is deprecated. The hessian method is compiled with all models.")
}

if (length(self$exe_file()) == 0) {
if (is.null(dir)) {
exe_base <- self$stan_file()
Expand Down
14 changes: 7 additions & 7 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -409,19 +409,19 @@ valid_draws_formats <- function() {
"draws_rvars", "rvars")
}

maybe_convert_draws_format <- function(draws, format) {
maybe_convert_draws_format <- function(draws, format, ...) {
if (is.null(draws)) {
return(draws)
}
format <- sub("^draws_", "", format)
switch(
format,
"array" = posterior::as_draws_array(draws),
"df" = posterior::as_draws_df(draws),
"data.frame" = posterior::as_draws_df(draws),
"list" = posterior::as_draws_list(draws),
"matrix" = posterior::as_draws_matrix(draws),
"rvars" = posterior::as_draws_rvars(draws),
"array" = posterior::as_draws_array(draws, ...),
"df" = posterior::as_draws_df(draws, ...),
"data.frame" = posterior::as_draws_df(draws, ...),
"list" = posterior::as_draws_list(draws, ...),
"matrix" = posterior::as_draws_matrix(draws, ...),
"rvars" = posterior::as_draws_rvars(draws, ...),
stop("Invalid draws format.", call. = FALSE)
)
}
Expand Down

0 comments on commit bf397e4

Please sign in to comment.