Skip to content

Commit

Permalink
Fix a bug in get_stat()'s handling of NAs in case of
Browse files Browse the repository at this point in the history
subsampled PSIS-LOO CV (`nloo`) with `deltas = TRUE` and `baseline = "best"`.

For `baseline = "ref"`, this is only a refactor improving the safety and
readability of `get_stat()`'s handling of `NA`s because for `baseline = "ref"`,
we should always have no `NA`s in `lppd.bs` and `mu.bs`, so in that case,
`n_notna` did not require an adjustment and also because the math operations
connecting `mu` and `mu.bs` (analogously for `lppd` and `lppd.bs`) ensured that
only the "inner join" of non-`NA` elements (i.e., the set of observations for
which both `mu` and `mu.bs` (analogously for `lppd` and `lppd.bs`) are not `NA`)
is used.

This addresses question 1 from
<stan-dev#94 (comment)>.
  • Loading branch information
fweber144 committed Nov 9, 2023
1 parent a976fd2 commit 7232148
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions R/summary_funs.R
Original file line number Diff line number Diff line change
Expand Up @@ -287,19 +287,30 @@ get_stat <- function(mu, lppd, y_wobs_test, stat, mu.bs = NULL, lppd.bs = NULL,
wcv = NULL, alpha = 0.1, ...) {
n_notna.bs <- NULL
if (stat %in% c("mlpd", "elpd")) {
n <- length(lppd)
n_notna <- sum(!is.na(lppd))
if (!is.null(lppd.bs)) {
# Compute the performance statistics using only those observations for
# which both `lppd` and `lppd.bs` are not `NA`:
lppd[is.na(lppd.bs)] <- NA
lppd.bs[is.na(lppd)] <- NA
n_notna.bs <- sum(!is.na(lppd.bs))
}
n_notna <- sum(!is.na(lppd))
n <- length(lppd)
} else {
n <- length(mu)
n_notna <- sum(!is.na(mu) & !is.na(y_wobs_test$y_prop %||% y_wobs_test$y))
if (!is.null(mu.bs)) {
# Compute the performance statistics using only those observations for
# which both `mu` and `mu.bs` are not `NA`:
mu[is.na(mu.bs)] <- NA
mu.bs[is.na(mu)] <- NA
n_notna.bs <- sum(!is.na(mu.bs))
}
n_notna <- sum(!is.na(mu) & !is.na(y_wobs_test$y_prop %||% y_wobs_test$y))
n <- length(mu)
}
if (n_notna == 0 || (!is.null(n_notna.bs) && n_notna.bs == 0)) {
if (!is.null(n_notna.bs) && getOption("projpred.additional_checks", FALSE)) {
stopifnot(n_notna == n_notna.bs)
}
if (n_notna == 0) {
return(list(value = NA, se = NA, lq = NA, uq = NA))
}

Expand All @@ -312,6 +323,7 @@ get_stat <- function(mu, lppd, y_wobs_test, stat, mu.bs = NULL, lppd.bs = NULL,

alpha_half <- alpha / 2
one_minus_alpha_half <- 1 - alpha_half

if (stat %in% c("mlpd", "elpd")) {
if (!is.null(lppd.bs)) {
value <- sum((lppd - lppd.bs) * wcv, na.rm = TRUE)
Expand Down Expand Up @@ -349,10 +361,6 @@ get_stat <- function(mu, lppd, y_wobs_test, stat, mu.bs = NULL, lppd.bs = NULL,
}
} else if (stat == "rmse") {
if (!is.null(mu.bs)) {
# Compute the RMSEs using only those observations for which both `mu`
# and `mu.bs` are not `NA`:
mu[is.na(mu.bs)] <- NA
mu.bs[is.na(mu)] <- NA
value <- sqrt(mean(wcv * (mu - y)^2, na.rm = TRUE)) -
sqrt(mean(wcv * (mu.bs - y)^2, na.rm = TRUE))
diffvalue.bootstrap <- bootstrap(
Expand Down Expand Up @@ -428,10 +436,6 @@ get_stat <- function(mu, lppd, y_wobs_test, stat, mu.bs = NULL, lppd.bs = NULL,
}
} else if (stat == "auc") {
if (!is.null(mu.bs)) {
# Compute the AUCs using only those observations for which both `mu` and
# `mu.bs` are not `NA`:
mu[is.na(mu.bs)] <- NA
mu.bs[is.na(mu)] <- NA
auc.data <- cbind(y, mu, wcv)
auc.data.bs <- cbind(y, mu.bs, wcv)
value <- auc(auc.data) - auc(auc.data.bs)
Expand Down

0 comments on commit 7232148

Please sign in to comment.