Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve efficiency of model methods, tidy code #960

Merged
merged 3 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 7 additions & 20 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -518,12 +518,8 @@ unconstrain_variables <- function(variables) {
" not provided!", call. = FALSE)
}

# Remove zero-length parameters from model_variables, otherwise process_init
# warns about missing inputs
model_variables$parameters <- model_variables$parameters[nonzero_length_params]

stan_pars <- process_init(list(variables), num_procs = 1, model_variables)
private$model_methods_env_$unconstrain_variables(private$model_methods_env_$model_ptr_, stan_pars)
variables_vector <- unlist(variables[model_par_names], recursive = TRUE)
private$model_methods_env_$unconstrain_variables(private$model_methods_env_$model_ptr_, variables_vector)
}
CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_variables)

Expand Down Expand Up @@ -571,11 +567,11 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
call. = FALSE)
}
if (!is.null(files)) {
read_csv <- read_cmdstan_csv(files = files, format = "draws_df")
read_csv <- read_cmdstan_csv(files = files, format = "draws_matrix")
draws <- read_csv$post_warmup_draws
}
if (!is.null(draws)) {
draws <- maybe_convert_draws_format(draws, "draws_df")
draws <- maybe_convert_draws_format(draws, "draws_matrix")
}
} else {
if (is.null(private$draws_)) {
Expand All @@ -584,7 +580,7 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
}
private$read_csv_(format = "draws_df")
}
draws <- maybe_convert_draws_format(private$draws_, "draws_df")
draws <- maybe_convert_draws_format(private$draws_, "draws_matrix")
}

model_par_names <- self$metadata()$stan_variables[self$metadata()$stan_variables != "lp__"]
Expand All @@ -599,19 +595,10 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
pars <- names(model_variables$parameters[nonzero_length_params])

draws <- posterior::subset_draws(draws, variable = pars)
skeleton <- self$variable_skeleton(transformed_parameters = FALSE,
generated_quantities = FALSE)
par_columns <- !(names(draws) %in% c(".chain", ".iteration", ".draw"))
meta_columns <- !par_columns
unconstrained <- lapply(asplit(draws, 1), function(draw) {
par_list <- utils::relist(as.numeric(draw[par_columns]), skeleton)
self$unconstrain_variables(variables = par_list)
})

unconstrained <- do.call(rbind.data.frame, unconstrained)
unconstrained <- private$model_methods_env_$unconstrain_draws(private$model_methods_env_$model_ptr_, draws)
uncon_names <- private$model_methods_env_$unconstrained_param_names(private$model_methods_env_$model_ptr_, FALSE, FALSE)
names(unconstrained) <- repair_variable_names(uncon_names)
maybe_convert_draws_format(cbind.data.frame(unconstrained, draws[,meta_columns]), format)
maybe_convert_draws_format(unconstrained, format, .nchains = posterior::nchains(draws))
}
CmdStanFit$set("public", name = "unconstrain_draws", value = unconstrain_draws)

