From adaa1e69177b15f9007452ff40c0246ec70631ff Mon Sep 17 00:00:00 2001 From: fweber144 Date: Sat, 18 Nov 2023 22:30:45 +0100 Subject: [PATCH 1/4] Add argument `search_control` accepting tuning parameters (in R also known as "control arguments") for the search. The main reason for this is to allow different tuning parameters for `refmodel$div_minimizer()` in `search_forward()` and `perf_eval()`. However, this change is also the most straightforward solution to ensure that fold-wise searches from `cv_varsel.vsel()` use the same tuning parameters as the previously run full-data search (`args_search` did not take `...` into account). --- R/cv_varsel.R | 40 +++++++++++------------- R/projfun.R | 24 +++++++++------ R/search.R | 22 ++++++++------ R/varsel.R | 52 +++++++++++++++++--------------- man/cv_varsel.Rd | 39 +++++++++++++----------- man/varsel.Rd | 39 +++++++++++++----------- tests/testthat/helpers/testers.R | 2 +- tests/testthat/test_datafit.R | 5 +-- 8 files changed, 120 insertions(+), 103 deletions(-) diff --git a/R/cv_varsel.R b/R/cv_varsel.R index 4c2d46636..202248e78 100644 --- a/R/cv_varsel.R +++ b/R/cv_varsel.R @@ -59,8 +59,10 @@ #' @param ... For [cv_varsel.default()]: Arguments passed to [get_refmodel()] as #' well as to [cv_varsel.refmodel()]. For [cv_varsel.vsel()]: Arguments passed #' to [cv_varsel.refmodel()]. For [cv_varsel.refmodel()]: Arguments passed to -#' the divergence minimizer (during a forward search and also during the -#' evaluation part, but the latter only if `refit_prj` is `TRUE`). +#' the divergence minimizer (see argument `div_minimizer` of [init_refmodel()] +#' as well as section "Draw-wise divergence minimizers" of [projpred-package]) +#' when refitting the submodels for the performance evaluation (if `refit_prj` +#' is `TRUE`). #' #' @inherit varsel details return #' @@ -196,9 +198,7 @@ cv_varsel.vsel <- function( ndraws = object[["args_search"]][["ndraws"]], nclusters = object[["args_search"]][["nclusters"]], nterms_max = object[["args_search"]][["nterms_max"]], - lambda_min_ratio = object[["args_search"]][["lambda_min_ratio"]], - nlambda = object[["args_search"]][["nlambda"]], - thresh = object[["args_search"]][["thresh"]], + search_control = object[["args_search"]][["search_control"]], penalty = object[["args_search"]][["penalty"]], search_terms = object[["args_search"]][["search_terms"]], cv_method = cv_method, @@ -228,9 +228,7 @@ cv_varsel.refmodel <- function( nloo = object$nobs, K = if (!inherits(object, "datafit")) 5 else 10, cvfits = object$cvfits, - lambda_min_ratio = 1e-5, - nlambda = 150, - thresh = 1e-6, + search_control = list(), validate_search = TRUE, seed = NA, search_terms = NULL, @@ -274,8 +272,6 @@ cv_varsel.refmodel <- function( nloo <- args$nloo K <- args$K cvfits <- args$cvfits - # Arguments specific to the search: - opt <- nlist(lambda_min_ratio, nlambda, thresh) # Full-data search: if (!is.null(search_out)) { @@ -293,7 +289,8 @@ cv_varsel.refmodel <- function( search_path_fulldata <- select( refmodel = refmodel, ndraws = ndraws, nclusters = nclusters, method = method, nterms_max = nterms_max, penalty = penalty, - verbose = verbose, opt = opt, search_terms = search_terms, + verbose = verbose, search_control = search_control, + search_terms = search_terms, search_terms_was_null = search_terms_was_null, ... ) verb_out("-----", verbose = verbose) @@ -319,7 +316,7 @@ cv_varsel.refmodel <- function( refmodel = refmodel, method = method, nterms_max = nterms_max, ndraws = ndraws, nclusters = nclusters, ndraws_pred = ndraws_pred, nclusters_pred = nclusters_pred, refit_prj = refit_prj, penalty = penalty, - verbose = verbose, opt = opt, nloo = nloo, + verbose = verbose, search_control = search_control, nloo = nloo, validate_search = validate_search, search_path_fulldata = if (validate_search) { # Not needed in this case, so for computational efficiency, avoiding @@ -337,8 +334,8 @@ cv_varsel.refmodel <- function( refmodel = refmodel, method = method, nterms_max = nterms_max, ndraws = ndraws, nclusters = nclusters, ndraws_pred = ndraws_pred, nclusters_pred = nclusters_pred, refit_prj = refit_prj, penalty = penalty, - verbose = verbose, opt = opt, K = K, cvfits = cvfits, - validate_search = validate_search, + verbose = verbose, search_control = search_control, K = K, + cvfits = cvfits, validate_search = validate_search, search_path_fulldata = if (validate_search) { # Not needed in this case, so for computational efficiency, avoiding # passing the large object `search_path_fulldata` to loo_varsel(): @@ -395,8 +392,7 @@ cv_varsel.refmodel <- function( validate_search, cvfits, args_search = nlist( - method, ndraws, nclusters, nterms_max, lambda_min_ratio, - nlambda, thresh, penalty, + method, ndraws, nclusters, nterms_max, search_control, penalty, search_terms = if (search_terms_was_null) NULL else search_terms ), clust_used_search = refdist_info_search$clust_used, @@ -500,7 +496,7 @@ parse_args_cv_varsel <- function(refmodel, cv_method, nloo, K, cvfits, # all other arguments, see the documentation of cv_varsel(). loo_varsel <- function(refmodel, method, nterms_max, ndraws, nclusters, ndraws_pred, nclusters_pred, refit_prj, - penalty, verbose, opt, nloo, validate_search, + penalty, verbose, search_control, nloo, validate_search, search_path_fulldata, search_terms, search_terms_was_null, search_out_rks, parallel, ...) { ## Pre-processing --------------------------------------------------------- @@ -918,8 +914,8 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws, refmodel = refmodel, ndraws = ndraws, nclusters = nclusters, reweighting_args = list(cl_ref = cl_sel, wdraws_ref = exp(lw[, i])), method = method, nterms_max = nterms_max, penalty = penalty, - verbose = verbose_search, opt = opt, search_terms = search_terms, - est_runtime = FALSE, ... + verbose = verbose_search, search_control = search_control, + search_terms = search_terms, est_runtime = FALSE, ... ) } @@ -1189,7 +1185,7 @@ if (getRversion() >= package_version("2.15.1")) { kfold_varsel <- function(refmodel, method, nterms_max, ndraws, nclusters, ndraws_pred, nclusters_pred, refit_prj, penalty, - verbose, opt, K, cvfits, validate_search, + verbose, search_control, K, cvfits, validate_search, search_path_fulldata, search_terms, search_out_rks, parallel, ...) { # Fetch the K reference model fits (or fit them now if not already done) and @@ -1238,8 +1234,8 @@ kfold_varsel <- function(refmodel, method, nterms_max, ndraws, nclusters, search_path <- select( refmodel = fold$refmodel, ndraws = ndraws, nclusters = nclusters, method = method, nterms_max = nterms_max, penalty = penalty, - verbose = verbose_search, opt = opt, search_terms = search_terms, - est_runtime = FALSE, ... + verbose = verbose_search, search_control = search_control, + search_terms = search_terms, est_runtime = FALSE, ... ) } diff --git a/R/projfun.R b/R/projfun.R index 1e46c1759..44a46bd17 100644 --- a/R/projfun.R +++ b/R/projfun.R @@ -3,7 +3,8 @@ # to a single fit (there are as many fits for this single submodel as there are # projected draws). At the end, init_submodl() is called, so the output is of # class `submodl`. -proj_to_submodl <- function(predictor_terms, p_ref, refmodel, ...) { +proj_to_submodl <- function(predictor_terms, p_ref, refmodel, + search_control = list(), ...) { y_unqs_aug <- refmodel$family$cats if (refmodel$family$for_latent && !is.null(y_unqs_aug)) { y_unqs_aug <- NULL @@ -23,15 +24,18 @@ proj_to_submodl <- function(predictor_terms, p_ref, refmodel, ...) { verb_out(" Projecting onto ", utils::tail(rhs_chr, 1)) } - outdmin <- refmodel$div_minimizer( - formula = fml_divmin, - data = subset$data, - family = refmodel$family, - weights = refmodel$wobs, - projpred_var = p_ref$var, - projpred_ws_aug = p_ref$mu, - ... - ) + args_divmin <- list(formula = fml_divmin, + data = subset$data, + family = refmodel$family, + weights = refmodel$wobs, + projpred_var = p_ref$var, + projpred_ws_aug = p_ref$mu) + if (length(search_control) > 0) { + args_divmin <- c(args_divmin, search_control) + } else { + args_divmin <- c(args_divmin, list(...)) + } + outdmin <- do.call(refmodel$div_minimizer, args_divmin) if (isTRUE(getOption("projpred.check_conv", FALSE))) { check_conv(outdmin) diff --git a/R/search.R b/R/search.R index 601666b5a..d879206d1 100644 --- a/R/search.R +++ b/R/search.R @@ -196,7 +196,7 @@ force_search_terms <- function(forced_terms, optional_terms) { } search_L1_surrogate <- function(p_ref, d_train, family, intercept, nterms_max, - penalty, opt) { + penalty, search_control) { ## predictive mean and variance of the reference model (with parameters ## integrated out) @@ -212,13 +212,15 @@ search_L1_surrogate <- function(p_ref, d_train, family, intercept, nterms_max, ## (Notice: here we use pmax = nterms_max+1 so that the computation gets ## carried until all the way down to the least regularization also for model ## size nterms_max) - search <- glm_elnet(d_train$x, mu, family, - lambda_min_ratio = opt$lambda_min_ratio, - nlambda = opt$nlambda, - pmax = nterms_max + 1, pmax_strict = FALSE, - weights = d_train$weights, - intercept = intercept, obsvar = v, penalty = penalty, - thresh = opt$thresh) + search <- glm_elnet( + d_train$x, mu, family, + lambda_min_ratio = search_control$lambda_min_ratio %||% 1e-5, + nlambda = search_control$nlambda %||% 150, + pmax = nterms_max + 1, pmax_strict = FALSE, + weights = d_train$weights, + intercept = intercept, obsvar = v, penalty = penalty, + thresh = search_control$thresh %||% 1e-6 + ) ## sort the variables according to the order in which they enter the model in ## the L1-path @@ -282,7 +284,7 @@ search_L1_surrogate <- function(p_ref, d_train, family, intercept, nterms_max, return(out) } -search_L1 <- function(p_ref, refmodel, nterms_max, penalty, opt) { +search_L1 <- function(p_ref, refmodel, nterms_max, penalty, search_control) { if (nterms_max == 0) { stop("L1 search cannot be used for an empty (i.e. intercept-only) ", "full-model formula or `nterms_max = 0`.") @@ -314,7 +316,7 @@ search_L1 <- function(p_ref, refmodel, nterms_max, penalty, opt) { terms_ <- attr(tt, "term.labels") search_path <- search_L1_surrogate( p_ref, nlist(x, weights = refmodel$wobs), refmodel$family, - intercept = TRUE, ncol(x), penalty, opt + intercept = TRUE, ncol(x), penalty, search_control ) predictor_ranking_orig <- collapse_ranked_predictors( diff --git a/R/varsel.R b/R/varsel.R index 37699f7b2..ddaeac01b 100644 --- a/R/varsel.R +++ b/R/varsel.R @@ -49,16 +49,21 @@ #' those predictors have no cost and will therefore be selected first, whereas #' `Inf` means those predictors will never be selected. If `NULL`, then `1` is #' used for each predictor. -#' @param lambda_min_ratio Only relevant for L1 search. Ratio between the -#' smallest and largest lambda in the L1-penalized search. This parameter -#' essentially determines how long the search is carried out, i.e., how large -#' submodels are explored. No need to change this unless the program gives a -#' warning about this. -#' @param nlambda Only relevant for L1 search. Number of values in the lambda -#' grid for L1-penalized search. No need to change this unless the program -#' gives a warning about this. -#' @param thresh Only relevant for L1 search. Convergence threshold when -#' computing the L1 path. Usually, there is no need to change this. +#' @param search_control A `list` of "control" arguments (i.e., tuning +#' parameters) for the search. In case of forward search, these arguments are +#' passed to the divergence minimizer (see argument `div_minimizer` of +#' [init_refmodel()] as well as section "Draw-wise divergence minimizers" of +#' [projpred-package]). In case of L1 search, possible arguments are: +#' * `lambda_min_ratio`: Ratio between the smallest and largest lambda in the +#' L1-penalized search (default: `1e-5`). This parameter essentially +#' determines how long the search is carried out, i.e., how large submodels +#' are explored. No need to change this unless the program gives a warning +#' about this. +#' * `nlambda`: Number of values in the lambda grid for L1-penalized search +#' (default: `150`). No need to change this unless the program gives a warning +#' about this. +#' * `thresh`: Convergence threshold when computing the L1 path (default: +#' `1e-6`). Usually, there is no need to change this. #' @param search_terms Only relevant for forward search. A custom character #' vector of predictor term blocks to consider for the search. Section #' "Details" below describes more precisely what "predictor term block" means. @@ -80,8 +85,10 @@ #' @param ... For [varsel.default()]: Arguments passed to [get_refmodel()] as #' well as to [varsel.refmodel()]. For [varsel.vsel()]: Arguments passed to #' [varsel.refmodel()]. For [varsel.refmodel()]: Arguments passed to the -#' divergence minimizer (during a forward search and also during the -#' evaluation part, but the latter only if `refit_prj` is `TRUE`). +#' divergence minimizer (see argument `div_minimizer` of [init_refmodel()] as +#' well as section "Draw-wise divergence minimizers" of [projpred-package]) +#' when refitting the submodels for the performance evaluation (if `refit_prj` +#' is `TRUE`). #' #' @details #' @@ -204,9 +211,7 @@ varsel.vsel <- function(object, ...) { ndraws = object[["args_search"]][["ndraws"]], nclusters = object[["args_search"]][["nclusters"]], nterms_max = object[["args_search"]][["nterms_max"]], - lambda_min_ratio = object[["args_search"]][["lambda_min_ratio"]], - nlambda = object[["args_search"]][["nlambda"]], - thresh = object[["args_search"]][["thresh"]], + search_control = object[["args_search"]][["search_control"]], penalty = object[["args_search"]][["penalty"]], search_terms = object[["args_search"]][["search_terms"]], search_out = list(search_path = object[["search_path"]]), @@ -221,8 +226,8 @@ varsel.refmodel <- function(object, d_test = NULL, method = "forward", nclusters_pred = NULL, refit_prj = !inherits(object, "datafit"), nterms_max = NULL, verbose = TRUE, - lambda_min_ratio = 1e-5, nlambda = 150, - thresh = 1e-6, penalty = NULL, search_terms = NULL, + search_control = list(), + penalty = NULL, search_terms = NULL, search_out = NULL, seed = NA, ...) { if (exists(".Random.seed", envir = .GlobalEnv)) { rng_state_old <- get(".Random.seed", envir = .GlobalEnv) @@ -314,12 +319,12 @@ varsel.refmodel <- function(object, d_test = NULL, method = "forward", if (!is.null(search_out)) { search_path <- search_out[["search_path"]] } else { - opt <- nlist(lambda_min_ratio, nlambda, thresh) verb_out("-----\nRunning the search ...", verbose = verbose) search_path <- select( refmodel = refmodel, ndraws = ndraws, nclusters = nclusters, method = method, nterms_max = nterms_max, penalty = penalty, - verbose = verbose, opt = opt, search_terms = search_terms, + verbose = verbose, search_control = search_control, + search_terms = search_terms, search_terms_was_null = search_terms_was_null, ... ) verb_out("-----", verbose = verbose) @@ -416,8 +421,7 @@ varsel.refmodel <- function(object, d_test = NULL, method = "forward", cvfits = refmodel$cvfits, ### args_search = nlist( - method, ndraws, nclusters, nterms_max, lambda_min_ratio, - nlambda, thresh, penalty, + method, ndraws, nclusters, nterms_max, search_control, penalty, search_terms = if (search_terms_was_null) NULL else search_terms ), clust_used_search = search_path$p_sel$clust_used, @@ -446,7 +450,7 @@ varsel.refmodel <- function(object, d_test = NULL, method = "forward", # of fits per model size being equal to the number of projected draws), and # `p_sel` (the output from get_refdist() for the search). select <- function(refmodel, ndraws, nclusters, reweighting_args = NULL, method, - nterms_max, penalty, verbose, opt, ...) { + nterms_max, penalty, verbose, search_control, ...) { if (is.null(reweighting_args)) { p_sel <- get_refdist(refmodel, ndraws = ndraws, nclusters = nclusters) } else { @@ -460,12 +464,12 @@ select <- function(refmodel, ndraws, nclusters, reweighting_args = NULL, method, if (method == "L1") { search_path <- search_L1( p_ref = p_sel, refmodel = refmodel, nterms_max = nterms_max, - penalty = penalty, opt = opt + penalty = penalty, search_control = search_control ) } else if (method == "forward") { search_path <- search_forward( p_ref = p_sel, refmodel = refmodel, nterms_max = nterms_max, - verbose = verbose, ... + verbose = verbose, search_control = search_control, ... ) } search_path$p_sel <- p_sel diff --git a/man/cv_varsel.Rd b/man/cv_varsel.Rd index f48c0397d..5bfa4d4fd 100644 --- a/man/cv_varsel.Rd +++ b/man/cv_varsel.Rd @@ -36,9 +36,7 @@ cv_varsel(object, ...) nloo = object$nobs, K = if (!inherits(object, "datafit")) 5 else 10, cvfits = object$cvfits, - lambda_min_ratio = 1e-05, - nlambda = 150, - thresh = 1e-06, + search_control = list(), validate_search = TRUE, seed = NA, search_terms = NULL, @@ -55,8 +53,10 @@ cv_varsel(object, ...) \item{...}{For \code{\link[=cv_varsel.default]{cv_varsel.default()}}: Arguments passed to \code{\link[=get_refmodel]{get_refmodel()}} as well as to \code{\link[=cv_varsel.refmodel]{cv_varsel.refmodel()}}. For \code{\link[=cv_varsel.vsel]{cv_varsel.vsel()}}: Arguments passed to \code{\link[=cv_varsel.refmodel]{cv_varsel.refmodel()}}. For \code{\link[=cv_varsel.refmodel]{cv_varsel.refmodel()}}: Arguments passed to -the divergence minimizer (during a forward search and also during the -evaluation part, but the latter only if \code{refit_prj} is \code{TRUE}).} +the divergence minimizer (see argument \code{div_minimizer} of \code{\link[=init_refmodel]{init_refmodel()}} +as well as section "Draw-wise divergence minimizers" of \link{projpred-package}) +when refitting the submodels for the performance evaluation (if \code{refit_prj} +is \code{TRUE}).} \item{cv_method}{The CV method, either \code{"LOO"} or \code{"kfold"}. In the \code{"LOO"} case, a Pareto-smoothed importance sampling leave-one-out CV (PSIS-LOO CV) @@ -135,18 +135,23 @@ used for each predictor.} \item{verbose}{A single logical value indicating whether to print out additional information during the computations.} -\item{lambda_min_ratio}{Only relevant for L1 search. Ratio between the -smallest and largest lambda in the L1-penalized search. This parameter -essentially determines how long the search is carried out, i.e., how large -submodels are explored. No need to change this unless the program gives a -warning about this.} - -\item{nlambda}{Only relevant for L1 search. Number of values in the lambda -grid for L1-penalized search. No need to change this unless the program -gives a warning about this.} - -\item{thresh}{Only relevant for L1 search. Convergence threshold when -computing the L1 path. Usually, there is no need to change this.} +\item{search_control}{A \code{list} of "control" arguments (i.e., tuning +parameters) for the search. In case of forward search, these arguments are +passed to the divergence minimizer (see argument \code{div_minimizer} of +\code{\link[=init_refmodel]{init_refmodel()}} as well as section "Draw-wise divergence minimizers" of +\link{projpred-package}). In case of L1 search, possible arguments are: +\itemize{ +\item \code{lambda_min_ratio}: Ratio between the smallest and largest lambda in the +L1-penalized search (default: \code{1e-5}). This parameter essentially +determines how long the search is carried out, i.e., how large submodels +are explored. No need to change this unless the program gives a warning +about this. +\item \code{nlambda}: Number of values in the lambda grid for L1-penalized search +(default: \code{150}). No need to change this unless the program gives a warning +about this. +\item \code{thresh}: Convergence threshold when computing the L1 path (default: +\code{1e-6}). Usually, there is no need to change this. +}} \item{seed}{Pseudorandom number generation (PRNG) seed by which the same results can be obtained again if needed. Passed to argument \code{seed} of diff --git a/man/varsel.Rd b/man/varsel.Rd index ddc73277a..71e0f0b74 100644 --- a/man/varsel.Rd +++ b/man/varsel.Rd @@ -24,9 +24,7 @@ varsel(object, ...) refit_prj = !inherits(object, "datafit"), nterms_max = NULL, verbose = TRUE, - lambda_min_ratio = 1e-05, - nlambda = 150, - thresh = 1e-06, + search_control = list(), penalty = NULL, search_terms = NULL, search_out = NULL, @@ -42,8 +40,10 @@ varsel(object, ...) \item{...}{For \code{\link[=varsel.default]{varsel.default()}}: Arguments passed to \code{\link[=get_refmodel]{get_refmodel()}} as well as to \code{\link[=varsel.refmodel]{varsel.refmodel()}}. For \code{\link[=varsel.vsel]{varsel.vsel()}}: Arguments passed to \code{\link[=varsel.refmodel]{varsel.refmodel()}}. For \code{\link[=varsel.refmodel]{varsel.refmodel()}}: Arguments passed to the -divergence minimizer (during a forward search and also during the -evaluation part, but the latter only if \code{refit_prj} is \code{TRUE}).} +divergence minimizer (see argument \code{div_minimizer} of \code{\link[=init_refmodel]{init_refmodel()}} as +well as section "Draw-wise divergence minimizers" of \link{projpred-package}) +when refitting the submodels for the performance evaluation (if \code{refit_prj} +is \code{TRUE}).} \item{d_test}{A \code{list} of the structure outlined in section "Argument \code{d_test}" below, providing test data for evaluating the predictive @@ -89,18 +89,23 @@ does not count the intercept.)} \item{verbose}{A single logical value indicating whether to print out additional information during the computations.} -\item{lambda_min_ratio}{Only relevant for L1 search. Ratio between the -smallest and largest lambda in the L1-penalized search. This parameter -essentially determines how long the search is carried out, i.e., how large -submodels are explored. No need to change this unless the program gives a -warning about this.} - -\item{nlambda}{Only relevant for L1 search. Number of values in the lambda -grid for L1-penalized search. No need to change this unless the program -gives a warning about this.} - -\item{thresh}{Only relevant for L1 search. Convergence threshold when -computing the L1 path. Usually, there is no need to change this.} +\item{search_control}{A \code{list} of "control" arguments (i.e., tuning +parameters) for the search. In case of forward search, these arguments are +passed to the divergence minimizer (see argument \code{div_minimizer} of +\code{\link[=init_refmodel]{init_refmodel()}} as well as section "Draw-wise divergence minimizers" of +\link{projpred-package}). In case of L1 search, possible arguments are: +\itemize{ +\item \code{lambda_min_ratio}: Ratio between the smallest and largest lambda in the +L1-penalized search (default: \code{1e-5}). This parameter essentially +determines how long the search is carried out, i.e., how large submodels +are explored. No need to change this unless the program gives a warning +about this. +\item \code{nlambda}: Number of values in the lambda grid for L1-penalized search +(default: \code{150}). No need to change this unless the program gives a warning +about this. +\item \code{thresh}: Convergence threshold when computing the L1 path (default: +\code{1e-6}). Usually, there is no need to change this. +}} \item{penalty}{Only relevant for L1 search. A numeric vector determining the relative penalties or costs for the predictors. A value of \code{0} means that diff --git a/tests/testthat/helpers/testers.R b/tests/testthat/helpers/testers.R index dde40b393..f95182f58 100644 --- a/tests/testthat/helpers/testers.R +++ b/tests/testthat/helpers/testers.R @@ -2411,7 +2411,7 @@ vsel_tester <- function( NULL }, nterms_max = vs$nterms_max, - lambda_min_ratio = 1e-5, nlambda = 150, thresh = 1e-6, + search_control = list(), penalty = penalty_expected, search_terms = if (is.null(search_terms_expected)) { NULL diff --git a/tests/testthat/test_datafit.R b/tests/testthat/test_datafit.R index adf6c0449..3678b6d0d 100644 --- a/tests/testthat/test_datafit.R +++ b/tests/testthat/test_datafit.R @@ -682,8 +682,9 @@ test_that(paste( ) vs <- suppressWarnings(varsel( ref, - method = "L1", lambda_min_ratio = lambda_min_ratio, - nlambda = nlambda, thresh = 1e-12, verbose = FALSE + method = "L1", + search_control = nlist(lambda_min_ratio, nlambda, thresh = 1e-12), + verbose = FALSE )) pred1 <- proj_linpred(vs, newdata = data.frame(x = x, weights = weights), From cb963caea04f3e8b4af9b6feb467ae24c6b38db0 Mon Sep 17 00:00:00 2001 From: fweber144 Date: Sun, 19 Nov 2023 14:42:44 +0100 Subject: [PATCH 2/4] Use `search_control = NULL` by default. This has the advantage that in `proj_to_submodl()`, `is.null(search_control)` can serve as an indicator for using `...` instead of `search_control` and that `search_control = list()` can be specified when the defaults of the corresponding underlying draw-wise divergence minimizer should be used in the search, but not in the performance evaluation. --- R/cv_varsel.R | 8 ++++++-- R/projfun.R | 10 ++++++---- R/varsel.R | 12 +++++++++--- man/cv_varsel.Rd | 6 ++++-- man/varsel.Rd | 6 ++++-- tests/testthat/helpers/testers.R | 7 ++++++- tests/testthat/test_varsel.R | 29 +++++++++++++++++++++++++++++ 7 files changed, 64 insertions(+), 14 deletions(-) diff --git a/R/cv_varsel.R b/R/cv_varsel.R index 202248e78..750d69ee0 100644 --- a/R/cv_varsel.R +++ b/R/cv_varsel.R @@ -228,7 +228,7 @@ cv_varsel.refmodel <- function( nloo = object$nobs, K = if (!inherits(object, "datafit")) 5 else 10, cvfits = object$cvfits, - search_control = list(), + search_control = NULL, validate_search = TRUE, seed = NA, search_terms = NULL, @@ -392,7 +392,11 @@ cv_varsel.refmodel <- function( validate_search, cvfits, args_search = nlist( - method, ndraws, nclusters, nterms_max, search_control, penalty, + method, ndraws, nclusters, nterms_max, + search_control = if ( + method == "forward" && is.null(search_control) + ) list(...) else search_control, + penalty, search_terms = if (search_terms_was_null) NULL else search_terms ), clust_used_search = refdist_info_search$clust_used, diff --git a/R/projfun.R b/R/projfun.R index 44a46bd17..d967eb9fb 100644 --- a/R/projfun.R +++ b/R/projfun.R @@ -1,10 +1,12 @@ # Function to project the reference model onto a single submodel with predictor # terms given in `predictor_terms`. Note that "single submodel" does not refer # to a single fit (there are as many fits for this single submodel as there are -# projected draws). At the end, init_submodl() is called, so the output is of -# class `submodl`. +# projected draws). The case `is.null(search_control)` occurs in two situations: +# (i) when called from search_forward() with `...` as the intended control +# arguments and (ii) when called from perf_eval(). At the end, init_submodl() is +# called, so the output is of class `submodl`. proj_to_submodl <- function(predictor_terms, p_ref, refmodel, - search_control = list(), ...) { + search_control = NULL, ...) { y_unqs_aug <- refmodel$family$cats if (refmodel$family$for_latent && !is.null(y_unqs_aug)) { y_unqs_aug <- NULL @@ -30,7 +32,7 @@ proj_to_submodl <- function(predictor_terms, p_ref, refmodel, weights = refmodel$wobs, projpred_var = p_ref$var, projpred_ws_aug = p_ref$mu) - if (length(search_control) > 0) { + if (!is.null(search_control)) { args_divmin <- c(args_divmin, search_control) } else { args_divmin <- c(args_divmin, list(...)) diff --git a/R/varsel.R b/R/varsel.R index ddaeac01b..f66603618 100644 --- a/R/varsel.R +++ b/R/varsel.R @@ -53,7 +53,9 @@ #' parameters) for the search. In case of forward search, these arguments are #' passed to the divergence minimizer (see argument `div_minimizer` of #' [init_refmodel()] as well as section "Draw-wise divergence minimizers" of -#' [projpred-package]). In case of L1 search, possible arguments are: +#' [projpred-package]). In case of forward search, `NULL` causes `...` to be +#' used not only for the performance evaluation, but also for the search. In +#' case of L1 search, possible arguments are: #' * `lambda_min_ratio`: Ratio between the smallest and largest lambda in the #' L1-penalized search (default: `1e-5`). This parameter essentially #' determines how long the search is carried out, i.e., how large submodels @@ -226,7 +228,7 @@ varsel.refmodel <- function(object, d_test = NULL, method = "forward", nclusters_pred = NULL, refit_prj = !inherits(object, "datafit"), nterms_max = NULL, verbose = TRUE, - search_control = list(), + search_control = NULL, penalty = NULL, search_terms = NULL, search_out = NULL, seed = NA, ...) { if (exists(".Random.seed", envir = .GlobalEnv)) { @@ -421,7 +423,11 @@ varsel.refmodel <- function(object, d_test = NULL, method = "forward", cvfits = refmodel$cvfits, ### args_search = nlist( - method, ndraws, nclusters, nterms_max, search_control, penalty, + method, ndraws, nclusters, nterms_max, + search_control = if ( + method == "forward" && is.null(search_control) + ) list(...) else search_control, + penalty, search_terms = if (search_terms_was_null) NULL else search_terms ), clust_used_search = search_path$p_sel$clust_used, diff --git a/man/cv_varsel.Rd b/man/cv_varsel.Rd index 5bfa4d4fd..66cd77381 100644 --- a/man/cv_varsel.Rd +++ b/man/cv_varsel.Rd @@ -36,7 +36,7 @@ cv_varsel(object, ...) nloo = object$nobs, K = if (!inherits(object, "datafit")) 5 else 10, cvfits = object$cvfits, - search_control = list(), + search_control = NULL, validate_search = TRUE, seed = NA, search_terms = NULL, @@ -139,7 +139,9 @@ additional information during the computations.} parameters) for the search. In case of forward search, these arguments are passed to the divergence minimizer (see argument \code{div_minimizer} of \code{\link[=init_refmodel]{init_refmodel()}} as well as section "Draw-wise divergence minimizers" of -\link{projpred-package}). In case of L1 search, possible arguments are: +\link{projpred-package}). In case of forward search, \code{NULL} causes \code{...} to be +used not only for the performance evaluation, but also for the search. In +case of L1 search, possible arguments are: \itemize{ \item \code{lambda_min_ratio}: Ratio between the smallest and largest lambda in the L1-penalized search (default: \code{1e-5}). This parameter essentially diff --git a/man/varsel.Rd b/man/varsel.Rd index 71e0f0b74..5286f741a 100644 --- a/man/varsel.Rd +++ b/man/varsel.Rd @@ -24,7 +24,7 @@ varsel(object, ...) refit_prj = !inherits(object, "datafit"), nterms_max = NULL, verbose = TRUE, - search_control = list(), + search_control = NULL, penalty = NULL, search_terms = NULL, search_out = NULL, @@ -93,7 +93,9 @@ additional information during the computations.} parameters) for the search. In case of forward search, these arguments are passed to the divergence minimizer (see argument \code{div_minimizer} of \code{\link[=init_refmodel]{init_refmodel()}} as well as section "Draw-wise divergence minimizers" of -\link{projpred-package}). In case of L1 search, possible arguments are: +\link{projpred-package}). In case of forward search, \code{NULL} causes \code{...} to be +used not only for the performance evaluation, but also for the search. In +case of L1 search, possible arguments are: \itemize{ \item \code{lambda_min_ratio}: Ratio between the smallest and largest lambda in the L1-penalized search (default: \code{1e-5}). This parameter essentially diff --git a/tests/testthat/helpers/testers.R b/tests/testthat/helpers/testers.R index f95182f58..e777abb54 100644 --- a/tests/testthat/helpers/testers.R +++ b/tests/testthat/helpers/testers.R @@ -1965,6 +1965,7 @@ vsel_tester <- function( penalty_expected = NULL, search_terms_expected = NULL, search_trms_empty_size = FALSE, + search_control_expected = NULL, extra_tol = 1.1, info_str = "" ) { @@ -2398,6 +2399,10 @@ vsel_tester <- function( expect_identical(vs$cvfits, cvfits_expected, info = info_str) # args_search + sce <- search_control_expected[!sapply(search_control_expected, is.null)] + if (!length(sce)) { + sce <- if (method_expected == "forward") list() else NULL + } expect_equal( vs$args_search, list( @@ -2411,7 +2416,7 @@ vsel_tester <- function( NULL }, nterms_max = vs$nterms_max, - search_control = list(), + search_control = sce, penalty = penalty_expected, search_terms = if (is.null(search_terms_expected)) { NULL diff --git a/tests/testthat/test_varsel.R b/tests/testthat/test_varsel.R index 2fb42eed3..2e4218140 100644 --- a/tests/testthat/test_varsel.R +++ b/tests/testthat/test_varsel.R @@ -22,6 +22,7 @@ test_that(paste( search_trms_empty_size = length(args_vs[[tstsetup]]$search_terms) && all(grepl("\\+", args_vs[[tstsetup]]$search_terms)), + search_control_expected = args_vs[[tstsetup]][c("avoid.increase")], info_str = tstsetup ) } @@ -162,6 +163,7 @@ test_that(paste( search_trms_empty_size = length(args_vs_i$search_terms) && all(grepl("\\+", args_vs_i$search_terms)), + search_control_expected = args_vs_i[c("avoid.increase")], info_str = tstsetup ) expect_equal(vs_repr[setdiff(names(vs_repr), @@ -289,6 +291,7 @@ test_that(paste( search_trms_empty_size = length(args_vs_i$search_terms) && all(grepl("\\+", args_vs_i$search_terms)), + search_control_expected = args_vs_i[c("avoid.increase")], info_str = tstsetup ) @@ -560,6 +563,7 @@ test_that("`refit_prj` works", { search_trms_empty_size = length(args_vs_i$search_terms) && all(grepl("\\+", args_vs_i$search_terms)), + search_control_expected = args_vs_i[c("avoid.increase")], extra_tol = extra_tol_crr, info_str = tstsetup ) @@ -779,6 +783,8 @@ test_that(paste( search_trms_empty_size = length(args_vs_i$search_terms) && all(grepl("\\+", args_vs_i$search_terms)), + search_control_expected = c(args_vs_i[c("avoid.increase")], + list(regul = regul_tst[j])), info_str = tstsetup ) } @@ -1139,6 +1145,7 @@ test_that("varsel.vsel() works for `vsel` objects from varsel()", { search_trms_empty_size = length(args_vs[[tstsetup]]$search_terms) && all(grepl("\\+", args_vs[[tstsetup]]$search_terms)), + search_control_expected = args_vs[[tstsetup]][c("avoid.increase")], extra_tol = extra_tol_crr, info_str = tstsetup ) @@ -1195,6 +1202,7 @@ test_that("varsel.vsel() works for `vsel` objects from cv_varsel()", { search_trms_empty_size = length(args_cvvs[[tstsetup]]$search_terms) && all(grepl("\\+", args_cvvs[[tstsetup]]$search_terms)), + search_control_expected = args_cvvs[[tstsetup]][c("avoid.increase")], info_str = tstsetup ) tstsetup_counter <- tstsetup_counter + 1L @@ -1233,6 +1241,7 @@ test_that(paste( search_trms_empty_size = length(args_cvvs[[tstsetup]]$search_terms) && all(grepl("\\+", args_cvvs[[tstsetup]]$search_terms)), + search_control_expected = args_cvvs[[tstsetup]][c("avoid.increase")], info_str = tstsetup ) } @@ -1358,6 +1367,7 @@ test_that("`refit_prj` works", { search_trms_empty_size = length(args_cvvs_i$search_terms) && all(grepl("\\+", args_cvvs_i$search_terms)), + search_control_expected = args_cvvs_i[c("avoid.increase")], info_str = tstsetup ) } @@ -1498,6 +1508,7 @@ test_that("setting `nloo` smaller than the number of observations works", { search_trms_empty_size = length(args_cvvs_i$search_terms) && all(grepl("\\+", args_cvvs_i$search_terms)), + search_control_expected = args_cvvs_i[c("avoid.increase")], info_str = tstsetup ) # Expected equality for most elements with a few exceptions: @@ -1573,6 +1584,7 @@ test_that("`validate_search` works", { search_trms_empty_size = length(args_cvvs_i$search_terms) && all(grepl("\\+", args_cvvs_i$search_terms)), + search_control_expected = args_cvvs_i[c("avoid.increase")], info_str = tstsetup ) # Expected equality for most elements with a few exceptions: @@ -1863,6 +1875,7 @@ test_that(paste( search_trms_empty_size = length(args_cvvs_i$search_terms) && all(grepl("\\+", args_cvvs_i$search_terms)), + search_control_expected = args_cvvs_i[c("avoid.increase")], info_str = tstsetup ) # Expected equality for most elements with a few exceptions: @@ -1913,6 +1926,7 @@ test_that("`cvfun` included in the `refmodel` object works", { search_trms_empty_size = length(args_cvvs[[tstsetup]]$search_terms) && all(grepl("\\+", args_cvvs[[tstsetup]]$search_terms)), + search_control_expected = args_cvvs[[tstsetup]][c("avoid.increase")], info_str = tstsetup ) } @@ -1967,6 +1981,7 @@ test_that(paste( search_trms_empty_size = length(args_cvvs[[tstsetup]]$search_terms) && all(grepl("\\+", args_cvvs[[tstsetup]]$search_terms)), + search_control_expected = args_cvvs[[tstsetup]][c("avoid.increase")], info_str = tstsetup ) } @@ -2017,6 +2032,7 @@ test_that(paste( search_trms_empty_size = length(args_vs[[tstsetup]]$search_terms) && all(grepl("\\+", args_vs[[tstsetup]]$search_terms)), + search_control_expected = args_vs[[tstsetup]][c("avoid.increase")], info_str = tstsetup ) tstsetup_counter <- tstsetup_counter + 1L @@ -2076,6 +2092,7 @@ test_that(paste( search_trms_empty_size = length(args_vs[[tstsetup]]$search_terms) && all(grepl("\\+", args_vs[[tstsetup]]$search_terms)), + search_control_expected = args_vs[[tstsetup]][c("avoid.increase")], info_str = tstsetup ) tstsetup_counter <- tstsetup_counter + 1L @@ -2143,6 +2160,7 @@ test_that(paste( search_trms_empty_size = length(args_cvvs[[tstsetup]]$search_terms) && all(grepl("\\+", args_cvvs[[tstsetup]]$search_terms)), + search_control_expected = args_cvvs[[tstsetup]][c("avoid.increase")], extra_tol = extra_tol_crr, info_str = tstsetup ) @@ -2203,6 +2221,7 @@ test_that(paste( search_trms_empty_size = length(args_cvvs[[tstsetup]]$search_terms) && all(grepl("\\+", args_cvvs[[tstsetup]]$search_terms)), + search_control_expected = args_cvvs[[tstsetup]][c("avoid.increase")], info_str = tstsetup ) } @@ -2265,6 +2284,7 @@ test_that(paste( search_trms_empty_size = length(args_cvvs[[tstsetup]]$search_terms) && all(grepl("\\+", args_cvvs[[tstsetup]]$search_terms)), + search_control_expected = args_cvvs[[tstsetup]][c("avoid.increase")], info_str = tstsetup ) } @@ -2327,6 +2347,7 @@ test_that(paste( search_trms_empty_size = length(args_cvvs[[tstsetup]]$search_terms) && all(grepl("\\+", args_cvvs[[tstsetup]]$search_terms)), + search_control_expected = args_cvvs[[tstsetup]][c("avoid.increase")], info_str = tstsetup ) } @@ -2388,6 +2409,7 @@ test_that(paste( search_trms_empty_size = length(args_vs[[tstsetup]]$search_terms) && all(grepl("\\+", args_vs[[tstsetup]]$search_terms)), + search_control_expected = args_vs[[tstsetup]][c("avoid.increase")], info_str = tstsetup ) tstsetup_counter <- tstsetup_counter + 1L @@ -2465,6 +2487,7 @@ test_that(paste( search_trms_empty_size = length(args_vs[[tstsetup]]$search_terms) && all(grepl("\\+", args_vs[[tstsetup]]$search_terms)), + search_control_expected = args_vs[[tstsetup]][c("avoid.increase")], info_str = tstsetup ) tstsetup_counter <- tstsetup_counter + 1L @@ -2559,6 +2582,7 @@ test_that(paste( search_trms_empty_size = length(args_cvvs[[tstsetup]]$search_terms) && all(grepl("\\+", args_cvvs[[tstsetup]]$search_terms)), + search_control_expected = args_cvvs[[tstsetup]][c("avoid.increase")], extra_tol = extra_tol_crr, info_str = tstsetup ) @@ -2641,6 +2665,7 @@ test_that(paste( search_trms_empty_size = length(args_cvvs[[tstsetup]]$search_terms) && all(grepl("\\+", args_cvvs[[tstsetup]]$search_terms)), + search_control_expected = args_cvvs[[tstsetup]][c("avoid.increase")], info_str = tstsetup ) } @@ -2697,6 +2722,7 @@ test_that("cv_varsel.vsel(): `nloo` works for `vsel` objects from varsel()", { search_trms_empty_size = length(args_vs[[tstsetup]]$search_terms) && all(grepl("\\+", args_vs[[tstsetup]]$search_terms)), + search_control_expected = args_vs[[tstsetup]][c("avoid.increase")], info_str = tstsetup ) vsel_tester( @@ -2713,6 +2739,7 @@ test_that("cv_varsel.vsel(): `nloo` works for `vsel` objects from varsel()", { search_trms_empty_size = length(args_vs[[tstsetup]]$search_terms) && all(grepl("\\+", args_vs[[tstsetup]]$search_terms)), + search_control_expected = args_vs[[tstsetup]][c("avoid.increase")], info_str = tstsetup ) tstsetup_counter <- tstsetup_counter + 1L @@ -2800,6 +2827,7 @@ test_that(paste( search_trms_empty_size = length(args_cvvs[[tstsetup]]$search_terms) && all(grepl("\\+", args_cvvs[[tstsetup]]$search_terms)), + search_control_expected = args_cvvs[[tstsetup]][c("avoid.increase")], extra_tol = extra_tol_crr, info_str = tstsetup ) @@ -2833,6 +2861,7 @@ test_that(paste( search_trms_empty_size = length(args_cvvs[[tstsetup]]$search_terms) && all(grepl("\\+", args_cvvs[[tstsetup]]$search_terms)), + search_control_expected = args_cvvs[[tstsetup]][c("avoid.increase")], info_str = tstsetup ) } From 8c4289783e6864c95d0b4948b1a5a65e4e9f4455 Mon Sep 17 00:00:00 2001 From: fweber144 Date: Mon, 20 Nov 2023 14:06:42 +0100 Subject: [PATCH 3/4] Add a `NEWS.md` entry for the new argument `search_control`. --- NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEWS.md b/NEWS.md index 836a22f81..3285b57ab 100644 --- a/NEWS.md +++ b/NEWS.md @@ -27,6 +27,7 @@ If you read this from a place other than Date: Mon, 20 Nov 2023 14:19:40 +0100 Subject: [PATCH 4/4] Deprecate the old L1 search tuning parameter arguments. --- NEWS.md | 2 +- R/cv_varsel.R | 23 +++++++++++++++++++++++ R/varsel.R | 39 ++++++++++++++++++++++++++++++++++++--- man/cv_varsel.Rd | 18 ++++++++++++++++++ man/varsel.Rd | 18 ++++++++++++++++++ 5 files changed, 96 insertions(+), 4 deletions(-) diff --git a/NEWS.md b/NEWS.md index 3285b57ab..6a9bdad47 100644 --- a/NEWS.md +++ b/NEWS.md @@ -27,7 +27,7 @@ If you read this from a place other than