Skip to content

Commit

Permalink
Merge pull request #960 from stan-dev/model-methods-speedup
Browse files Browse the repository at this point in the history
Improve efficiency of model methods, tidy code
  • Loading branch information
andrjohns authored May 5, 2024
2 parents 15aa9d9 + fa1b10b commit c4d6e80
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 108 deletions.
27 changes: 7 additions & 20 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -518,12 +518,8 @@ unconstrain_variables <- function(variables) {
" not provided!", call. = FALSE)
}

# Remove zero-length parameters from model_variables, otherwise process_init
# warns about missing inputs
model_variables$parameters <- model_variables$parameters[nonzero_length_params]

stan_pars <- process_init(list(variables), num_procs = 1, model_variables)
private$model_methods_env_$unconstrain_variables(private$model_methods_env_$model_ptr_, stan_pars)
variables_vector <- unlist(variables[model_par_names], recursive = TRUE)
private$model_methods_env_$unconstrain_variables(private$model_methods_env_$model_ptr_, variables_vector)
}
CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_variables)

Expand Down Expand Up @@ -571,11 +567,11 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
call. = FALSE)
}
if (!is.null(files)) {
read_csv <- read_cmdstan_csv(files = files, format = "draws_df")
read_csv <- read_cmdstan_csv(files = files, format = "draws_matrix")
draws <- read_csv$post_warmup_draws
}
if (!is.null(draws)) {
draws <- maybe_convert_draws_format(draws, "draws_df")
draws <- maybe_convert_draws_format(draws, "draws_matrix")
}
} else {
if (is.null(private$draws_)) {
Expand All @@ -584,7 +580,7 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
}
private$read_csv_(format = "draws_df")
}
draws <- maybe_convert_draws_format(private$draws_, "draws_df")
draws <- maybe_convert_draws_format(private$draws_, "draws_matrix")
}

model_par_names <- self$metadata()$stan_variables[self$metadata()$stan_variables != "lp__"]
Expand All @@ -599,19 +595,10 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
pars <- names(model_variables$parameters[nonzero_length_params])

draws <- posterior::subset_draws(draws, variable = pars)
skeleton <- self$variable_skeleton(transformed_parameters = FALSE,
generated_quantities = FALSE)
par_columns <- !(names(draws) %in% c(".chain", ".iteration", ".draw"))
meta_columns <- !par_columns
unconstrained <- lapply(asplit(draws, 1), function(draw) {
par_list <- utils::relist(as.numeric(draw[par_columns]), skeleton)
self$unconstrain_variables(variables = par_list)
})

unconstrained <- do.call(rbind.data.frame, unconstrained)
unconstrained <- private$model_methods_env_$unconstrain_draws(private$model_methods_env_$model_ptr_, draws)
uncon_names <- private$model_methods_env_$unconstrained_param_names(private$model_methods_env_$model_ptr_, FALSE, FALSE)
names(unconstrained) <- repair_variable_names(uncon_names)
maybe_convert_draws_format(cbind.data.frame(unconstrained, draws[,meta_columns]), format)
maybe_convert_draws_format(unconstrained, format, .nchains = posterior::nchains(draws))
}
CmdStanFit$set("public", name = "unconstrain_draws", value = unconstrain_draws)

