diff --git a/NEWS.md b/NEWS.md index 5666c81b..48aa5717 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,9 @@ # development version +- New option to impute missing data after the train/test split rather than before (#301, @megancoden and @shah-priyal). + - Added `impute_in_training` option to `run_ml()`, which defaults to FALSE. + - Added `impute_in_preprocessing` option to `preprocess()`, which defaults to TRUE. + # mikropml 1.6.0 - New functions: @@ -13,7 +17,6 @@ - Renamed the column `names` to `feat` to represent each feature or group of correlated features. - New column `lower` and `upper` to report the bounds of the empirical 95% confidence interval from the permutation test. See `vignette('parallel')` for an example of plotting feature importance with confidence intervals. -- Minor documentation improvements (#323, #332, @kelly-sovacool). # mikropml 1.5.0 diff --git a/R/impute.R b/R/impute.R new file mode 100644 index 00000000..149e5316 --- /dev/null +++ b/R/impute.R @@ -0,0 +1,37 @@ +#' Replace NA values with the median value of the column for continious variables in the dataset +#' +#' @param transformed_cont Data frame that may include NA values in one or more columns +#' +#' @return Data frame that has no NA values in continious numeric columns +#' +#' @examples +#' transformed_cont <- impute(transformed_cont) +#' train_data <- impute(train_data) +#' test_data <- impute(test_data) +impute <- function(transformed_cont) { + sapply_fn <- select_apply("sapply") + cl <- sapply_fn(transformed_cont, function(x) { + class(x) + }) + missing <- + is.na(transformed_cont[, cl %in% c("integer", "numeric")]) + n_missing <- sum(missing) + if (n_missing > 0) { + transformed_cont <- sapply_fn(transformed_cont, function(x) { + if (class(x) %in% c("integer", "numeric")) { + m <- is.na(x) + x[m] <- stats::median(x, na.rm = TRUE) + } + message(typeof(x)) + message(class(x)) + return(x) + }) %>% dplyr::as_tibble() + message( + paste0( + n_missing, + " missing continuous value(s) were imputed using the median value of the feature." + ) + ) + } + return (transformed_cont) +} \ No newline at end of file diff --git a/R/preprocess.R b/R/preprocess.R index 3bddd927..20495aaf 100644 --- a/R/preprocess.R +++ b/R/preprocess.R @@ -60,7 +60,7 @@ preprocess_data <- function(dataset, outcome_colname, method = c("center", "scale"), remove_var = "nzv", collapse_corr_feats = TRUE, to_numeric = TRUE, group_neg_corr = TRUE, - prefilter_threshold = 1) { + prefilter_threshold = 1, impute_in_preprocessing = TRUE) { progbar <- NULL if (isTRUE(check_packages_installed("progressr"))) { progbar <- progressr::progressor(steps = 20, message = "preprocessing") @@ -70,10 +70,11 @@ preprocess_data <- function(dataset, outcome_colname, check_outcome_column(dataset, outcome_colname, check_values = FALSE) check_remove_var(remove_var) pbtick(progbar) + dataset[[outcome_colname]] <- replace_spaces(dataset[[outcome_colname]]) dataset <- rm_missing_outcome(dataset, outcome_colname) split_dat <- split_outcome_features(dataset, outcome_colname) - + features <- split_dat$features removed_feats <- character(0) if (to_numeric) { @@ -83,14 +84,14 @@ preprocess_data <- function(dataset, outcome_colname, features <- feats$dat } pbtick(progbar) - + nv_feats <- process_novar_feats(features, progbar = progbar) pbtick(progbar) split_feats <- process_cat_feats(nv_feats$var_feats, progbar = progbar) pbtick(progbar) - cont_feats <- process_cont_feats(split_feats$cont_feats, method) + cont_feats <- process_cont_feats(split_feats$cont_feats, method, impute_in_preprocessing) pbtick(progbar) - + # combine all processed features processed_feats <- dplyr::bind_cols( cont_feats$transformed_cont, @@ -98,7 +99,7 @@ preprocess_data <- function(dataset, outcome_colname, nv_feats$novar_feats ) pbtick(progbar) - + # remove features with (near-)zero variance feats <- get_caret_processed_df(processed_feats, remove_var) processed_feats <- feats$processed @@ -140,17 +141,15 @@ preprocess_data <- function(dataset, outcome_colname, #' @inheritParams run_ml #' #' @return dataset with no missing outcomes -#' @keywords internal +#' @noRd #' @author Zena Lapp, \email{zenalapp@@umich.edu} #' #' @examples -#' \dontrun{ #' rm_missing_outcome(mikropml::otu_mini_bin, "dx") #' #' test_df <- mikropml::otu_mini_bin #' test_df[1:100, "dx"] <- NA #' rm_missing_outcome(test_df, "dx") -#' } rm_missing_outcome <- function(dataset, outcome_colname) { n_outcome_na <- sum(is.na(dataset %>% dplyr::pull(outcome_colname))) total_outcomes <- nrow(dataset) @@ -168,13 +167,11 @@ rm_missing_outcome <- function(dataset, outcome_colname) { #' @param features dataframe of features for machine learning #' #' @return dataframe with numeric columns where possible -#' @keywords internal +#' @noRd #' @author Zena Lapp, \email{zenalapp@@umich.edu} #' #' @examples -#' \dontrun{ #' class(change_to_num(data.frame(val = c("1", "2", "3")))[[1]]) -#' } change_to_num <- function(features) { lapply_fn <- select_apply(fun = "lapply") check_features(features, check_missing = FALSE) @@ -228,13 +225,11 @@ remove_singleton_columns <- function(dat, threshold = 1) { #' @param progbar optional progress bar (default: `NULL`) #' #' @return list of two dataframes: features with variability (unprocessed) and without (processed) -#' @keywords internal +#' @noRd #' @author Zena Lapp, \email{zenalapp@@umich.edu} #' #' @examples -#' \dontrun{ #' process_novar_feats(mikropml::otu_small[, 2:ncol(otu_small)]) -#' } process_novar_feats <- function(features, progbar = NULL) { novar_feats <- NULL var_feats <- NULL @@ -297,13 +292,11 @@ process_novar_feats <- function(features, progbar = NULL) { #' @inheritParams process_novar_feats #' #' @return list of two dataframes: categorical (processed) and continuous features (unprocessed) -#' @keywords internal +#' @noRd #' @author Zena Lapp, \email{zenalapp@@umich.edu} #' #' @examples -#' \dontrun{ #' process_cat_feats(mikropml::otu_small[, 2:ncol(otu_small)]) -#' } process_cat_feats <- function(features, progbar = NULL) { feature_design_cat_mat <- NULL cont_feats <- NULL @@ -367,14 +360,12 @@ process_cat_feats <- function(features, progbar = NULL) { #' @inheritParams get_caret_processed_df #' #' @return dataframe of preprocessed features -#' @keywords internal +#' @noRd #' @author Zena Lapp, \email{zenalapp@@umich.edu} #' #' @examples -#' \dontrun{ #' process_cont_feats(mikropml::otu_small[, 2:ncol(otu_small)], c("center", "scale")) -#' } -process_cont_feats <- function(features, method) { +process_cont_feats <- function(features, method, impute_in_preprocessing) { transformed_cont <- NULL removed_cont <- NULL @@ -389,31 +380,12 @@ process_cont_feats <- function(features, method) { transformed_cont <- feats$processed removed_cont <- feats$removed } - sapply_fn <- select_apply("sapply") - cl <- sapply_fn(transformed_cont, function(x) { - class(x) - }) - missing <- - is.na(transformed_cont[, cl %in% c("integer", "numeric")]) - n_missing <- sum(missing) - if (n_missing > 0) { # impute missing data using the median value - transformed_cont <- sapply_fn(transformed_cont, function(x) { - if (class(x) %in% c("integer", "numeric")) { - m <- is.na(x) - x[m] <- stats::median(x, na.rm = TRUE) - } - return(x) - }) %>% dplyr::as_tibble() - message( - paste0( - n_missing, - " missing continuous value(s) were imputed using the median value of the feature." - ) - ) + if (impute_in_preprocessing) { + transformed_cont <- impute(transformed_cont) + } } } - } return(list(transformed_cont = transformed_cont, removed_cont = removed_cont)) } @@ -450,11 +422,10 @@ get_caret_processed_df <- function(features, method) { #' @inheritParams process_novar_feats #' @param full_rank whether matrix should be full rank or not (see `[caret::dummyVars]) #' @return design matrix -#' @keywords internal +#' @noRd #' @author Zena Lapp, \email{zenalapp@@umich.edu} #' #' @examples -#' \dontrun{ #' df <- data.frame( #' outcome = c("normal", "normal", "cancer"), #' var1 = 1:3, @@ -463,7 +434,6 @@ get_caret_processed_df <- function(features, method) { #' var4 = c(0, 1, 0) #' ) #' get_caret_dummyvars_df(df, TRUE) -#' } get_caret_dummyvars_df <- function(features, full_rank = FALSE, progbar = NULL) { check_features(features, check_missing = FALSE) if (!is.null(process_novar_feats(features, progbar = progbar)$novar_feats)) { @@ -481,13 +451,11 @@ get_caret_dummyvars_df <- function(features, full_rank = FALSE, progbar = NULL) #' @inheritParams group_correlated_features #' #' @return features where perfectly correlated ones are collapsed -#' @keywords internal +#' @noRd #' @author Zena Lapp, \email{zenalapp@@umich.edu} #' #' @examples -#' \dontrun{ #' collapse_correlated_features(mikropml::otu_small[, 2:ncol(otu_small)]) -#' } collapse_correlated_features <- function(features, group_neg_corr = TRUE, progbar = NULL) { feats_nocorr <- features grp_feats <- NULL diff --git a/R/run_ml.R b/R/run_ml.R index 7abe8b82..4526b199 100644 --- a/R/run_ml.R +++ b/R/run_ml.R @@ -144,6 +144,7 @@ run_ml <- group_partitions = NULL, corr_thresh = 1, seed = NA, + impute_after_split = FALSE, ...) { check_all( dataset, @@ -162,7 +163,7 @@ run_ml <- if (!is.na(seed)) { set.seed(seed) } - + # `future.apply` is required for `find_feature_importance()`. # check it here to adhere to the fail fast principle. if (find_feature_importance) { @@ -173,20 +174,20 @@ run_ml <- if (find_feature_importance) { check_cat_feats(dataset %>% dplyr::select(-outcome_colname)) } - + dataset <- dataset %>% randomize_feature_order(outcome_colname) %>% # convert tibble to dataframe to silence warning from caret::train(): # "Warning: Setting row names on a tibble is deprecated.." as.data.frame() - + outcomes_vctr <- dataset %>% dplyr::pull(outcome_colname) - + if (length(training_frac) == 1) { training_inds <- get_partition_indices(outcomes_vctr, - training_frac = training_frac, - groups = groups, - group_partitions = group_partitions + training_frac = training_frac, + groups = groups, + group_partitions = group_partitions ) } else { training_inds <- training_frac @@ -201,30 +202,34 @@ run_ml <- } check_training_frac(training_frac) check_training_indices(training_inds, dataset) - + train_data <- dataset[training_inds, ] test_data <- dataset[-training_inds, ] + if (impute_after_split == TRUE) { + train_data <- impute(train_data) + test_data <- impute(test_data) + } # train_groups & test_groups will be NULL if groups is NULL train_groups <- groups[training_inds] test_groups <- groups[-training_inds] - + if (is.null(hyperparameters)) { hyperparameters <- get_hyperparams_list(dataset, method) } tune_grid <- get_tuning_grid(hyperparameters, method) - - + + outcome_type <- get_outcome_type(outcomes_vctr) class_probs <- outcome_type != "continuous" - + if (is.null(perf_metric_function)) { perf_metric_function <- get_perf_metric_fn(outcome_type) } - + if (is.null(perf_metric_name)) { perf_metric_name <- get_perf_metric_name(outcome_type) } - + if (is.null(cross_val)) { cross_val <- define_cv( train_data, @@ -238,8 +243,8 @@ run_ml <- group_partitions = group_partitions ) } - - + + message("Training the model...") trained_model_caret <- train_model( train_data = train_data, @@ -254,7 +259,7 @@ run_ml <- if (!is.na(seed)) { set.seed(seed) } - + if (calculate_performance) { performance_tbl <- get_performance_tbl( trained_model_caret, @@ -269,7 +274,7 @@ run_ml <- } else { performance_tbl <- "Skipped calculating performance" } - + if (find_feature_importance) { message("Finding feature importance...") feature_importance_tbl <- get_feature_importance( @@ -287,7 +292,7 @@ run_ml <- } else { feature_importance_tbl <- "Skipped feature importance" } - + return( list( trained_model = trained_model_caret, diff --git a/tests/testthat/test-preprocess.R b/tests/testthat/test-preprocess.R index ef1555bb..2a95e3e3 100644 --- a/tests/testthat/test-preprocess.R +++ b/tests/testthat/test-preprocess.R @@ -442,14 +442,14 @@ test_that("process_cat_feats works", { test_that("process_cont_feats works", { expect_equal( - process_cont_feats(dplyr::as_tibble(test_df[1:3, 2]), method = c("center", "scale")), + process_cont_feats(dplyr::as_tibble(test_df[1:3, 2]), method = c("center", "scale"), impute_in_preprocessing = TRUE), list(transformed_cont = structure(list(value = c(-1, 0, 1)), row.names = c( NA, -3L ), class = c("tbl_df", "tbl", "data.frame")), removed_cont = character(0)) ) %>% suppressMessages() expect_message(expect_equal( - process_cont_feats(test_df[1:3, c(2, 9)], method = c("center", "scale")), + process_cont_feats(test_df[1:3, c(2, 9)], method = c("center", "scale"), impute_in_preprocessing = TRUE), list(transformed_cont = structure(list(var1 = c(-1, 0, 1), var8 = c( -0.707106781186547, 0.707106781186547, 0 @@ -671,3 +671,297 @@ test_that("preprocess_data replaces spaces in outcome column values (class label dat_proc ) %>% suppressMessages() }) + + + +test_that("setting impute param to false doesn't impute data", { + expect_message( + expect_equal( + preprocess_data(test_df, "outcome", + prefilter_threshold = -1, impute_in_preprocessing = FALSE + ), + list( + dat_transformed = structure(list(outcome = c( + "normal", "normal", + "cancer" + ), grp1 = c(0, 1, 0), grp2 = c(1, 0, 0), grp3 = c( + -1, + 0, 1 + ), grp4 = c(0, 0, 1), var8 = c( + -0.707106781186547, 0.707106781186547, + NA + )), row.names = c(NA, -3L), class = c("tbl_df", "tbl", "data.frame")), grp_feats = list(grp1 = c( + "var10_0", "var2_b", "var3_yes", + "var4_1", "var9_x" + ), grp2 = c("var10_1", "var2_a"), grp3 = c( + "var1", + "var12" + ), grp4 = c("var2_c", "var7_1", "var9_y"), var8 = "var8"), + removed_feats = c("var5", "var6", "var11") + ) + ), + "Removed " + ) %>% suppressMessages() + expect_message(expect_equal( + preprocess_data(test_df, "outcome", + prefilter_threshold = -1, + group_neg_corr = FALSE, + impute_in_preprocessing = FALSE + ), + list(dat_transformed = structure(list(outcome = c( + "normal", "normal", + "cancer" + ), grp1 = c(0, 1, 0), grp2 = c(1, 0, 0), grp3 = c( + -1, + 0, 1 + ), grp4 = c(0, 0, 1), var7_1 = c(1, 1, 0), var8 = c( + -0.707106781186547, + 0.707106781186547, NA + )), row.names = c(NA, -3L), class = c( + "tbl_df", + "tbl", "data.frame" + )), grp_feats = list( + grp1 = c( + "var10_0", "var2_b", + "var3_yes", "var4_1", "var9_x" + ), grp2 = c("var10_1", "var2_a"), grp3 = c("var1", "var12"), grp4 = c("var2_c", "var9_y"), var7_1 = "var7_1", + var8 = "var8" + ), removed_feats = c("var5", "var6", "var11")) + )) %>% suppressMessages() + expect_equal( + preprocess_data(test_df[1:3, c("outcome", "var1")], "outcome"), + list( + dat_transformed = dplyr::tibble( + outcome = c("normal", "normal", "cancer"), + var1 = c(-1, 0, 1) + ), + grp_feats = NULL, + removed_feats = character(0) + ) + ) %>% suppressMessages() + expect_equal( + preprocess_data(test_df[1:3, c("outcome", "var2")], "outcome", impute_in_preprocessing = FALSE), + list( + dat_transformed = dplyr::tibble( + outcome = c("normal", "normal", "cancer"), + var2_a = c(1, 0, 0), + var2_b = c(0, 1, 0), + var2_c = c(0, 0, 1), + ), + grp_feats = NULL, + removed_feats = character(0) + ) + ) %>% suppressMessages() + expect_equal( + preprocess_data(test_df[1:3, c("outcome", "var3")], "outcome", impute_in_preprocessing = FALSE), + list( + dat_transformed = dplyr::tibble( + outcome = c("normal", "normal", "cancer"), + var3_yes = c(0, 1, 0), + ), + grp_feats = NULL, + removed_feats = character(0) + ) + ) %>% suppressMessages() + expect_equal( + preprocess_data(test_df[1:3, c("outcome", "var4")], "outcome", + prefilter_threshold = -1, + impute_in_preprocessing = FALSE + ), + list( + dat_transformed = dplyr::tibble( + outcome = c("normal", "normal", "cancer"), + var4_1 = c(0, 1, 0), + ), + grp_feats = NULL, + removed_feats = character(0) + ) + ) %>% suppressMessages() + expect_message(expect_equal( + preprocess_data(test_df[1:3, ], "outcome", + method = NULL, + prefilter_threshold = -1, + impute_in_preprocessing = FALSE + ), + list( + dat_transformed = structure(list(outcome = c( + "normal", "normal", + "cancer" + ), grp1 = c(0, 1, 0), grp2 = c(1, 0, 0), grp3 = c( + 1, + 2, 3 + ), grp4 = c(0, 0, 1), var8 = c(5, 6, NA)), row.names = c( + NA, + -3L + ), class = c("tbl_df", "tbl", "data.frame")), grp_feats = list( + grp1 = c("var10_0", "var2_b", "var3_yes", "var4_1", "var9_x"), grp2 = c("var10_1", "var2_a"), grp3 = c("var1", "var12"), grp4 = c("var2_c", "var7_1", "var9_y"), var8 = "var8" + ), + removed_feats = c("var5", "var6", "var11") + ) + )) %>% suppressMessages() + expect_error( + preprocess_data(test_df[1:3, c("outcome", "var5")], "outcome", impute_in_preprocessing = FALSE), + "All features have zero variance" + ) %>% suppressMessages() + expect_message(expect_equal( + preprocess_data(test_df[1:3, ], + "outcome", + method = c("range"), + prefilter_threshold = -1, + impute_in_preprocessing = FALSE + ), + list( + dat_transformed = structure(list(outcome = c( + "normal", "normal", + "cancer" + ), grp1 = c(0, 1, 0), grp2 = c(1, 0, 0), grp3 = c( + 0, + 0.5, 1 + ), grp4 = c(0, 0, 1), var8 = c(0, 1, NA)), row.names = c( + NA, + -3L + ), class = c("tbl_df", "tbl", "data.frame")), grp_feats = list( + grp1 = c("var10_0", "var2_b", "var3_yes", "var4_1", "var9_x"), grp2 = c("var10_1", "var2_a"), grp3 = c("var1", "var12"), grp4 = c("var2_c", "var7_1", "var9_y"), var8 = "var8" + ), + removed_feats = c("var5", "var6", "var11") + ) + )) %>% suppressMessages() + expect_message(expect_equal( + preprocess_data(test_df[1:3, ], + "outcome", + remove_var = "zv", + prefilter_threshold = -1, + impute_in_preprocessing = FALSE + ), + list( + dat_transformed = structure(list(outcome = c( + "normal", "normal", + "cancer" + ), grp1 = c(0, 1, 0), grp2 = c(1, 0, 0), grp3 = c( + -1, + 0, 1 + ), grp4 = c(0, 0, 1), var8 = c( + -0.707106781186547, 0.707106781186547, + NA + )), row.names = c(NA, -3L), class = c("tbl_df", "tbl", "data.frame")), grp_feats = list(grp1 = c( + "var10_0", "var2_b", "var3_yes", + "var4_1", "var9_x" + ), grp2 = c("var10_1", "var2_a"), grp3 = c( + "var1", + "var12" + ), grp4 = c("var2_c", "var7_1", "var9_y"), var8 = "var8"), + removed_feats = c("var5", "var6", "var11") + ) + )) %>% suppressMessages() + expect_message( + expect_equal( + preprocess_data(test_df[1:3, ], "outcome", + remove_var = NULL, prefilter_threshold = -1, + impute_in_preprocessing = FALSE + ), + list( + dat_transformed = structure(list(outcome = c( + "normal", "normal", + "cancer" + ), grp1 = c(0, 1, 0), grp2 = c(1, 0, 0), grp3 = c( + -1, + 0, 1 + ), grp4 = c(0, 0, 1), var8 = c( + -0.707106781186547, 0.707106781186547, + NA + )), row.names = c(NA, -3L), class = c("tbl_df", "tbl", "data.frame")), grp_feats = list(grp1 = c( + "var10_0", "var2_b", "var3_yes", + "var4_1", "var9_x" + ), grp2 = c("var10_1", "var2_a"), grp3 = c( + "var1", + "var12" + ), grp4 = c("var2_c", "var7_1", "var9_y"), var8 = "var8"), + removed_feats = c("var5", "var6", "var11") + ) + ), + "Removing" + ) %>% suppressMessages() + expect_message(expect_equal( + preprocess_data(test_df[1:3, ], + "outcome", + remove_var = NULL, + collapse_corr_feats = FALSE, + prefilter_threshold = -1, + impute_in_preprocessing = FALSE + ), + list( + dat_transformed = structure(list(outcome = c( + "normal", "normal", + "cancer" + ), var1 = c(-1, 0, 1), var8 = c( + -0.707106781186547, 0.707106781186547, + NA + ), var12 = c(-1, 0, 1), var3_yes = c(0, 1, 0), var4_1 = c( + 0, + 1, 0 + ), var7_1 = c(1, 1, 0), var2_a = c(1, 0, 0), var2_b = c( + 0, + 1, 0 + ), var2_c = c(0, 0, 1), var9_x = c(0, 1, 0), var9_y = c( + 0, + 0, 1 + ), var10_0 = c(0, 1, 0), var10_1 = c(1, 0, 0), var5 = c( + 0, + 0, 0 + ), var6 = c(0, 0, 0), var11 = c(1, 1, 1)), row.names = c( + NA, + -3L + ), class = c("tbl_df", "tbl", "data.frame")), grp_feats = NULL, + removed_feats = character(0) + ) + )) %>% suppressMessages() + expect_error(preprocess_data(test_df[1:3, ], + "outcome", + method = c("asdf") + )) %>% suppressMessages() + expect_message(expect_equal( + preprocess_data(test_df, + "outcome", + to_numeric = FALSE, + impute_in_preprocessing = FALSE + ), + list(dat_transformed = structure(list(outcome = c( + "normal", "normal", + "cancer" + ), var1 = c(-1, 0, 1), grp1 = c(0, 1, 0), grp2 = c( + 1, + 0, 0 + ), grp3 = c(0, 0, 1), var8 = c( + -0.707106781186547, 0.707106781186547, + NA + )), row.names = c(NA, -3L), class = c("tbl_df", "tbl", "data.frame")), grp_feats = list(var1 = "var1", grp1 = c( + "var10_0", "var12_2", + "var2_b", "var3_yes", "var4_1", "var9_x" + ), grp2 = c( + "var10_1", + "var12_1", "var2_a" + ), grp3 = c( + "var12_3", "var2_c", "var7_1", + "var9_y" + ), var8 = "var8"), removed_feats = c( + "var5", "var6", + "var11" + )) + )) %>% suppressMessages() +}) + + +test_that("default parameter for impute_in_preprocessing is TRUE", { + expect_message( + expect_equal( + preprocess_data(test_df, "outcome", + prefilter_threshold = -1 + ), + preprocess_data(test_df, "outcome", + prefilter_threshold = -1, impute_in_preprocessing = TRUE)) + , + "Removed " + ) %>% suppressMessages() +}) + + diff --git a/tests/testthat/test-run_ml.R b/tests/testthat/test-run_ml.R index e91a0611..ee140858 100644 --- a/tests/testthat/test-run_ml.R +++ b/tests/testthat/test-run_ml.R @@ -309,3 +309,86 @@ test_that("models use case weights when provided", { expect_true("weights" %in% colnames(results_custom_train$trained_model$pred)) expect_false("weights" %in% colnames(otu_mini_bin_results_glmnet$trained_model$pred)) }) + +test_that("make sure impute function on train data set works", { + train_data <- data.frame(outcome = c("normal", "normal", "cancer", "cancer"), + var1 = 1:4, + var2 = c("a", "b", "c", "d"), + var3 = c("no", "yes", "no", "no"), + var4 = c(0, 1, 0, 0), + var5 = c(0, 0, 0, 0), + var6 = c("no", "no", "no", "no"), + var7 = c(1, 1, 0, 0), + var8 = c(5, 6, NA, 7), + var9 = c(NA, 1, 1, 0), + var10 = c(1, 0, NA, NA)) + test_data <- data.frame(outcome = c("normal", "normal", "cancer", "cancer"), var11 = c(1, 1, NA, NA), + var12 = c(1, 2, NA, 4)) + train_data_output <- dplyr::tibble( + outcome = c("normal", "normal", "cancer", "cancer"), + var1 = c('1', '2', '3', '4'), + var2 = c("a", "b", "c", "d"), + var3 = c("no", "yes", "no", "no"), + var4 = c('0', '1', '0', '0'), + var5 = c('0', '0', '0', '0'), + var6 = c("no", "no", "no", "no"), + var7 = c('1', '1', '0', '0'), + var8 = c('5', '6', '6', '7'), + var9 = c('1', '1', '1', '0'), + var10 = c('1', '0', '0.5', '0.5')) + results_output <- impute(train_data) + expect_equal(train_data_output, results_output) +}) + +test_that("make sure impute function on test data set works", { + train_data <- data.frame(outcome = c("normal", "normal", "cancer", "cancer"), + var1 = 1:4, + var2 = c("a", "b", "c", "d"), + var3 = c("no", "yes", "no", "no"), + var4 = c(0, 1, 0, 0), + var5 = c(0, 0, 0, 0), + var6 = c("no", "no", "no", "no"), + var7 = c(1, 1, 0, 0), + var8 = c(5, 6, NA, 7), + var9 = c(NA, 1, 1, 0), + var10 = c(1, 0, NA, NA)) + test_data <- data.frame(outcome = c("normal", "normal", "cancer", "cancer"), var11 = c(1, 1, NA, NA), + var12 = c(1, 2, NA, 4)) + test_data_output <- dplyr::tibble( + outcome = c("normal", "normal", "cancer", "cancer"), var11 = c('1', '1', '1', '1'), + var12 = c('1', '2', '2', '4')) + results_output <- impute(test_data) + expect_equal(test_data_output, results_output) + }) + + +temp_df <- otu_mini_bin +Otu00011 <- c(6, 6, 6, NA, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, NA, 6, 6, 6, 6, 6, 6, 6, + 6, NA, 6, 6, 6, 6, 6, 6, 6, 6, 6, NA, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, NA, 6, 6, 6, 6, 6, NA, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, NA, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, NA, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, NA, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, NA, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, + 5, 5) +mini_bin_with_nas <- otu_mini_bin %>% mutate(Otu00011) + +test_that("data gets imputed when impute_after_split is set to TRUE", { + results <- run_ml(mini_bin_with_nas, "glmnet", 'dx', NULL, FALSE, TRUE, 5, 100, NULL, 0.5, NULL, NULL, NULL, NULL, 1, 2019, TRUE) + temp <- colSums(is.na(results$test_data)) + num_nas <- sum(temp) + expect_equal(0, num_nas) %>% suppressMessages() + temp <- colSums(is.na(results$trained_model$trainingData)) + num_nas <- sum(temp) + expect_equal(0, num_nas) %>% suppressMessages() +}) + +test_that("data is not imputed when impute_after_split is set to FALSE", { + expect_error(run_ml(mini_bin_with_nas, "glmnet", 'dx', NULL, FALSE, TRUE, 5, 100, NULL, 0.5, NULL, NULL, NULL, NULL, 1, 2019, FALSE),NULL) %>% suppressMessages() +}) + +test_that("data is not imputed when impute_after_split is not set", { + expect_error(run_ml(mini_bin_with_nas, "glmnet", 'dx', NULL, FALSE, TRUE, 5, 100, NULL, 0.5, NULL, NULL, NULL, NULL, 1, 2019),NULL) %>% suppressMessages() +}) \ No newline at end of file diff --git a/vignettes/preprocess.Rmd b/vignettes/preprocess.Rmd index fe8ab009..31e7ab1c 100644 --- a/vignettes/preprocess.Rmd +++ b/vignettes/preprocess.Rmd @@ -369,6 +369,21 @@ going on (i.e. the median value is used): # preprocess raw dataset with missing value in continuous feature preprocess_data(dataset = miss_cont_df, outcome_colname = "outcome", method = NULL) ``` +#### Impute after the train/test split in run_ml.R +To delay this step until after the train/test split in run_ml.R, set the impute_in_preprocessing option to FALSE as shown here: + +```{r} +# preprocess raw dataset with missing value in continuous feature +preprocess_data(dataset = miss_cont_df, outcome_colname = "outcome", method = NULL, impute_in_preprocessing=FALSE) +``` +To impute the data after the train/test split in run_ml.R, set the impute_after_split option to TRUE as shown here: +```{r, eval = FALSE} +results <- run_ml(otu_mini_bin, + "glmnet", + outcome_colname = "dx", + seed = 2019, impute_after_split = TRUE +) +``` ## Putting it all together