diff --git a/DESCRIPTION b/DESCRIPTION index a236bb71..55a9bf52 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: mikropml Title: User-Friendly R Package for Supervised Machine Learning Pipelines -Version: 1.6.0.9000 +Version: 1.6.0.9001 Date: 2023-04-14 Authors@R: c(person(given = "Begüm", diff --git a/R/checks.R b/R/checks.R index 9e27a122..d49d1b42 100644 --- a/R/checks.R +++ b/R/checks.R @@ -259,7 +259,7 @@ check_outcome_column <- function(dataset, outcome_colname, check_values = TRUE, #' \dontrun{ #' check_outcome_value(otu_small, "dx", "cancer") #' } -check_outcome_value <- function(dataset, outcome_colname) { +check_outcome_value <- function(dataset, outcome_colname, pos_outcome = NULL) { # check no NA's outcomes_vec <- dataset %>% dplyr::pull(outcome_colname) num_missing <- sum(is.na(outcomes_vec)) @@ -273,20 +273,17 @@ check_outcome_value <- function(dataset, outcome_colname) { warning(paste0("Possible missing data in the output variable: ", num_empty, " empty value(s).")) } - outcomes_all <- dataset %>% - dplyr::pull(outcome_colname) - # check if continuous outcome - isnum <- is.numeric(outcomes_all) + isnum <- is.numeric(outcomes_vec) if (isnum) { # check if it might actually be categorical - if (all(floor(outcomes_all) == outcomes_all)) { + if (all(floor(outcomes_vec) == outcomes_vec)) { warning("Data is being considered numeric, but all outcome values are integers. If you meant to code your values as categorical, please use character values.") } } # check binary and multiclass outcome - outcomes <- outcomes_all %>% + outcomes <- outcomes_vec %>% unique() num_outcomes <- length(outcomes) if (num_outcomes < 2) { @@ -299,6 +296,60 @@ check_outcome_value <- function(dataset, outcome_colname) { } } +#' Check or set outcome column to be a factor with `pos_class` as the first level +#' +#' @inheritParams run_ml +#' +#' @return dataset, with the outcome column as a factor +#' @keywords internal +#' @author Kelly Sovacool, \email{sovacool@@umich.edu} +#' +#' @examples +#' dat <- data.frame("dx" = c("a", "b", "a", "b", "b", "a"), feat = 1:6) +#' dat %>% set_outcome_factor("dx", "a") +#' dat %>% set_outcome_factor("dx", "b") +set_outcome_factor <- function(dataset, outcome_colname, pos_class) { + relevel_outcome <- FALSE + outcomes_vctr <- dataset %>% dplyr::pull(outcome_colname) + # make sure it's either a factor or pos_class is set. + # the first factor level is used as the positive class by caret + if (!is.factor(outcomes_vctr)) { + if (is.null(pos_class)) { + stop(paste0( + "Either the outcome column `", outcome_colname, + "` must be a factor with the first factor level being the positive class,\n", + "or you must specify `pos_class`." + )) + } + relevel_outcome <- TRUE + } else { + first_lvl <- levels(outcomes_vctr)[1] + if (!is.null(pos_class) & pos_class != first_lvl) { + warning(paste0( + "`pos_class` is set, but it is not the first level in the outcome column. ", + "Releveling the outcome column to set ", + "`pos_class`=", pos_class, " as the first level." + )) + relevel_outcome <- TRUE + } + } + if (isTRUE(relevel_outcome)) { + if (!(pos_class %in% outcomes_vctr)) { + stop(paste0( + "pos_class `", pos_class, + "` not found in outcome column." + )) + } + dataset[outcome_colname] <- factor(outcomes_vctr, + levels = unique(c( + pos_class, + outcomes_vctr + )) + ) + } + return(dataset) +} + #' Check whether package(s) are installed #' #' @param ... names of packages to check diff --git a/R/performance.R b/R/performance.R index 15b47236..276db114 100644 --- a/R/performance.R +++ b/R/performance.R @@ -116,17 +116,28 @@ get_perf_metric_name <- function(outcome_type) { #' class_probs = TRUE #' ) #' } -calc_perf_metrics <- function(test_data, trained_model, outcome_colname, perf_metric_function, class_probs) { +calc_perf_metrics <- function(test_data, trained_model, outcome_colname, + perf_metric_function, class_probs, + pos_class = NULL) { pred_type <- "raw" if (class_probs) pred_type <- "prob" preds <- stats::predict(trained_model, test_data, type = pred_type) + obs <- test_data %>% dplyr::pull(outcome_colname) if (class_probs) { - uniq_obs <- unique(c(test_data %>% dplyr::pull(outcome_colname), as.character(trained_model$pred$obs))) - obs <- factor(test_data %>% dplyr::pull(outcome_colname), levels = uniq_obs) + if (is.factor(obs)) { + uniq_obs <- obs %>% levels() + } else { + uniq_obs <- unique(c( + pos_class, + test_data %>% dplyr::pull(outcome_colname), + as.character(trained_model$pred$obs) + )) + obs <- factor(test_data %>% dplyr::pull(outcome_colname), levels = uniq_obs) + } + # TODO refactor this line pred_class <- factor(names(preds)[apply(preds, 1, which.max)], levels = uniq_obs) perf_met <- perf_metric_function(data.frame(obs = obs, pred = pred_class, preds), lev = uniq_obs) } else { - obs <- test_data %>% dplyr::pull(outcome_colname) perf_met <- perf_metric_function(data.frame(obs = obs, pred = preds)) } return(perf_met) diff --git a/R/run_ml.R b/R/run_ml.R index 7abe8b82..74503b52 100644 --- a/R/run_ml.R +++ b/R/run_ml.R @@ -19,6 +19,10 @@ #' - xgbTree: xgboost #' @param outcome_colname Column name as a string of the outcome variable #' (default `NULL`; the first column will be chosen automatically). +#' @param pos_class The positive class, i.e. which level of `outcome_colname` is +#' the event of interest. If the outcome is binary, either the +#' `outcome_colname` must be a factor with the first level being the positive +#' class, or `pos_class` must be set. (default: `NULL`). #' @param hyperparameters Dataframe of hyperparameters #' (default `NULL`; sensible defaults will be chosen automatically). #' @param seed Random seed (default: `NA`). @@ -131,6 +135,7 @@ run_ml <- function(dataset, method, outcome_colname = NULL, + pos_class = NULL, hyperparameters = NULL, find_feature_importance = FALSE, calculate_performance = TRUE, @@ -216,6 +221,10 @@ run_ml <- outcome_type <- get_outcome_type(outcomes_vctr) class_probs <- outcome_type != "continuous" + if (outcome_type == "binary") { + # enforce factor levels + dataset <- dataset %>% set_outcome_factor(outcome_colname, pos_class) + } if (is.null(perf_metric_function)) { perf_metric_function <- get_perf_metric_fn(outcome_type) @@ -254,6 +263,8 @@ run_ml <- if (!is.na(seed)) { set.seed(seed) } + # verify that correct outcome level got used + trained_model_caret$levels[1] if (calculate_performance) { performance_tbl <- get_performance_tbl( diff --git a/data-raw/otu_mini_bin.R b/data-raw/otu_mini_bin.R index 0abf7e1b..3d8651bf 100644 --- a/data-raw/otu_mini_bin.R +++ b/data-raw/otu_mini_bin.R @@ -35,6 +35,7 @@ otu_mini_group <- c( otu_mini_bin_results_glmnet <- mikropml::run_ml(otu_mini_bin, # use built-in hyperparams "glmnet", outcome_colname = "dx", + pos_class = "cancer", find_feature_importance = FALSE, seed = 2019, cv_times = 2 @@ -77,6 +78,7 @@ use_data(otu_mini_cv, overwrite = TRUE) otu_mini_bin_results_rf <- mikropml::run_ml(otu_mini_bin, "rf", outcome_colname = "dx", + pos_class = "cancer", find_feature_importance = TRUE, seed = 2019, cv_times = 2, @@ -87,6 +89,7 @@ use_data(otu_mini_bin_results_rf, overwrite = TRUE) otu_mini_bin_results_svmRadial <- mikropml::run_ml(otu_mini_bin, "svmRadial", outcome_colname = "dx", + pos_class = "cancer", find_feature_importance = FALSE, seed = 2019, cv_times = 2 @@ -96,6 +99,7 @@ use_data(otu_mini_bin_results_svmRadial, overwrite = TRUE) otu_mini_bin_results_xgbTree <- mikropml::run_ml(otu_mini_bin, "xgbTree", outcome_colname = "dx", + pos_class = "cancer", find_feature_importance = FALSE, seed = 2019, cv_times = 2 @@ -105,6 +109,7 @@ use_data(otu_mini_bin_results_xgbTree, overwrite = TRUE) otu_mini_bin_results_rpart2 <- mikropml::run_ml(otu_mini_bin, "rpart2", outcome_colname = "dx", + pos_class = "cancer", find_feature_importance = FALSE, seed = 2019, cv_times = 2 diff --git a/data/otu_mini_bin.rda b/data/otu_mini_bin.rda index 33587bc0..0dacc6b2 100644 Binary files a/data/otu_mini_bin.rda and b/data/otu_mini_bin.rda differ diff --git a/data/otu_mini_bin_results_glmnet.rda b/data/otu_mini_bin_results_glmnet.rda index 41359778..e2931290 100644 Binary files a/data/otu_mini_bin_results_glmnet.rda and b/data/otu_mini_bin_results_glmnet.rda differ diff --git a/data/otu_mini_bin_results_rf.rda b/data/otu_mini_bin_results_rf.rda index 5f0cf9e4..804fad8f 100644 Binary files a/data/otu_mini_bin_results_rf.rda and b/data/otu_mini_bin_results_rf.rda differ diff --git a/data/otu_mini_bin_results_rpart2.rda b/data/otu_mini_bin_results_rpart2.rda index 3b560b1c..3de3588a 100644 Binary files a/data/otu_mini_bin_results_rpart2.rda and b/data/otu_mini_bin_results_rpart2.rda differ diff --git a/data/otu_mini_bin_results_svmRadial.rda b/data/otu_mini_bin_results_svmRadial.rda index 450a4d94..8d6f6190 100644 Binary files a/data/otu_mini_bin_results_svmRadial.rda and b/data/otu_mini_bin_results_svmRadial.rda differ diff --git a/data/otu_mini_bin_results_xgbTree.rda b/data/otu_mini_bin_results_xgbTree.rda index 5ccaa6cb..5a7ab666 100644 Binary files a/data/otu_mini_bin_results_xgbTree.rda and b/data/otu_mini_bin_results_xgbTree.rda differ diff --git a/data/otu_mini_cv.rda b/data/otu_mini_cv.rda index 5e2d9cf4..9da4b441 100644 Binary files a/data/otu_mini_cv.rda and b/data/otu_mini_cv.rda differ diff --git a/docs/dev/CODE_OF_CONDUCT.html b/docs/dev/CODE_OF_CONDUCT.html index 4fc94a4d..ccd6e588 100644 --- a/docs/dev/CODE_OF_CONDUCT.html +++ b/docs/dev/CODE_OF_CONDUCT.html @@ -10,7 +10,7 @@ mikropml - 1.5.0.9000 + 1.6.0.9001