Skip to content

Commit

Permalink
Merge pull request #273 from stan-dev/avoid-under-and-overflows-in-st…
Browse files Browse the repository at this point in the history
…acking

avoid under and overflows in stacking
  • Loading branch information
jgabry authored Aug 5, 2024
2 parents b1f7a5a + 568f29b commit 6e7001e
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions R/loo_model_weights.R
Original file line number Diff line number Diff line change
Expand Up @@ -257,15 +257,12 @@ stacking_weights <-
stop("At least two models are required for stacking weights.")
}

exp_lpd_point <- exp(lpd_point)
negative_log_score_loo <- function(w) {
# objective function: log score
stopifnot(length(w) == K - 1)
w_full <- c(w, 1 - sum(w))
sum <- 0
for (i in 1:N) {
sum <- sum + log(exp(lpd_point[i, ]) %*% w_full)
}
# avoid over- and underflows using log weights and rowLogSumExps
sum <- sum(matrixStats::rowLogSumExps(sweep(lpd_point[1:N,], 2, log(w_full), '+')))
return(-as.numeric(sum))
}

Expand All @@ -274,11 +271,11 @@ stacking_weights <-
stopifnot(length(w) == K - 1)
w_full <- c(w, 1 - sum(w))
grad <- rep(0, K - 1)
# avoid over- and underflows using log weights, rowLogSumExps,
# and by subtracting the row maximum of lpd_point
mlpd <- matrixStats::rowMaxs(lpd_point)
for (k in 1:(K - 1)) {
for (i in 1:N) {
grad[k] <- grad[k] +
(exp_lpd_point[i, k] - exp_lpd_point[i, K]) / (exp_lpd_point[i,] %*% w_full)
}
grad[k] <- sum((exp(lpd_point[, k] - mlpd) - exp(lpd_point[, K] - mlpd)) / exp(matrixStats::rowLogSumExps(sweep(lpd_point, 2, log(w_full), '+')) - mlpd))
}
return(-grad)
}
Expand Down

0 comments on commit 6e7001e

Please sign in to comment.