Skip to content

Commit

Permalink
Refactors (warning for cv_search, solution_terms, term fit) (#263)
Browse files Browse the repository at this point in the history
  • Loading branch information
fweber144 authored Jan 8, 2022
1 parent b894b36 commit d20b542
Show file tree
Hide file tree
Showing 16 changed files with 201 additions and 187 deletions.
22 changes: 12 additions & 10 deletions R/cv_varsel.R
Original file line number Diff line number Diff line change
Expand Up @@ -372,12 +372,12 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,
nterms_max = nterms_max, penalty = penalty, verbose = FALSE, opt = opt,
search_terms = search_terms
)
solution_terms <- search_path$solution_terms

## project onto the selected models and compute the prediction accuracy for
## the full data
submodels <- .get_submodels(
search_path = search_path, nterms = c(0, seq_along(solution_terms)),
search_path = search_path,
nterms = c(0, seq_along(search_path$solution_terms)),
p_ref = p_pred, refmodel = refmodel, regul = opt$regul,
cv_search = cv_search
)
Expand All @@ -392,7 +392,7 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,

## compute approximate LOO with PSIS weights
for (k in seq_along(submodels)) {
mu_k <- refmodel$family$mu_fun(submodels[[k]]$sub_fit,
mu_k <- refmodel$family$mu_fun(submodels[[k]]$submodl,
obs = inds,
offset = refmodel$offset[inds])
log_lik_sub <- t(refmodel$family$ll_fun(
Expand All @@ -418,13 +418,14 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,
data = refmodel$fetch_data(),
add_main_effects = FALSE)
## with `match` we get the indices of the variables as they enter the
## solution path in solution_terms
solution <- match(solution_terms, setdiff(candidate_terms, "1"))
## solution path in `search_path$solution_terms`
solution <- match(search_path$solution_terms,
setdiff(candidate_terms, "1"))
for (i in seq_len(n)) {
solution_terms_mat[i, seq_along(solution)] <- solution
}
sel <- nlist(search_path, kl = sapply(submodels, function(x) x$kl),
solution_terms)
solution_terms = search_path$solution_terms)
} else {
if (verbose) {
print(msg)
Expand All @@ -451,12 +452,12 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,
nterms_max = nterms_max, penalty = penalty, verbose = FALSE, opt = opt,
search_terms = search_terms
)
solution_terms <- search_path$solution_terms

## project onto the selected models and compute the prediction accuracy
## for the left-out point
submodels <- .get_submodels(
search_path = search_path, nterms = c(0, seq_along(solution_terms)),
search_path = search_path,
nterms = c(0, seq_along(search_path$solution_terms)),
p_ref = p_pred, refmodel = refmodel, regul = opt$regul,
cv_search = cv_search
)
Expand All @@ -472,8 +473,9 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,
data = refmodel$fetch_data(),
add_main_effects = FALSE)
## with `match` we get the indices of the variables as they enter the
## solution path in solution_terms
solution <- match(solution_terms, setdiff(candidate_terms, "1"))
## solution path in `search_path$solution_terms`
solution <- match(search_path$solution_terms,
setdiff(candidate_terms, "1"))
solution_terms_mat[i, seq_along(solution)] <- solution

if (verbose) {
Expand Down
8 changes: 6 additions & 2 deletions R/divergence_minimizers.R
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ check_conv <- function(fit) {

# Prediction functions for submodels --------------------------------------

subprd <- function(fit, newdata) {
return(do.call(cbind, lapply(fit, function(fit) {
subprd <- function(fits, newdata) {
return(do.call(cbind, lapply(fits, function(fit) {
# Only pass argument `allow.new.levels` to the predict() generic if the fit
# is multilevel:
has_grp <- inherits(fit, c("lmerMod", "glmerMod"))
Expand Down Expand Up @@ -446,6 +446,10 @@ predict.subfit <- function(subfit, newdata = NULL) {
if (is.null(beta)) {
return(as.matrix(rep(alpha, NROW(x))))
} else {
if (ncol(x) != length(beta) + 1L) {
stop("The number of columns in the model matrix (\"X\") doesn't match ",
"the number of coefficients.")
}
return(x %*% rbind(alpha, beta))
}
}
Expand Down
11 changes: 2 additions & 9 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,6 @@ proj_helper <- function(object, newdata,
count_terms_chosen(proj$solution_terms, add_icpt = TRUE)
})

solution_terms <- list(...)$solution_terms
if (!is.null(solution_terms) &&
length(solution_terms) > NCOL(newdata)) {
stop("The number of solution terms is greater than the number of columns ",
"in `newdata`.")
}

preds <- lapply(projs, function(proj) {
w_o <- proj$refmodel$extract_model_data(
proj$refmodel$fit, newdata = newdata, wrhs = weightsnew, orhs = offsetnew,
Expand All @@ -179,7 +172,7 @@ proj_helper <- function(object, newdata,
if (length(offsetnew) == 0) {
offsetnew <- rep(0, NROW(newdata))
}
mu <- proj$refmodel$family$mu_fun(proj$sub_fit,
mu <- proj$refmodel$family$mu_fun(proj$submodl,
newdata = newdata, offset = offsetnew)
onesub_fun(proj, mu, weightsnew,
offset = offsetnew, newdata = newdata,
Expand Down Expand Up @@ -1053,7 +1046,7 @@ as.matrix.projection <- function(x, ...) {
warning("Note that projection was performed using clustering and the ",
"clusters might have different weights.")
}
res <- do.call(rbind, lapply(x$sub_fit, get_subparams))
res <- do.call(rbind, lapply(x$submodl, get_subparams))
if (x$refmodel$family$family == "gaussian") res <- cbind(res, sigma = x$dis)
return(res)
}
Expand Down
10 changes: 8 additions & 2 deletions R/project.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@
#' \item{`solution_terms`}{A character vector of the submodel's
#' predictor terms, ordered in the way in which the terms were added to the
#' submodel.}
#' \item{`sub_fit`}{The submodel's fitted model object.}
#' \item{`submodl`}{A `list` containing the submodel fits (one fit per
#' projected draw).}
#' \item{`p_type`}{A single logical value indicating whether the
#' reference model's posterior draws have been clustered for the projection
#' (`TRUE`) or not (`FALSE`).}
Expand Down Expand Up @@ -135,6 +136,11 @@ project <- function(object, nterms = NULL, solution_terms = NULL,
cv_search <- TRUE
}

if (!cv_search) {
warning("Currently, `cv_search = FALSE` requires some caution, see GitHub ",
"issues #168 and #211.")
}

if (!is.null(solution_terms)) {
## if solution_terms is given, nterms is ignored
## (project only onto the given submodel)
Expand Down Expand Up @@ -215,7 +221,7 @@ project <- function(object, nterms = NULL, solution_terms = NULL,
search_path = nlist(
solution_terms,
p_sel = object$search_path$p_sel,
sub_fits = object$search_path$sub_fits
submodls = object$search_path$submodls
),
nterms = nterms, p_ref = p_ref, refmodel = refmodel, regul = regul,
cv_search = cv_search
Expand Down
29 changes: 15 additions & 14 deletions R/projfun.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ project_submodel <- function(solution_terms, p_ref, refmodel, regul = 1e-4) {
data = refmodel$fetch_data(), y = p_ref$mu
)

sub_fit <- refmodel$div_minimizer(
submodl <- refmodel$div_minimizer(
formula = flatten_formula(subset$formula),
data = subset$data,
family = refmodel$family,
Expand All @@ -22,11 +22,11 @@ project_submodel <- function(solution_terms, p_ref, refmodel, regul = 1e-4) {
)

if (isTRUE(getOption("projpred.check_conv", FALSE))) {
check_conv(sub_fit)
check_conv(submodl)
}

return(.init_submodel(
sub_fit = sub_fit, p_ref = p_ref, refmodel = refmodel,
submodl = submodl, p_ref = p_ref, refmodel = refmodel,
solution_terms = solution_terms, wobs = wobs, wsample = wsample
))
}
Expand All @@ -37,25 +37,26 @@ project_submodel <- function(solution_terms, p_ref, refmodel, regul = 1e-4) {
.get_submodels <- function(search_path, nterms, p_ref, refmodel, regul,
cv_search = FALSE) {
if (!cv_search) {
## simply fetch the already computed quantities for each submodel size
# In this case, simply fetch the already computed projections, so don't
# project again.
fetch_submodel <- function(nterms) {
validparams <- .validate_wobs_wsample(
refmodel$wobs, search_path$p_sel$weights, search_path$p_sel$mu
)
wobs <- validparams$wobs
wsample <- validparams$wsample

## reuse sub_fit as projected during search
sub_refit <- search_path$sub_fits[[nterms + 1]]

return(.init_submodel(
sub_fit = sub_refit, p_ref = search_path$p_sel, refmodel = refmodel,
# Re-use the submodel fits from the search:
submodl = search_path$submodls[[nterms + 1]],
p_ref = search_path$p_sel,
refmodel = refmodel,
solution_terms = utils::head(search_path$solution_terms, nterms),
wobs = wobs, wsample = wsample
wobs = wobs,
wsample = wsample
))
}
} else {
## need to project again for each submodel size
# In this case, project again.
fetch_submodel <- function(nterms) {
return(project_submodel(
solution_terms = utils::head(search_path$solution_terms, nterms),
Expand Down Expand Up @@ -83,7 +84,7 @@ project_submodel <- function(solution_terms, p_ref, refmodel, regul = 1e-4) {
return(nlist(wobs, wsample))
}

.init_submodel <- function(sub_fit, p_ref, refmodel, solution_terms, wobs,
.init_submodel <- function(submodl, p_ref, refmodel, solution_terms, wobs,
wsample) {
p_ref$mu <- refmodel$family$linkinv(
refmodel$family$linkfun(p_ref$mu) + refmodel$offset
Expand Down Expand Up @@ -114,13 +115,13 @@ project_submodel <- function(solution_terms, p_ref, refmodel, regul = 1e-4) {
###
}

mu <- refmodel$family$mu_fun(sub_fit, offset = refmodel$offset)
mu <- refmodel$family$mu_fun(submodl, offset = refmodel$offset)
dis <- refmodel$family$dis_fun(p_ref, nlist(mu), wobs)
kl <- weighted.mean(
refmodel$family$kl(p_ref,
nlist(weights = wobs),
nlist(mu, dis)),
wsample
)
return(nlist(dis, kl, weights = wsample, solution_terms, sub_fit))
return(nlist(dis, kl, weights = wsample, solution_terms, submodl))
}
18 changes: 9 additions & 9 deletions R/refmodel.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@
#' + `newdata` accepts either `NULL` (for using the original dataset,
#' typically stored in `fit`) or data for new observations (at least in the
#' form of a `data.frame`).
#' * `proj_predfun`: `proj_predfun(fit, newdata)` where:
#' + `fit` accepts a `list` of length \eqn{S_{\mbox{prj}}}{S_prj} containing
#' this number of submodel fits. This `list` is the same as that returned by
#' [project()] in its output element `sub_fit` (which in turn is the same as
#' the return value of `div_minimizer`, except if [project()] was used with
#' an `object` of class `vsel` based on an L1 search as well as with
#' `cv_search = FALSE`).
#' * `proj_predfun`: `proj_predfun(fits, newdata)` where:
#' + `fits` accepts a `list` of length \eqn{S_{\mbox{prj}}}{S_prj}
#' containing this number of submodel fits. This `list` is the same as that
#' returned by [project()] in its output element `submodl` (which in turn is
#' the same as the return value of `div_minimizer`, except if [project()]
#' was used with an `object` of class `vsel` based on an L1 search as well
#' as with `cv_search = FALSE`).
#' + `newdata` accepts data for new observations (at least in the form of a
#' `data.frame`).
#' * `div_minimizer` does not need to have a specific prototype, but it needs to
Expand Down Expand Up @@ -563,14 +563,14 @@ init_refmodel <- function(object, data, formula, family, ref_predfun = NULL,
family <- extend_family(family)
}

family$mu_fun <- function(fit, obs = NULL, newdata = NULL, offset = NULL) {
family$mu_fun <- function(fits, obs = NULL, newdata = NULL, offset = NULL) {
newdata <- fetch_data(data, obs = obs, newdata = newdata)
if (is.null(offset)) {
offset <- rep(0, nrow(newdata))
} else {
stopifnot(length(offset) %in% c(1L, nrow(newdata)))
}
family$linkinv(proj_predfun(fit, newdata = newdata) + offset)
family$linkinv(proj_predfun(fits, newdata = newdata) + offset)
}

# Special case: `datafit` -------------------------------------------------
Expand Down
8 changes: 4 additions & 4 deletions R/search.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ search_forward <- function(p_ref, refmodel, nterms_max, verbose = TRUE, opt,
chosen <- c(chosen, cands[imin])

## append submodels
submodels <- c(submodels, list(subL[[imin]]$sub_fit))
submodels <- c(submodels, list(subL[[imin]]$submodl))

if (verbose && count_terms_chosen(chosen) %in% iq) {
print(paste0(names(iq)[max(which(count_terms_chosen(chosen) == iq))],
Expand All @@ -35,7 +35,7 @@ search_forward <- function(p_ref, refmodel, nterms_max, verbose = TRUE, opt,

## reduce chosen to a list of non-redundant accumulated models
return(list(solution_terms = setdiff(reduce_models(chosen), "1"),
sub_fits = submodels))
submodls = submodels))
}

# copied over from search until we resolve the TODO below
Expand Down Expand Up @@ -150,7 +150,7 @@ search_L1 <- function(p_ref, refmodel, nterms_max, penalty, opt) {
refmodel$formula, colnames(x)[search_path$solution_terms],
refmodel$fetch_data()
)
sub_fits <- lapply(0:length(solution_terms), function(nterms) {
submodls <- lapply(0:length(solution_terms), function(nterms) {
if (nterms == 0) {
formula <- make_formula(c("1"))
beta <- NULL
Expand Down Expand Up @@ -186,5 +186,5 @@ search_L1 <- function(p_ref, refmodel, nterms_max, penalty, opt) {
return(list(sub))
})
return(list(solution_terms = solution_terms[seq_len(nterms_max)],
sub_fits = sub_fits[seq_len(nterms_max + 1)]))
submodls = submodls[seq_len(nterms_max + 1)]))
}
2 changes: 1 addition & 1 deletion R/summary_funs.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
weights = refmodel$wobs[test_points]),
family = refmodel$family,
wsample = model$weights,
mu = refmodel$family$mu_fun(model$sub_fit,
mu = refmodel$family$mu_fun(model$submodl,
obs = test_points,
offset = refmodel$offset[test_points]),
dis = model$dis
Expand Down
3 changes: 2 additions & 1 deletion man/project.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 7 additions & 7 deletions man/refmodel-init-get.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit d20b542

Please sign in to comment.