Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve factor levels in outcome column #337

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mikropml
Title: User-Friendly R Package for Supervised Machine Learning Pipelines
Version: 1.6.0.9000
Version: 1.6.0.9001
Date: 2023-04-14
Authors@R:
c(person(given = "Begüm",
Expand Down
65 changes: 58 additions & 7 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ check_outcome_column <- function(dataset, outcome_colname, check_values = TRUE,
#' \dontrun{
#' check_outcome_value(otu_small, "dx", "cancer")
#' }
check_outcome_value <- function(dataset, outcome_colname) {
check_outcome_value <- function(dataset, outcome_colname, pos_outcome = NULL) {
# check no NA's
outcomes_vec <- dataset %>% dplyr::pull(outcome_colname)
num_missing <- sum(is.na(outcomes_vec))
Expand All @@ -273,20 +273,17 @@ check_outcome_value <- function(dataset, outcome_colname) {
warning(paste0("Possible missing data in the output variable: ", num_empty, " empty value(s)."))
}

outcomes_all <- dataset %>%
dplyr::pull(outcome_colname)

# check if continuous outcome
isnum <- is.numeric(outcomes_all)
isnum <- is.numeric(outcomes_vec)
if (isnum) {
# check if it might actually be categorical
if (all(floor(outcomes_all) == outcomes_all)) {
if (all(floor(outcomes_vec) == outcomes_vec)) {
warning("Data is being considered numeric, but all outcome values are integers. If you meant to code your values as categorical, please use character values.")
}
}

# check binary and multiclass outcome
outcomes <- outcomes_all %>%
outcomes <- outcomes_vec %>%
unique()
num_outcomes <- length(outcomes)
if (num_outcomes < 2) {
Expand All @@ -299,6 +296,60 @@ check_outcome_value <- function(dataset, outcome_colname) {
}
}

#' Check or set outcome column to be a factor with `pos_class` as the first level
#'
#' @inheritParams run_ml
#'
#' @return dataset, with the outcome column as a factor
#' @keywords internal
#' @author Kelly Sovacool, \email{sovacool@@umich.edu}
#'
#' @examples
#' dat <- data.frame("dx" = c("a", "b", "a", "b", "b", "a"), feat = 1:6)
#' dat %>% set_outcome_factor("dx", "a")
#' dat %>% set_outcome_factor("dx", "b")
set_outcome_factor <- function(dataset, outcome_colname, pos_class) {
relevel_outcome <- FALSE
outcomes_vctr <- dataset %>% dplyr::pull(outcome_colname)
# make sure it's either a factor or pos_class is set.
# the first factor level is used as the positive class by caret
if (!is.factor(outcomes_vctr)) {
if (is.null(pos_class)) {
stop(paste0(
"Either the outcome column `", outcome_colname,
"` must be a factor with the first factor level being the positive class,\n",
"or you must specify `pos_class`."
))
}
relevel_outcome <- TRUE
} else {
first_lvl <- levels(outcomes_vctr)[1]
if (!is.null(pos_class) & pos_class != first_lvl) {
warning(paste0(
"`pos_class` is set, but it is not the first level in the outcome column. ",
"Releveling the outcome column to set ",
"`pos_class`=", pos_class, " as the first level."
))
relevel_outcome <- TRUE
}
}
if (isTRUE(relevel_outcome)) {
if (!(pos_class %in% outcomes_vctr)) {
stop(paste0(
"pos_class `", pos_class,
"` not found in outcome column."
))
}
dataset[outcome_colname] <- factor(outcomes_vctr,
levels = unique(c(
pos_class,
outcomes_vctr
))
)
}
return(dataset)
}

#' Check whether package(s) are installed
#'
#' @param ... names of packages to check
Expand Down
19 changes: 15 additions & 4 deletions R/performance.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,28 @@ get_perf_metric_name <- function(outcome_type) {
#' class_probs = TRUE
#' )
#' }
calc_perf_metrics <- function(test_data, trained_model, outcome_colname, perf_metric_function, class_probs) {
calc_perf_metrics <- function(test_data, trained_model, outcome_colname,
perf_metric_function, class_probs,
pos_class = NULL) {
pred_type <- "raw"
if (class_probs) pred_type <- "prob"
preds <- stats::predict(trained_model, test_data, type = pred_type)
obs <- test_data %>% dplyr::pull(outcome_colname)
if (class_probs) {
uniq_obs <- unique(c(test_data %>% dplyr::pull(outcome_colname), as.character(trained_model$pred$obs)))
obs <- factor(test_data %>% dplyr::pull(outcome_colname), levels = uniq_obs)
if (is.factor(obs)) {
uniq_obs <- obs %>% levels()
} else {
uniq_obs <- unique(c(
pos_class,
test_data %>% dplyr::pull(outcome_colname),
as.character(trained_model$pred$obs)
))
obs <- factor(test_data %>% dplyr::pull(outcome_colname), levels = uniq_obs)
}
# TODO refactor this line
pred_class <- factor(names(preds)[apply(preds, 1, which.max)], levels = uniq_obs)
perf_met <- perf_metric_function(data.frame(obs = obs, pred = pred_class, preds), lev = uniq_obs)
} else {
obs <- test_data %>% dplyr::pull(outcome_colname)
perf_met <- perf_metric_function(data.frame(obs = obs, pred = preds))
}
return(perf_met)
Expand Down
11 changes: 11 additions & 0 deletions R/run_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
#' - xgbTree: xgboost
#' @param outcome_colname Column name as a string of the outcome variable
#' (default `NULL`; the first column will be chosen automatically).
#' @param pos_class The positive class, i.e. which level of `outcome_colname` is
#' the event of interest. If the outcome is binary, either the
#' `outcome_colname` must be a factor with the first level being the positive
#' class, or `pos_class` must be set. (default: `NULL`).
#' @param hyperparameters Dataframe of hyperparameters
#' (default `NULL`; sensible defaults will be chosen automatically).
#' @param seed Random seed (default: `NA`).
Expand Down Expand Up @@ -131,6 +135,7 @@ run_ml <-
function(dataset,
method,
outcome_colname = NULL,
pos_class = NULL,
hyperparameters = NULL,
find_feature_importance = FALSE,
calculate_performance = TRUE,
Expand Down Expand Up @@ -216,6 +221,10 @@ run_ml <-

outcome_type <- get_outcome_type(outcomes_vctr)
class_probs <- outcome_type != "continuous"
if (outcome_type == "binary") {
# enforce factor levels
dataset <- dataset %>% set_outcome_factor(outcome_colname, pos_class)
}

if (is.null(perf_metric_function)) {
perf_metric_function <- get_perf_metric_fn(outcome_type)
Expand Down Expand Up @@ -254,6 +263,8 @@ run_ml <-
if (!is.na(seed)) {
set.seed(seed)
}
# verify that correct outcome level got used
trained_model_caret$levels[1]

if (calculate_performance) {
performance_tbl <- get_performance_tbl(
Expand Down
5 changes: 5 additions & 0 deletions data-raw/otu_mini_bin.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ otu_mini_group <- c(
otu_mini_bin_results_glmnet <- mikropml::run_ml(otu_mini_bin, # use built-in hyperparams
"glmnet",
outcome_colname = "dx",
pos_class = "cancer",
find_feature_importance = FALSE,
seed = 2019,
cv_times = 2
Expand Down Expand Up @@ -77,6 +78,7 @@ use_data(otu_mini_cv, overwrite = TRUE)
otu_mini_bin_results_rf <- mikropml::run_ml(otu_mini_bin,
"rf",
outcome_colname = "dx",
pos_class = "cancer",
find_feature_importance = TRUE,
seed = 2019,
cv_times = 2,
Expand All @@ -87,6 +89,7 @@ use_data(otu_mini_bin_results_rf, overwrite = TRUE)
otu_mini_bin_results_svmRadial <- mikropml::run_ml(otu_mini_bin,
"svmRadial",
outcome_colname = "dx",
pos_class = "cancer",
find_feature_importance = FALSE,
seed = 2019,
cv_times = 2
Expand All @@ -96,6 +99,7 @@ use_data(otu_mini_bin_results_svmRadial, overwrite = TRUE)
otu_mini_bin_results_xgbTree <- mikropml::run_ml(otu_mini_bin,
"xgbTree",
outcome_colname = "dx",
pos_class = "cancer",
find_feature_importance = FALSE,
seed = 2019,
cv_times = 2
Expand All @@ -105,6 +109,7 @@ use_data(otu_mini_bin_results_xgbTree, overwrite = TRUE)
otu_mini_bin_results_rpart2 <- mikropml::run_ml(otu_mini_bin,
"rpart2",
outcome_colname = "dx",
pos_class = "cancer",
find_feature_importance = FALSE,
seed = 2019,
cv_times = 2
Expand Down
Binary file modified data/otu_mini_bin.rda
Binary file not shown.
Binary file modified data/otu_mini_bin_results_glmnet.rda
Binary file not shown.
Binary file modified data/otu_mini_bin_results_rf.rda
Binary file not shown.
Binary file modified data/otu_mini_bin_results_rpart2.rda
Binary file not shown.
Binary file modified data/otu_mini_bin_results_svmRadial.rda
Binary file not shown.
Binary file modified data/otu_mini_bin_results_xgbTree.rda
Binary file not shown.
Binary file modified data/otu_mini_cv.rda
Binary file not shown.
2 changes: 1 addition & 1 deletion docs/dev/CODE_OF_CONDUCT.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion docs/dev/CONTRIBUTING.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion docs/dev/LICENSE-text.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion docs/dev/LICENSE.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion docs/dev/SUPPORT.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion docs/dev/articles/index.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion docs/dev/articles/introduction.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion docs/dev/articles/paper.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 18 additions & 18 deletions docs/dev/articles/parallel.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion docs/dev/articles/preprocess.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion docs/dev/articles/tuning.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading