Skip to content

Commit

Permalink
Merge pull request #295 from SchlossLab/iss-268_compare_models
Browse files Browse the repository at this point in the history
Create function to compare models with a permutation test
  • Loading branch information
kelly-sovacool authored May 18, 2022
2 parents d6b511d + 2296a0d commit c418c6a
Show file tree
Hide file tree
Showing 80 changed files with 485 additions and 209 deletions.
1 change: 1 addition & 0 deletions .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:
fail-fast: false
matrix:
config:
- {os: macOS-latest, r: 'devel'}
- {os: macOS-latest, r: 'release'}
- {os: windows-latest, r: 'release'}
- {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'}
Expand Down
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# development version

- mikropml now requires R version 4.1.0 or greater due to an update in the randomForest package (#292).
- Fix bug where `cv_times` had no effect on reported repeats for cross-validation (#291, @kelly-sovacool).
- New function `compare_models()` compares the performance of two models with a permutation test (#295, @courtneyarmour).
- Fixed a bug where `cv_times` did not affect the reported repeats for cross-validation (#291, @kelly-sovacool).
- Made minor documentation improvements (#293, @kelly-sovacool)

# mikropml 1.2.2
Expand Down
2 changes: 1 addition & 1 deletion R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ check_outcome_value <- function(dataset, outcome_colname) {
stop(
paste0(
"A binary or multi-class outcome variable is required, but this dataset has ",
num_outcomes, " outcome(s): ", paste(outcomes, collapse = ", ")
num_outcomes, " outcome(s): ", paste(outcomes, collapse = ", ")
)
)
}
Expand Down
180 changes: 180 additions & 0 deletions R/compare_models.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
#' Average metric difference
#'
#' Calculate the difference in the mean of the metric for two groups
#'
#' @param sub_data subset of the merged performance data frame for two groups
#' @param group_name name of column with group variable
#' @param metric metric to compare
#'
#' @return numeric difference in the average metric between the two groups
#'
#' @export
#' @author Courtney Armour, \email{armourc@@umich.edu}
#'
#' @examples
#' df <- dplyr::tibble(
#' condition = c("a", "a", "b", "b"),
#' AUC = c(.2, 0.3, 0.8, 0.9)
#' )
#' get_difference(df, "condition", "AUC")
#'
get_difference <- function(sub_data, group_name, metric) {
if (!is.numeric(sub_data %>% dplyr::pull(metric))) {
stop(paste0(
"The metric `", metric,
"` is not numeric, please check that you specified the right column."
))
}
means <- sub_data %>%
dplyr::group_by(.data[[group_name]]) %>%
dplyr::summarise(meanVal = mean(.data[[metric]]), .groups = "drop") %>%
dplyr::pull(meanVal)
abs(diff(means))
}

#' Shuffle the rows in a column
#'
#' @param dat a data frame containing `col_name`
#' @param col_name column name to shuffle
#'
#' @return `dat` with the rows of `col_name` shuffled
#' @export
#' @author Courtney R Armour, \email{armourc@@umich.edu}
#'
#' @examples
#' set.seed(123)
#' df <- dplyr::tibble(
#' condition = c("a", "a", "b", "b"),
#' AUC = c(.2, 0.3, 0.8, 0.9)
#' )
#' shuffle_group(df, "condition")
shuffle_group <- function(dat, col_name) {
if (!(col_name %in% colnames(dat))) {
stop(paste0("The col_name `", col_name, "` does not exist in the data frame."))
}
group_vals <- dat %>%
dplyr::pull({{ col_name }})
group_vals_shuffled <- base::sample(group_vals)

data_shuffled <- dat %>%
dplyr::mutate(!!col_name := group_vals_shuffled)

return(data_shuffled)
}


#' Calculated a permuted p-value comparing two models
#'
#' @inheritParams compare_models
#' @param group_1 name of one group to compare
#' @param group_2 name of other group to compare
#'
#' @return numeric p-value comparing two models
#' @export
#' @author Begüm Topçuoğlu, \email{topcuoglu.begum@@gmail.com}
#' @author Courtney R Armour, \email{armourc@@umich.edu}
#'
#' @examples
#' df <- dplyr::tibble(
#' model = c("rf", "rf", "glmnet", "glmnet", "svmRadial", "svmRadial"),
#' AUC = c(.2, 0.3, 0.8, 0.9, 0.85, 0.95)
#' )
#' set.seed(123)
#' permute_p_value(df, "AUC", "model", "rf", "glmnet", nperm = 100)
permute_p_value <- function(merged_data, metric, group_name, group_1, group_2, nperm = 10000) {
# check that the metric and group exist in data
if (!(metric %in% colnames(merged_data))) {
stop(paste0("The metric `", metric, "` does not exist in the data."))
}
if (!(group_name %in% colnames(merged_data))) {
stop(paste0("The group_name `", group_name, "` does not exist in the data."))
}
# check that group_1 and group_2 exist in the data
if (!(group_1 %in% (merged_data %>% dplyr::pull(group_name)))) {
stop(paste0("group_1 `", group_1, "` does not exist in the data."))
}
if (!(group_2 %in% (merged_data %>% dplyr::pull(group_name)))) {
stop(paste0("group_2 `", group_2, "` does not exist in the data."))
}

# subset results to select metric and group columns and
# filter to only the two groups of interest
sub_data <- merged_data %>%
dplyr::select({{ metric }}, {{ group_name }}) %>%
dplyr::filter(.data[[group_name]] == {{ group_1 }} | .data[[group_name]] == {{ group_2 }})

# observed difference: quantify the absolute value of the difference
# in metric between the two groups
metric_obs <- get_difference(sub_data, {{ group_name }}, {{ metric }})

# shuffled difference: quantify the absolute value of the difference
# in metric between the two groups after shuffling group labels
rep_fn <- select_apply("replicate")
metric_null <- rep_fn(
nperm,
get_difference(
shuffle_group(sub_data, group_name),
group_name,
metric
)
)

p_value <- calc_pvalue(metric_null, metric_obs)
return(p_value)
}


#' Compute all pairs of comparisons
#' calculate permuted p-value across all pairs of group variable.
#' wrapper for `permute_p_value`
#'
#' @param merged_data the concatenated performance data from `run_ml`
#' @param metric metric to compare, must be numeric
#' @param group_name column with group variables to compare
#' @param nperm number of permutations, default=10000
#'
#' @return a table of p-values for all pairs of group varible
#' @export
#' @author Courtney R Armour, \email{armourc@@umich.edu}
#'
#' @examples
#' df <- dplyr::tibble(
#' model = c("rf", "rf", "glmnet", "glmnet", "svmRadial", "svmRadial"),
#' AUC = c(.2, 0.3, 0.8, 0.9, 0.85, 0.95)
#' )
#' set.seed(123)
#' compare_models(df, "AUC", "model", nperm = 10)
compare_models <- function(merged_data, metric, group_name, nperm = 10000) {
# check that the metric and group exist in data
if (!(metric %in% colnames(merged_data))) {
stop("The metric does not exist in the data.")
}
if (!(group_name %in% colnames(merged_data))) {
stop("The group_name does not exist in the data.")
}

# identify all unique groups in group variable
groups <- merged_data %>%
dplyr::pull({{ group_name }}) %>%
unique()

# create a table with all possible comparisons of groups
# without repeating pairings
p_table <- tidyr::expand_grid(
x = 1:length(groups),
y = 1:length(groups)
) %>%
dplyr::filter(x < y) %>%
dplyr::mutate(
group1 = groups[x],
group2 = groups[y]
) %>%
dplyr::select(-x, -y) %>%
dplyr::group_by(group1, group2) %>%
dplyr::summarize(
p_value = permute_p_value(merged_data, metric, group_name, group1, group2, nperm),
.groups = "drop"
)

return(as.data.frame(p_table))
}
6 changes: 5 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ is_whole_number <- function(x, tol = .Machine$double.eps^0.5) {

#' Calculate the p-value for a permutation test
#'
#' compute Monte Carlo p-value with correction
#' based on formula from Page 158 of 'Bootstrap methods and their application'
#' By Davison & Hinkley 1997
#'
#' @param vctr vector of statistics
#' @param test_stat the test statistic
#'
Expand All @@ -243,5 +247,5 @@ is_whole_number <- function(x, tol = .Machine$double.eps^0.5) {
#' @noRd
#' @author Kelly Sovacool \email{sovacool@@umich.edu}
calc_pvalue <- function(vctr, test_stat) {
return(sum(vctr > test_stat) / length(vctr))
return((sum(vctr >= test_stat) + 1) / (length(vctr) + 1))
}
5 changes: 3 additions & 2 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ reference:
- mikropml
- preprocess_data
- run_ml
- title: Plotting helpers
- title: Plotting & evalutation helpers
desc: >
Visualize performance to help you tune hyperparameters and choose model methods.
Visualize & evalutate performance to help you tune hyperparameters and choose model methods.
contents:
- compare_models
- starts_with('plot')
- tidy_perf_data
- get_hp_performance
Expand Down
Binary file modified data/otu_mini_bin.rda
Binary file not shown.
Binary file modified data/otu_mini_bin_results_glmnet.rda
Binary file not shown.
Binary file modified data/otu_mini_bin_results_rf.rda
Binary file not shown.
Binary file modified data/otu_mini_bin_results_rpart2.rda
Binary file not shown.
Binary file modified data/otu_mini_bin_results_svmRadial.rda
Binary file not shown.
Binary file modified data/otu_mini_bin_results_xgbTree.rda
Binary file not shown.
Binary file modified data/otu_mini_cont_results_glmnet.rda
Binary file not shown.
Binary file modified data/otu_mini_cont_results_nocv.rda
Binary file not shown.
Binary file modified data/otu_mini_cv.rda
Binary file not shown.
Binary file modified data/otu_mini_multi.rda
Binary file not shown.
Binary file modified data/otu_mini_multi_group.rda
Binary file not shown.
Binary file modified data/otu_mini_multi_results_glmnet.rda
Binary file not shown.
4 changes: 2 additions & 2 deletions docs/404.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions docs/CODE_OF_CONDUCT.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions docs/CONTRIBUTING.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions docs/LICENSE-text.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions docs/LICENSE.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions docs/SUPPORT.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions docs/articles/index.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit c418c6a

Please sign in to comment.