Skip to content

Commit

Permalink
improve stability of r_eff calculation for loo
Browse files Browse the repository at this point in the history
  • Loading branch information
jgabry committed Jul 30, 2024
1 parent c549ae6 commit b65e868
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -574,9 +574,9 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
} else {
draws <- self$draws(inc_warmup = inc_warmup)
}

draws <- maybe_convert_draws_format(draws, "draws_matrix")

chains <- posterior::nchains(draws)

model_par_names <- self$metadata()$stan_variables[self$metadata()$stan_variables != "lp__"]
Expand All @@ -595,7 +595,7 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
uncon_names <- private$model_methods_env_$unconstrained_param_names(private$model_methods_env_$model_ptr_, FALSE, FALSE)
names(unconstrained) <- repair_variable_names(uncon_names)
unconstrained$.nchains <- chains

do.call(function(...) { create_draws_format(format, ...) }, unconstrained)
}
CmdStanFit$set("public", name = "unconstrain_draws", value = unconstrain_draws)
Expand Down Expand Up @@ -1539,7 +1539,7 @@ loo <- function(variables = "log_lik", r_eff = TRUE, moment_match = FALSE, ...)
if (is.logical(r_eff)) {
if (isTRUE(r_eff)) {
r_eff_cores <- list(...)[["cores"]] %||% getOption("mc.cores", 1)
r_eff <- loo::relative_eff(exp(LLarray), cores = r_eff_cores)
r_eff <- loo::relative_eff(exp(LLarray + max(-LLarray)), cores = r_eff_cores)
} else {
r_eff <- NULL
}
Expand Down

0 comments on commit b65e868

Please sign in to comment.