Expand Down
7 changes: 6 additions & 1 deletion R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -465,10 +465,10 @@ compile <- function(quiet = TRUE,
stanc_options = list(),
force_recompile = getOption("cmdstanr_force_recompile", default = FALSE),
compile_model_methods = FALSE,
compile_hessian_method = FALSE,
compile_standalone = FALSE,
dry_run = FALSE,
#deprecated
compile_hessian_method = FALSE,
threads = FALSE) {

if (length(self$stan_file()) == 0) {
Expand Down Expand Up @@ -505,6 +505,11 @@ compile <- function(quiet = TRUE,
cpp_options[["stan_threads"]] <- TRUE
}

# temporary deprecation warnings
if (isTRUE(compile_hessian_method)) {
warning("'compile_hessian_method' is deprecated. The hessian method is compiled with all models.")
}

if (length(self$exe_file()) == 0) {
if (is.null(dir)) {
exe_base <- self$stan_file()
Expand Down
21 changes: 7 additions & 14 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -409,19 +409,19 @@ valid_draws_formats <- function() {
"draws_rvars", "rvars")
}

maybe_convert_draws_format <- function(draws, format) {
maybe_convert_draws_format <- function(draws, format, ...) {
if (is.null(draws)) {
return(draws)
}
format <- sub("^draws_", "", format)
switch(
format,
"array" = posterior::as_draws_array(draws),
"df" = posterior::as_draws_df(draws),
"data.frame" = posterior::as_draws_df(draws),
"list" = posterior::as_draws_list(draws),
"matrix" = posterior::as_draws_matrix(draws),
"rvars" = posterior::as_draws_rvars(draws),
"array" = posterior::as_draws_array(draws, ...),
"df" = posterior::as_draws_df(draws, ...),
"data.frame" = posterior::as_draws_df(draws, ...),
"list" = posterior::as_draws_list(draws, ...),
"matrix" = posterior::as_draws_matrix(draws, ...),
"rvars" = posterior::as_draws_rvars(draws, ...),
stop("Invalid draws format.", call. = FALSE)
)
}
Expand Down Expand Up @@ -757,13 +757,6 @@ expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
readLines(system.file("include", "model_methods.cpp",
package = "cmdstanr", mustWork = TRUE)))

if (hessian) {
code <- c("#include <stan/math/mix.hpp>",
code,
readLines(system.file("include", "hessian.cpp",
package = "cmdstanr", mustWork = TRUE)))
}

code <- paste(code, collapse = "\n")
rcpp_source_stan(code, env, verbose)
invisible(NULL)
Expand Down
41 changes: 0 additions & 41 deletions inst/include/hessian.cpp

This file was deleted.

75 changes: 55 additions & 20 deletions inst/include/model_methods.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <Rcpp.h>
#include <rcpp_eigen_interop.hpp>
#include <stan/model/model_base.hpp>
#include <stan/model/log_prob_grad.hpp>
#include <stan/model/log_prob_propto.hpp>
Expand Down Expand Up @@ -26,10 +27,14 @@ using json_data_t = stan::json::json_data;
return std::make_shared<json_data_t>(data_context);
}

stan::model::model_base&
new_model(stan::io::var_context& data_context, unsigned int seed,
std::ostream* msg_stream);

// [[Rcpp::export]]
Rcpp::List model_ptr(std::string data_path, boost::uint32_t seed) {
Rcpp::XPtr<stan_model> ptr(
new stan_model(*var_context(data_path), seed, &Rcpp::Rcout)
Rcpp::XPtr<stan::model::model_base> ptr(
&new_model(*var_context(data_path), seed, &Rcpp::Rcout)
);
Rcpp::XPtr<boost::ecuyer1988> base_rng(new boost::ecuyer1988(seed));
return Rcpp::List::create(
Expand All @@ -39,37 +44,56 @@ Rcpp::List model_ptr(std::string data_path, boost::uint32_t seed) {
}

// [[Rcpp::export]]
double log_prob(SEXP ext_model_ptr, std::vector<double> upars,
bool jac_adjust) {
double log_prob(SEXP ext_model_ptr, Eigen::VectorXd upars, bool jac_adjust) {
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
std::vector<int> params_i;
if (jac_adjust) {
return stan::model::log_prob_propto<true>(*ptr.get(), upars, params_i, &Rcpp::Rcout);
return stan::model::log_prob_propto<true>(*ptr.get(), upars, &Rcpp::Rcout);
} else {
return stan::model::log_prob_propto<false>(*ptr.get(), upars, params_i, &Rcpp::Rcout);
return stan::model::log_prob_propto<false>(*ptr.get(), upars, &Rcpp::Rcout);
}
}

// [[Rcpp::export]]
Rcpp::NumericVector grad_log_prob(SEXP ext_model_ptr, std::vector<double> upars,
Rcpp::NumericVector grad_log_prob(SEXP ext_model_ptr, Eigen::VectorXd upars,
bool jac_adjust) {
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
std::vector<double> gradients;
std::vector<int> params_i;
Eigen::VectorXd gradients;

double lp;
if (jac_adjust) {
lp = stan::model::log_prob_grad<true, true>(
*ptr.get(), upars, params_i, gradients);
lp = stan::model::log_prob_grad<true, true>(*ptr.get(), upars, gradients);
} else {
lp = stan::model::log_prob_grad<true, false>(
*ptr.get(), upars, params_i, gradients);
lp = stan::model::log_prob_grad<true, false>(*ptr.get(), upars, gradients);
}
Rcpp::NumericVector grad_rtn = Rcpp::wrap(gradients);
Rcpp::NumericVector grad_rtn(Rcpp::wrap(std::move(gradients)));
grad_rtn.attr("log_prob") = lp;
return grad_rtn;
}

// [[Rcpp::export]]
Rcpp::List hessian(SEXP ext_model_ptr, Eigen::VectorXd upars, bool jacobian) {
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);

auto hessian_functor = [&](auto&& x) {
if (jacobian) {
return ptr->log_prob<true, true>(x, 0);
} else {
return ptr->log_prob<true, false>(x, 0);
}
};

double log_prob;
Eigen::VectorXd grad;
Eigen::MatrixXd hessian;

stan::math::internal::finite_diff_hessian_auto(hessian_functor, upars, log_prob, grad, hessian);

return Rcpp::List::create(
Rcpp::Named("log_prob") = log_prob,
Rcpp::Named("grad_log_prob") = grad,
Rcpp::Named("hessian") = hessian);
}

// [[Rcpp::export]]
size_t get_num_upars(SEXP ext_model_ptr) {
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
Expand All @@ -95,12 +119,23 @@ Rcpp::List get_param_metadata(SEXP ext_model_ptr) {
}

// [[Rcpp::export]]
std::vector<double> unconstrain_variables(SEXP ext_model_ptr, std::string init_path) {
Eigen::VectorXd unconstrain_variables(SEXP ext_model_ptr, Eigen::VectorXd variables) {
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
std::vector<int> params_i;
std::vector<double> vars;
ptr->transform_inits(*var_context(init_path), params_i, vars, &Rcpp::Rcout);
return vars;
Eigen::VectorXd unconstrained_variables;
ptr->unconstrain_array(variables, unconstrained_variables, &Rcpp::Rcout);
return unconstrained_variables;
}

// [[Rcpp::export]]
Eigen::MatrixXd unconstrain_draws(SEXP ext_model_ptr, Eigen::MatrixXd variables) {
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
Eigen::MatrixXd unconstrained_draws(variables.cols(), variables.rows());
for (int i = 0; i < variables.rows(); i++) {
Eigen::VectorXd unconstrained_variables;
ptr->unconstrain_array(variables.transpose().col(i), unconstrained_variables, &Rcpp::Rcout);
unconstrained_draws.col(i) = unconstrained_variables;
}
return unconstrained_draws.transpose();
}

// [[Rcpp::export]]
Expand Down
8 changes: 4 additions & 4 deletions man/model-method-compile.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 1 addition & 8 deletions tests/testthat/test-model-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,10 @@ test_that("Methods error if not compiled", {
)
})

test_that("User warned about higher-order autodiff with hessian", {
skip_if(os_is_wsl())
expect_message(
fit$init_model_methods(hessian = TRUE, verbose = TRUE),
"The hessian method relies on higher-order autodiff which is still experimental. Please report any compilation errors that you encounter",
fixed = TRUE
)
})

test_that("Methods return correct values", {
skip_if(os_is_wsl())
fit$init_model_methods(verbose = TRUE)
lp <- fit$log_prob(unconstrained_variables=c(0.1))
expect_equal(lp, -8.6327599208828509347)

Expand Down
Loading