From b65e8680732cfa079be118d0c397ed339db661da Mon Sep 17 00:00:00 2001 From: jgabry Date: Tue, 30 Jul 2024 14:32:42 -0600 Subject: [PATCH] improve stability of r_eff calculation for loo see https://github.com/stan-dev/loo/issues/272 --- R/fit.R | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/R/fit.R b/R/fit.R index 68206c75..30e7310a 100644 --- a/R/fit.R +++ b/R/fit.R @@ -574,9 +574,9 @@ unconstrain_draws <- function(files = NULL, draws = NULL, } else { draws <- self$draws(inc_warmup = inc_warmup) } - + draws <- maybe_convert_draws_format(draws, "draws_matrix") - + chains <- posterior::nchains(draws) model_par_names <- self$metadata()$stan_variables[self$metadata()$stan_variables != "lp__"] @@ -595,7 +595,7 @@ unconstrain_draws <- function(files = NULL, draws = NULL, uncon_names <- private$model_methods_env_$unconstrained_param_names(private$model_methods_env_$model_ptr_, FALSE, FALSE) names(unconstrained) <- repair_variable_names(uncon_names) unconstrained$.nchains <- chains - + do.call(function(...) { create_draws_format(format, ...) }, unconstrained) } CmdStanFit$set("public", name = "unconstrain_draws", value = unconstrain_draws) @@ -1539,7 +1539,7 @@ loo <- function(variables = "log_lik", r_eff = TRUE, moment_match = FALSE, ...) if (is.logical(r_eff)) { if (isTRUE(r_eff)) { r_eff_cores <- list(...)[["cores"]] %||% getOption("mc.cores", 1) - r_eff <- loo::relative_eff(exp(LLarray), cores = r_eff_cores) + r_eff <- loo::relative_eff(exp(LLarray + max(-LLarray)), cores = r_eff_cores) } else { r_eff <- NULL }