diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 26f60654..e12e34f9 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -19,6 +19,7 @@ jobs: fail-fast: false matrix: config: + - {os: macOS-latest, r: 'devel'} - {os: macOS-latest, r: 'release'} - {os: windows-latest, r: 'release'} - {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'} diff --git a/NEWS.md b/NEWS.md index 452c4b8e..af236574 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,7 +1,8 @@ # development version - mikropml now requires R version 4.1.0 or greater due to an update in the randomForest package (#292). -- Fix bug where `cv_times` had no effect on reported repeats for cross-validation (#291, @kelly-sovacool). +- New function `compare_models()` compares the performance of two models with a permutation test (#295, @courtneyarmour). +- Fixed a bug where `cv_times` did not affect the reported repeats for cross-validation (#291, @kelly-sovacool). - Made minor documentation improvements (#293, @kelly-sovacool) # mikropml 1.2.2 diff --git a/R/checks.R b/R/checks.R index 0ce412af..08a2a67f 100644 --- a/R/checks.R +++ b/R/checks.R @@ -274,7 +274,7 @@ check_outcome_value <- function(dataset, outcome_colname) { stop( paste0( "A binary or multi-class outcome variable is required, but this dataset has ", - num_outcomes, " outcome(s): ", paste(outcomes, collapse = ", ") + num_outcomes, " outcome(s): ", paste(outcomes, collapse = ", ") ) ) } diff --git a/R/compare_models.R b/R/compare_models.R new file mode 100644 index 00000000..13ba1a19 --- /dev/null +++ b/R/compare_models.R @@ -0,0 +1,180 @@ +#' Average metric difference +#' +#' Calculate the difference in the mean of the metric for two groups +#' +#' @param sub_data subset of the merged performance data frame for two groups +#' @param group_name name of column with group variable +#' @param metric metric to compare +#' +#' @return numeric difference in the average metric between the two groups +#' +#' @export +#' @author Courtney Armour, \email{armourc@@umich.edu} +#' +#' @examples +#' df <- dplyr::tibble( +#' condition = c("a", "a", "b", "b"), +#' AUC = c(.2, 0.3, 0.8, 0.9) +#' ) +#' get_difference(df, "condition", "AUC") +#' +get_difference <- function(sub_data, group_name, metric) { + if (!is.numeric(sub_data %>% dplyr::pull(metric))) { + stop(paste0( + "The metric `", metric, + "` is not numeric, please check that you specified the right column." + )) + } + means <- sub_data %>% + dplyr::group_by(.data[[group_name]]) %>% + dplyr::summarise(meanVal = mean(.data[[metric]]), .groups = "drop") %>% + dplyr::pull(meanVal) + abs(diff(means)) +} + +#' Shuffle the rows in a column +#' +#' @param dat a data frame containing `col_name` +#' @param col_name column name to shuffle +#' +#' @return `dat` with the rows of `col_name` shuffled +#' @export +#' @author Courtney R Armour, \email{armourc@@umich.edu} +#' +#' @examples +#' set.seed(123) +#' df <- dplyr::tibble( +#' condition = c("a", "a", "b", "b"), +#' AUC = c(.2, 0.3, 0.8, 0.9) +#' ) +#' shuffle_group(df, "condition") +shuffle_group <- function(dat, col_name) { + if (!(col_name %in% colnames(dat))) { + stop(paste0("The col_name `", col_name, "` does not exist in the data frame.")) + } + group_vals <- dat %>% + dplyr::pull({{ col_name }}) + group_vals_shuffled <- base::sample(group_vals) + + data_shuffled <- dat %>% + dplyr::mutate(!!col_name := group_vals_shuffled) + + return(data_shuffled) +} + + +#' Calculated a permuted p-value comparing two models +#' +#' @inheritParams compare_models +#' @param group_1 name of one group to compare +#' @param group_2 name of other group to compare +#' +#' @return numeric p-value comparing two models +#' @export +#' @author Begüm Topçuoğlu, \email{topcuoglu.begum@@gmail.com} +#' @author Courtney R Armour, \email{armourc@@umich.edu} +#' +#' @examples +#' df <- dplyr::tibble( +#' model = c("rf", "rf", "glmnet", "glmnet", "svmRadial", "svmRadial"), +#' AUC = c(.2, 0.3, 0.8, 0.9, 0.85, 0.95) +#' ) +#' set.seed(123) +#' permute_p_value(df, "AUC", "model", "rf", "glmnet", nperm = 100) +permute_p_value <- function(merged_data, metric, group_name, group_1, group_2, nperm = 10000) { + # check that the metric and group exist in data + if (!(metric %in% colnames(merged_data))) { + stop(paste0("The metric `", metric, "` does not exist in the data.")) + } + if (!(group_name %in% colnames(merged_data))) { + stop(paste0("The group_name `", group_name, "` does not exist in the data.")) + } + # check that group_1 and group_2 exist in the data + if (!(group_1 %in% (merged_data %>% dplyr::pull(group_name)))) { + stop(paste0("group_1 `", group_1, "` does not exist in the data.")) + } + if (!(group_2 %in% (merged_data %>% dplyr::pull(group_name)))) { + stop(paste0("group_2 `", group_2, "` does not exist in the data.")) + } + + # subset results to select metric and group columns and + # filter to only the two groups of interest + sub_data <- merged_data %>% + dplyr::select({{ metric }}, {{ group_name }}) %>% + dplyr::filter(.data[[group_name]] == {{ group_1 }} | .data[[group_name]] == {{ group_2 }}) + + # observed difference: quantify the absolute value of the difference + # in metric between the two groups + metric_obs <- get_difference(sub_data, {{ group_name }}, {{ metric }}) + + # shuffled difference: quantify the absolute value of the difference + # in metric between the two groups after shuffling group labels + rep_fn <- select_apply("replicate") + metric_null <- rep_fn( + nperm, + get_difference( + shuffle_group(sub_data, group_name), + group_name, + metric + ) + ) + + p_value <- calc_pvalue(metric_null, metric_obs) + return(p_value) +} + + +#' Compute all pairs of comparisons +#' calculate permuted p-value across all pairs of group variable. +#' wrapper for `permute_p_value` +#' +#' @param merged_data the concatenated performance data from `run_ml` +#' @param metric metric to compare, must be numeric +#' @param group_name column with group variables to compare +#' @param nperm number of permutations, default=10000 +#' +#' @return a table of p-values for all pairs of group varible +#' @export +#' @author Courtney R Armour, \email{armourc@@umich.edu} +#' +#' @examples +#' df <- dplyr::tibble( +#' model = c("rf", "rf", "glmnet", "glmnet", "svmRadial", "svmRadial"), +#' AUC = c(.2, 0.3, 0.8, 0.9, 0.85, 0.95) +#' ) +#' set.seed(123) +#' compare_models(df, "AUC", "model", nperm = 10) +compare_models <- function(merged_data, metric, group_name, nperm = 10000) { + # check that the metric and group exist in data + if (!(metric %in% colnames(merged_data))) { + stop("The metric does not exist in the data.") + } + if (!(group_name %in% colnames(merged_data))) { + stop("The group_name does not exist in the data.") + } + + # identify all unique groups in group variable + groups <- merged_data %>% + dplyr::pull({{ group_name }}) %>% + unique() + + # create a table with all possible comparisons of groups + # without repeating pairings + p_table <- tidyr::expand_grid( + x = 1:length(groups), + y = 1:length(groups) + ) %>% + dplyr::filter(x < y) %>% + dplyr::mutate( + group1 = groups[x], + group2 = groups[y] + ) %>% + dplyr::select(-x, -y) %>% + dplyr::group_by(group1, group2) %>% + dplyr::summarize( + p_value = permute_p_value(merged_data, metric, group_name, group1, group2, nperm), + .groups = "drop" + ) + + return(as.data.frame(p_table)) +} diff --git a/R/utils.R b/R/utils.R index 1b6bcb98..74c98234 100644 --- a/R/utils.R +++ b/R/utils.R @@ -234,6 +234,10 @@ is_whole_number <- function(x, tol = .Machine$double.eps^0.5) { #' Calculate the p-value for a permutation test #' +#' compute Monte Carlo p-value with correction +#' based on formula from Page 158 of 'Bootstrap methods and their application' +#' By Davison & Hinkley 1997 +#' #' @param vctr vector of statistics #' @param test_stat the test statistic #' @@ -243,5 +247,5 @@ is_whole_number <- function(x, tol = .Machine$double.eps^0.5) { #' @noRd #' @author Kelly Sovacool \email{sovacool@@umich.edu} calc_pvalue <- function(vctr, test_stat) { - return(sum(vctr > test_stat) / length(vctr)) + return((sum(vctr >= test_stat) + 1) / (length(vctr) + 1)) } diff --git a/_pkgdown.yml b/_pkgdown.yml index 186d6a96..b126508f 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -29,10 +29,11 @@ reference: - mikropml - preprocess_data - run_ml -- title: Plotting helpers +- title: Plotting & evalutation helpers desc: > - Visualize performance to help you tune hyperparameters and choose model methods. + Visualize & evalutate performance to help you tune hyperparameters and choose model methods. contents: + - compare_models - starts_with('plot') - tidy_perf_data - get_hp_performance diff --git a/data/otu_mini_bin.rda b/data/otu_mini_bin.rda index c3ef3843..ccb5cf7a 100644 Binary files a/data/otu_mini_bin.rda and b/data/otu_mini_bin.rda differ diff --git a/data/otu_mini_bin_results_glmnet.rda b/data/otu_mini_bin_results_glmnet.rda index 7a0f0b6d..2c85c788 100644 Binary files a/data/otu_mini_bin_results_glmnet.rda and b/data/otu_mini_bin_results_glmnet.rda differ diff --git a/data/otu_mini_bin_results_rf.rda b/data/otu_mini_bin_results_rf.rda index 6ba8ead6..ac6e8a1e 100644 Binary files a/data/otu_mini_bin_results_rf.rda and b/data/otu_mini_bin_results_rf.rda differ diff --git a/data/otu_mini_bin_results_rpart2.rda b/data/otu_mini_bin_results_rpart2.rda index 753c04aa..0e652ae6 100644 Binary files a/data/otu_mini_bin_results_rpart2.rda and b/data/otu_mini_bin_results_rpart2.rda differ diff --git a/data/otu_mini_bin_results_svmRadial.rda b/data/otu_mini_bin_results_svmRadial.rda index 46b3d262..8a3d16bf 100644 Binary files a/data/otu_mini_bin_results_svmRadial.rda and b/data/otu_mini_bin_results_svmRadial.rda differ diff --git a/data/otu_mini_bin_results_xgbTree.rda b/data/otu_mini_bin_results_xgbTree.rda index 1d923fe7..8b21b0bb 100644 Binary files a/data/otu_mini_bin_results_xgbTree.rda and b/data/otu_mini_bin_results_xgbTree.rda differ diff --git a/data/otu_mini_cont_results_glmnet.rda b/data/otu_mini_cont_results_glmnet.rda index ab97fcc3..44eaa16e 100644 Binary files a/data/otu_mini_cont_results_glmnet.rda and b/data/otu_mini_cont_results_glmnet.rda differ diff --git a/data/otu_mini_cont_results_nocv.rda b/data/otu_mini_cont_results_nocv.rda index 01c9ffdd..b613687c 100644 Binary files a/data/otu_mini_cont_results_nocv.rda and b/data/otu_mini_cont_results_nocv.rda differ diff --git a/data/otu_mini_cv.rda b/data/otu_mini_cv.rda index 9cdbd31b..b8e54749 100644 Binary files a/data/otu_mini_cv.rda and b/data/otu_mini_cv.rda differ diff --git a/data/otu_mini_multi.rda b/data/otu_mini_multi.rda index fb6f1af1..2c773fac 100644 Binary files a/data/otu_mini_multi.rda and b/data/otu_mini_multi.rda differ diff --git a/data/otu_mini_multi_group.rda b/data/otu_mini_multi_group.rda index 4b0a1521..07f7eff5 100644 Binary files a/data/otu_mini_multi_group.rda and b/data/otu_mini_multi_group.rda differ diff --git a/data/otu_mini_multi_results_glmnet.rda b/data/otu_mini_multi_results_glmnet.rda index 79d8dca5..9d506a8b 100644 Binary files a/data/otu_mini_multi_results_glmnet.rda and b/data/otu_mini_multi_results_glmnet.rda differ diff --git a/docs/404.html b/docs/404.html index 0f7568a8..986332d0 100644 --- a/docs/404.html +++ b/docs/404.html @@ -49,7 +49,7 @@ Reference