Expand Down
7 changes: 6 additions & 1 deletion R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -465,10 +465,10 @@ compile <- function(quiet = TRUE,
stanc_options = list(),
force_recompile = getOption("cmdstanr_force_recompile", default = FALSE),
compile_model_methods = FALSE,
compile_hessian_method = FALSE,
compile_standalone = FALSE,
dry_run = FALSE,
#deprecated
compile_hessian_method = FALSE,
threads = FALSE) {

if (length(self$stan_file()) == 0) {
Expand Down Expand Up @@ -505,6 +505,11 @@ compile <- function(quiet = TRUE,
cpp_options[["stan_threads"]] <- TRUE
}

# temporary deprecation warnings
if (isTRUE(compile_hessian_method)) {
warning("'compile_hessian_method' is deprecated. The hessian method is compiled with all models.")
}

if (length(self$exe_file()) == 0) {
if (is.null(dir)) {
exe_base <- self$stan_file()
Expand Down
21 changes: 7 additions & 14 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -409,19 +409,19 @@ valid_draws_formats <- function() {
"draws_rvars", "rvars")
}

maybe_convert_draws_format <- function(draws, format) {
maybe_convert_draws_format <- function(draws, format, ...) {
if (is.null(draws)) {
return(draws)
}
format <- sub("^draws_", "", format)
switch(
format,
"array" = posterior::as_draws_array(draws),
"df" = posterior::as_draws_df(draws),
"data.frame" = posterior::as_draws_df(draws),
"list" = posterior::as_draws_list(draws),
"matrix" = posterior::as_draws_matrix(draws),
"rvars" = posterior::as_draws_rvars(draws),
"array" = posterior::as_draws_array(draws, ...),
"df" = posterior::as_draws_df(draws, ...),
"data.frame" = posterior::as_draws_df(draws, ...),
"list" = posterior::as_draws_list(draws, ...),
"matrix" = posterior::as_draws_matrix(draws, ...),
"rvars" = posterior::as_draws_rvars(draws, ...),
stop("Invalid draws format.", call. = FALSE)
)
}
Expand Down Expand Up @@ -757,13 +757,6 @@ expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
readLines(system.file("include", "model_methods.cpp",
package = "cmdstanr", mustWork = TRUE)))

if (hessian) {
code <- c("#include <stan/math/mix.hpp>",
code,
readLines(system.file("include", "hessian.cpp",
package = "cmdstanr", mustWork = TRUE)))
}

code <- paste(code, collapse = "\n")
rcpp_source_stan(code, env, verbose)
invisible(NULL)
Expand Down
41 changes: 0 additions & 41 deletions inst/include/hessian.cpp

This file was deleted.

75 changes: 55 additions & 20 deletions inst/include/model_methods.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <Rcpp.h>
#include <rcpp_eigen_interop.hpp>
#include <stan/model/model_base.hpp>
#include <stan/model/log_prob_grad.hpp>
#include <stan/model/log_prob_propto.hpp>
Expand Down Expand Up @@ -26,10 +27,14 @@ using json_data_t = stan::json::json_data;
return std::make_shared<json_data_t>(data_context);
}

stan::model::model_base&
new_model(stan::io::var_context& data_context, unsigned int seed,
std::ostream* msg_stream);

// [[Rcpp::export]]
Rcpp::List model_ptr(std::string data_path, boost::uint32_t seed) {
Rcpp::XPtr<stan_model> ptr(
new stan_model(*var_context(data_path), seed, &Rcpp::Rcout)
Rcpp::XPtr<stan::model::model_base> ptr(
&new_model(*var_context(data_path), seed, &Rcpp::Rcout)
);
Rcpp::XPtr<boost::ecuyer1988> base_rng(new boost::ecuyer1988(seed));
return Rcpp::List::create(
Expand All @@ -39,37 +44,56 @@ Rcpp::List model_ptr(std::string data_path, boost::uint32_t seed) {
}

// [[Rcpp::export]]
double log_prob(SEXP ext_model_ptr, std::vector<double> upars,
bool jac_adjust) {
double log_prob(SEXP ext_model_ptr, Eigen::VectorXd upars, bool jac_adjust) {
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
std::vector<int> params_i;
if (jac_adjust) {
return stan::model::log_prob_propto<true>(*ptr.get(), upars, params_i, &Rcpp::Rcout);
return stan::model::log_prob_propto<true>(*ptr.get(), upars, &Rcpp::Rcout);
} else {
return stan::model::log_prob_propto<false>(*ptr.get(), upars, params_i, &Rcpp::Rcout);
return stan::model::log_prob_propto<false>(*ptr.get(), upars, &Rcpp::Rcout);
}
}

// [[Rcpp::export]]
Rcpp::NumericVector grad_log_prob(SEXP ext_model_ptr, std::vector<double> upars,
Rcpp::NumericVector grad_log_prob(SEXP ext_model_ptr, Eigen::VectorXd upars,
bool jac_adjust) {
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
std::vector<double> gradients;
std::vector<int> params_i;
Eigen::VectorXd gradients;

double lp;
if (jac_adjust) {
lp = stan::model::log_prob_grad<true, true>(
*ptr.get(), upars, params_i, gradients);
lp = stan::model::log_prob_grad<true, true>(*ptr.get(), upars, gradients);
} else {
lp = stan::model::log_prob_grad<true, false>(
*ptr.get(), upars, params_i, gradients);
lp = stan::model::log_prob_grad<true, false>(*ptr.get(), upars, gradients);
}
Rcpp::NumericVector grad_rtn = Rcpp::wrap(gradients);
Rcpp::NumericVector grad_rtn(Rcpp::wrap(std::move(gradients)));
grad_rtn.attr("log_prob") = lp;
return grad_rtn;
}

// [[Rcpp::export]]
Rcpp::List hessian(SEXP ext_model_ptr, Eigen::VectorXd upars, bool jacobian) {
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);

auto hessian_functor = [&](auto&& x) {
if (jacobian) {
return ptr->log_prob<true, true>(x, 0);
} else {
return ptr->log_prob<true, false>(x, 0);
}
};

double log_prob;
Eigen::VectorXd grad;
Eigen::MatrixXd hessian;

stan::math::internal::finite_diff_hessian_auto(hessian_functor, upars, log_prob, grad, hessian);

return Rcpp::List::create(
Rcpp::Named("log_prob") = log_prob,
Rcpp::Named("grad_log_prob") = grad,
Rcpp::Named("hessian") = hessian);
}

// [[Rcpp::export]]
size_t get_num_upars(SEXP ext_model_ptr) {
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
Expand All @@ -95,12 +119,23 @@ Rcpp::List get_param_metadata(SEXP ext_model_ptr) {
}

// [[Rcpp::export]]
std::vector<double> unconstrain_variables(SEXP ext_model_ptr, std::string init_path) {
Eigen::VectorXd unconstrain_variables(SEXP ext_model_ptr, Eigen::VectorXd variables) {
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
std::vector<int> params_i;
std::vector<double> vars;
ptr->transform_inits(*var_context(init_path), params_i, vars, &Rcpp::Rcout);
return vars;
Eigen::VectorXd unconstrained_variables;
ptr->unconstrain_array(variables, unconstrained_variables, &Rcpp::Rcout);
return unconstrained_variables;
}

// [[Rcpp::export]]
Eigen::MatrixXd unconstrain_draws(SEXP ext_model_ptr, Eigen::MatrixXd variables) {
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
Eigen::MatrixXd unconstrained_draws(variables.cols(), variables.rows());
for (int i = 0; i < variables.rows(); i++) {
Eigen::VectorXd unconstrained_variables;
ptr->unconstrain_array(variables.transpose().col(i), unconstrained_variables, &Rcpp::Rcout);
unconstrained_draws.col(i) = unconstrained_variables;
}
return unconstrained_draws.transpose();
}

// [[Rcpp::export]]
Expand Down
8 changes: 4 additions & 4 deletions man/model-method-compile.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 1 addition & 8 deletions tests/testthat/test-model-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,10 @@ test_that("Methods error if not compiled", {
)
})

test_that("User warned about higher-order autodiff with hessian", {
skip_if(os_is_wsl())
expect_message(
fit$init_model_methods(hessian = TRUE, verbose = TRUE),
"The hessian method relies on higher-order autodiff which is still experimental. Please report any compilation errors that you encounter",
fixed = TRUE
)
})

test_that("Methods return correct values", {
skip_if(os_is_wsl())
fit$init_model_methods(verbose = TRUE)
lp <- fit$log_prob(unconstrained_variables=c(0.1))
expect_equal(lp, -8.6327599208828509347)

Expand Down

0 comments on commit c4d6e80

Please sign in to comment.