Skip to content

Commit

Permalink
Merge pull request #477 from fweber144/search_control
Browse files Browse the repository at this point in the history
Add `search_control`
  • Loading branch information
fweber144 authored Nov 22, 2023
2 parents 4fc0b19 + 38ba99a commit 83e8f9e
Show file tree
Hide file tree
Showing 10 changed files with 252 additions and 92 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ If you read this from a place other than <https://mc-stan.org/projpred/news/inde
* `print.vselsummary()` and `print.vsel()` now use a minimum number of significant digits of `2` by default. The previous behavior can be restored by setting `options(projpred.digits = getOption("digits"))`.
* Added a new performance statistic, the geometric mean predictive density (GMPD). This is particularly useful for discrete outcomes because there, the GMPD is a geometric mean of probabilities and hence bounded by zero and one. For details, see argument `stats` of the `?summary.vsel` help. (GitHub: #476)
* `project()`'s argument `verbose` now gets passed to argument `verbose_divmin` (not `projpred_verbose`) of the divergence minimizer function (see argument `div_minimizer` of `init_refmodel()`).
* Arguments `lambda_min_ratio`, `nlambda`, and `thresh` of `varsel()` and `cv_varsel()` have been deprecated. Instead, `varsel()` and `cv_varsel()` have gained a new argument called `search_control` which accepts control arguments for the search as a `list`. Thus, former arguments `lambda_min_ratio`, `nlambda`, and `thresh` should now be specified via `search_control` (but note that `search_control` is more general because it also accepts control arguments for a *forward* search). (GitHub: #477)

## Bug fixes

Expand Down
61 changes: 42 additions & 19 deletions R/cv_varsel.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@
#' @param ... For [cv_varsel.default()]: Arguments passed to [get_refmodel()] as
#' well as to [cv_varsel.refmodel()]. For [cv_varsel.vsel()]: Arguments passed
#' to [cv_varsel.refmodel()]. For [cv_varsel.refmodel()]: Arguments passed to
#' the divergence minimizer (during a forward search and also during the
#' evaluation part, but the latter only if `refit_prj` is `TRUE`).
#' the divergence minimizer (see argument `div_minimizer` of [init_refmodel()]
#' as well as section "Draw-wise divergence minimizers" of [projpred-package])
#' when refitting the submodels for the performance evaluation (if `refit_prj`
#' is `TRUE`).
#'
#' @inherit varsel details return
#'
Expand Down Expand Up @@ -196,9 +198,7 @@ cv_varsel.vsel <- function(
ndraws = object[["args_search"]][["ndraws"]],
nclusters = object[["args_search"]][["nclusters"]],
nterms_max = object[["args_search"]][["nterms_max"]],
lambda_min_ratio = object[["args_search"]][["lambda_min_ratio"]],
nlambda = object[["args_search"]][["nlambda"]],
thresh = object[["args_search"]][["thresh"]],
search_control = object[["args_search"]][["search_control"]],
penalty = object[["args_search"]][["penalty"]],
search_terms = object[["args_search"]][["search_terms"]],
cv_method = cv_method,
Expand Down Expand Up @@ -228,6 +228,7 @@ cv_varsel.refmodel <- function(
nloo = object$nobs,
K = if (!inherits(object, "datafit")) 5 else 10,
cvfits = object$cvfits,
search_control = NULL,
lambda_min_ratio = 1e-5,
nlambda = 150,
thresh = 1e-6,
Expand All @@ -238,6 +239,26 @@ cv_varsel.refmodel <- function(
parallel = getOption("projpred.prll_cv", FALSE),
...
) {
if (!missing(lambda_min_ratio)) {
warning("Argument `lambda_min_ratio` is deprecated. Please specify ",
"control arguments for the search via argument `search_control`. ",
"Now using `lambda_min_ratio` as element `lambda_min_ratio` of ",
"`search_control`.")
search_control$lambda_min_ratio <- lambda_min_ratio
}
if (!missing(nlambda)) {
warning("Argument `nlambda` is deprecated. Please specify control ",
"arguments for the search via argument `search_control`. ",
"Now using `nlambda` as element `nlambda` of `search_control`.")
search_control$nlambda <- nlambda
}
if (!missing(thresh)) {
warning("Argument `thresh` is deprecated. Please specify control ",
"arguments for the search via argument `search_control`. ",
"Now using `thresh` as element `thresh` of `search_control`.")
search_control$thresh <- thresh
}

if (exists(".Random.seed", envir = .GlobalEnv)) {
rng_state_old <- get(".Random.seed", envir = .GlobalEnv)
}
Expand Down Expand Up @@ -274,8 +295,6 @@ cv_varsel.refmodel <- function(
nloo <- args$nloo
K <- args$K
cvfits <- args$cvfits
# Arguments specific to the search:
opt <- nlist(lambda_min_ratio, nlambda, thresh)

# Full-data search:
if (!is.null(search_out)) {
Expand All @@ -293,7 +312,8 @@ cv_varsel.refmodel <- function(
search_path_fulldata <- select(
refmodel = refmodel, ndraws = ndraws, nclusters = nclusters,
method = method, nterms_max = nterms_max, penalty = penalty,
verbose = verbose, opt = opt, search_terms = search_terms,
verbose = verbose, search_control = search_control,
search_terms = search_terms,
search_terms_was_null = search_terms_was_null, ...
)
verb_out("-----", verbose = verbose)
Expand All @@ -319,7 +339,7 @@ cv_varsel.refmodel <- function(
refmodel = refmodel, method = method, nterms_max = nterms_max,
ndraws = ndraws, nclusters = nclusters, ndraws_pred = ndraws_pred,
nclusters_pred = nclusters_pred, refit_prj = refit_prj, penalty = penalty,
verbose = verbose, opt = opt, nloo = nloo,
verbose = verbose, search_control = search_control, nloo = nloo,
validate_search = validate_search,
search_path_fulldata = if (validate_search) {
# Not needed in this case, so for computational efficiency, avoiding
Expand All @@ -337,8 +357,8 @@ cv_varsel.refmodel <- function(
refmodel = refmodel, method = method, nterms_max = nterms_max,
ndraws = ndraws, nclusters = nclusters, ndraws_pred = ndraws_pred,
nclusters_pred = nclusters_pred, refit_prj = refit_prj, penalty = penalty,
verbose = verbose, opt = opt, K = K, cvfits = cvfits,
validate_search = validate_search,
verbose = verbose, search_control = search_control, K = K,
cvfits = cvfits, validate_search = validate_search,
search_path_fulldata = if (validate_search) {
# Not needed in this case, so for computational efficiency, avoiding
# passing the large object `search_path_fulldata` to loo_varsel():
Expand Down Expand Up @@ -395,8 +415,11 @@ cv_varsel.refmodel <- function(
validate_search,
cvfits,
args_search = nlist(
method, ndraws, nclusters, nterms_max, lambda_min_ratio,
nlambda, thresh, penalty,
method, ndraws, nclusters, nterms_max,
search_control = if (
method == "forward" && is.null(search_control)
) list(...) else search_control,
penalty,
search_terms = if (search_terms_was_null) NULL else search_terms
),
clust_used_search = refdist_info_search$clust_used,
Expand Down Expand Up @@ -500,7 +523,7 @@ parse_args_cv_varsel <- function(refmodel, cv_method, nloo, K, cvfits,
# all other arguments, see the documentation of cv_varsel().
loo_varsel <- function(refmodel, method, nterms_max, ndraws,
nclusters, ndraws_pred, nclusters_pred, refit_prj,
penalty, verbose, opt, nloo, validate_search,
penalty, verbose, search_control, nloo, validate_search,
search_path_fulldata, search_terms,
search_terms_was_null, search_out_rks, parallel, ...) {
## Pre-processing ---------------------------------------------------------
Expand Down Expand Up @@ -918,8 +941,8 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,
refmodel = refmodel, ndraws = ndraws, nclusters = nclusters,
reweighting_args = list(cl_ref = cl_sel, wdraws_ref = exp(lw[, i])),
method = method, nterms_max = nterms_max, penalty = penalty,
verbose = verbose_search, opt = opt, search_terms = search_terms,
est_runtime = FALSE, ...
verbose = verbose_search, search_control = search_control,
search_terms = search_terms, est_runtime = FALSE, ...
)
}

Expand Down Expand Up @@ -1189,7 +1212,7 @@ if (getRversion() >= package_version("2.15.1")) {

kfold_varsel <- function(refmodel, method, nterms_max, ndraws, nclusters,
ndraws_pred, nclusters_pred, refit_prj, penalty,
verbose, opt, K, cvfits, validate_search,
verbose, search_control, K, cvfits, validate_search,
search_path_fulldata, search_terms, search_out_rks,
parallel, ...) {
# Fetch the K reference model fits (or fit them now if not already done) and
Expand Down Expand Up @@ -1238,8 +1261,8 @@ kfold_varsel <- function(refmodel, method, nterms_max, ndraws, nclusters,
search_path <- select(
refmodel = fold$refmodel, ndraws = ndraws, nclusters = nclusters,
method = method, nterms_max = nterms_max, penalty = penalty,
verbose = verbose_search, opt = opt, search_terms = search_terms,
est_runtime = FALSE, ...
verbose = verbose_search, search_control = search_control,
search_terms = search_terms, est_runtime = FALSE, ...
)
}

Expand Down
30 changes: 18 additions & 12 deletions R/projfun.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Function to project the reference model onto a single submodel with predictor
# terms given in `predictor_terms`. Note that "single submodel" does not refer
# to a single fit (there are as many fits for this single submodel as there are
# projected draws). At the end, init_submodl() is called, so the output is of
# class `submodl`.
proj_to_submodl <- function(predictor_terms, p_ref, refmodel, ...) {
# projected draws). The case `is.null(search_control)` occurs in two situations:
# (i) when called from search_forward() with `...` as the intended control
# arguments and (ii) when called from perf_eval(). At the end, init_submodl() is
# called, so the output is of class `submodl`.
proj_to_submodl <- function(predictor_terms, p_ref, refmodel,
search_control = NULL, ...) {
y_unqs_aug <- refmodel$family$cats
if (refmodel$family$for_latent && !is.null(y_unqs_aug)) {
y_unqs_aug <- NULL
Expand All @@ -23,15 +26,18 @@ proj_to_submodl <- function(predictor_terms, p_ref, refmodel, ...) {
verb_out(" Projecting onto ", utils::tail(rhs_chr, 1))
}

outdmin <- refmodel$div_minimizer(
formula = fml_divmin,
data = subset$data,
family = refmodel$family,
weights = refmodel$wobs,
projpred_var = p_ref$var,
projpred_ws_aug = p_ref$mu,
...
)
args_divmin <- list(formula = fml_divmin,
data = subset$data,
family = refmodel$family,
weights = refmodel$wobs,
projpred_var = p_ref$var,
projpred_ws_aug = p_ref$mu)
if (!is.null(search_control)) {
args_divmin <- c(args_divmin, search_control)
} else {
args_divmin <- c(args_divmin, list(...))
}
outdmin <- do.call(refmodel$div_minimizer, args_divmin)

if (isTRUE(getOption("projpred.check_conv", FALSE))) {
check_conv(outdmin)
Expand Down
22 changes: 12 additions & 10 deletions R/search.R
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ force_search_terms <- function(forced_terms, optional_terms) {
}

search_L1_surrogate <- function(p_ref, d_train, family, intercept, nterms_max,
penalty, opt) {
penalty, search_control) {

## predictive mean and variance of the reference model (with parameters
## integrated out)
Expand All @@ -212,13 +212,15 @@ search_L1_surrogate <- function(p_ref, d_train, family, intercept, nterms_max,
## (Notice: here we use pmax = nterms_max+1 so that the computation gets
## carried until all the way down to the least regularization also for model
## size nterms_max)
search <- glm_elnet(d_train$x, mu, family,
lambda_min_ratio = opt$lambda_min_ratio,
nlambda = opt$nlambda,
pmax = nterms_max + 1, pmax_strict = FALSE,
weights = d_train$weights,
intercept = intercept, obsvar = v, penalty = penalty,
thresh = opt$thresh)
search <- glm_elnet(
d_train$x, mu, family,
lambda_min_ratio = search_control$lambda_min_ratio %||% 1e-5,
nlambda = search_control$nlambda %||% 150,
pmax = nterms_max + 1, pmax_strict = FALSE,
weights = d_train$weights,
intercept = intercept, obsvar = v, penalty = penalty,
thresh = search_control$thresh %||% 1e-6
)

## sort the variables according to the order in which they enter the model in
## the L1-path
Expand Down Expand Up @@ -282,7 +284,7 @@ search_L1_surrogate <- function(p_ref, d_train, family, intercept, nterms_max,
return(out)
}

search_L1 <- function(p_ref, refmodel, nterms_max, penalty, opt) {
search_L1 <- function(p_ref, refmodel, nterms_max, penalty, search_control) {
if (nterms_max == 0) {
stop("L1 search cannot be used for an empty (i.e. intercept-only) ",
"full-model formula or `nterms_max = 0`.")
Expand Down Expand Up @@ -314,7 +316,7 @@ search_L1 <- function(p_ref, refmodel, nterms_max, penalty, opt) {
terms_ <- attr(tt, "term.labels")
search_path <- search_L1_surrogate(
p_ref, nlist(x, weights = refmodel$wobs), refmodel$family,
intercept = TRUE, ncol(x), penalty, opt
intercept = TRUE, ncol(x), penalty, search_control
)

predictor_ranking_orig <- collapse_ranked_predictors(
Expand Down
Loading

0 comments on commit 83e8f9e

Please sign in to comment.