From ac6cdd44d94d678526c09748287e847199643bd3 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Tue, 27 Aug 2024 09:15:02 -0500 Subject: [PATCH] arbitrary test statistics in `calculate()` --- NAMESPACE | 1 + NEWS.md | 5 ++ R/calculate.R | 106 ++++++++++++++++++++++++++--- R/observe.R | 2 + man/calculate.Rd | 39 ++++++++++- man/observe.Rd | 39 ++++++++++- tests/testthat/_snaps/calculate.md | 46 +++++++++++++ tests/testthat/test-calculate.R | 87 ++++++++++++++++++++++- tests/testthat/test-observe.R | 13 ++++ 9 files changed, 322 insertions(+), 16 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 8431aee1..d948fca5 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,5 +1,6 @@ # Generated by roxygen2: do not edit by hand +S3method(calc_impl,"function") S3method(calc_impl,Chisq) S3method(calc_impl,F) S3method(calc_impl,correlation) diff --git a/NEWS.md b/NEWS.md index 40e35b56..18f591a2 100755 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,10 @@ # infer (development version) +* Introduced support for arbitrary test statistics in `calculate()`. In addition + to the pre-implemented `calculate(stat)` options, taken as strings, users can + now supply a function defining any scalar-valued test statistic. See + `?calculate()` to learn more. + # infer 1.0.7 * The aliases `p_value()` and `conf_int()`, first deprecated 6 years ago, now diff --git a/R/calculate.R b/R/calculate.R index 9bcff31a..2b9aa8af 100755 --- a/R/calculate.R +++ b/R/calculate.R @@ -12,13 +12,15 @@ #' #' @param x The output from [generate()] for computation-based inference or the #' output from [hypothesize()] piped in to here for theory-based inference. -#' @param stat A string giving the type of the statistic to calculate. Current +#' @param stat A string giving the type of the statistic to calculate or a +#' function that takes in a replicate of `x` and returns a scalar value. Current #' options include `"mean"`, `"median"`, `"sum"`, `"sd"`, `"prop"`, `"count"`, #' `"diff in means"`, `"diff in medians"`, `"diff in props"`, `"Chisq"` (or #' `"chisq"`), `"F"` (or `"f"`), `"t"`, `"z"`, `"ratio of props"`, `"slope"`, #' `"odds ratio"`, `"ratio of means"`, and `"correlation"`. `infer` only #' supports theoretical tests on one or two means via the `"t"` distribution -#' and one or two proportions via the `"z"`. +#' and one or two proportions via the `"z"`. See the "Arbitrary test statistics" +#' section below for more on how to define a custom statistic. #' @param order A string vector of specifying the order in which the levels of #' the explanatory variable should be ordered for subtraction (or division #' for ratio-based statistics), where `order = c("first", "second")` means @@ -31,6 +33,38 @@ #' #' @return A tibble containing a `stat` column of calculated statistics. #' +#' @section Arbitrary test statistics: +#' +#' In addition to the pre-implemented statistics documented in `stat`, users can +#' supply an arbitrary test statistic by supplying a function to the `stat` +#' argument. +#' +#' The function should have arguments `stat(x, order, ...)`, where `x` is one +#' replicate's worth of `x`. The `order` argument and ellipses will be supplied +#' directly to the `stat` function. Internally, `calculate()` will split `x` up +#' into data frames by replicate and pass them one-by-one to the supplied `stat`. +#' For example, to implement `stat = "mean"` as a function, one could write: +#' +#' ```r +#' stat_mean <- function(x, order, ...) {mean(x$hours)} +#' obs_mean <- +#' gss %>% +#' specify(response = hours) %>% +#' calculate(stat = stat_mean) +#' +#' set.seed(1) +#' null_dist_mean <- +#' gss %>% +#' specify(response = hours) %>% +#' hypothesize(null = "point", mu = 40) %>% +#' generate(reps = 5, type = "bootstrap") %>% +#' calculate(stat = stat_mean) +#' ``` +#' +#' Note that the same `stat_mean` function is supplied to both `generate()`d and +#' non-`generate()`d infer objects--no need to implement support for grouping +#' by `replicate` yourself. +#' #' @section Missing levels in small samples: #' In some cases, when bootstrapping with small samples, some generated #' bootstrap samples will have only one level of the explanatory variable @@ -97,22 +131,23 @@ calculate <- function(x, ...) { check_type(x, tibble::is_tibble) check_if_mlr(x, "calculate") - stat <- check_calculate_stat(stat) - check_input_vs_stat(x, stat) - check_point_params(x, stat) + stat_chr <- stat_chr(stat) + stat_chr <- check_calculate_stat(stat_chr) + check_input_vs_stat(x, stat_chr) + check_point_params(x, stat_chr) - order <- check_order(x, order, in_calculate = TRUE, stat) + order <- check_order(x, order, in_calculate = TRUE, stat_chr) if (!is_generated(x)) { x$replicate <- 1L } - x <- message_on_excessive_null(x, stat = stat, fn = "calculate") - x <- warn_on_insufficient_null(x, stat, ...) + x <- message_on_excessive_null(x, stat = stat_chr, fn = "calculate") + x <- warn_on_insufficient_null(x, stat_chr, ...) # Use S3 method to match correct calculation result <- calc_impl( - structure(stat, class = gsub(" ", "_", stat)), x, order, ... + structure(stat, class = gsub(" ", "_", stat_chr)), x, order, ... ) result <- copy_attrs(to = result, from = x) @@ -144,9 +179,19 @@ check_if_mlr <- function(x, fn, call = caller_env()) { } } -check_calculate_stat <- function(stat, call = caller_env()) { +stat_chr <- function(stat) { + if (rlang::is_function(stat)) { + return("function") + } + stat +} + +check_calculate_stat <- function(stat, call = caller_env()) { check_type(stat, rlang::is_string, call = call) + if (identical(stat, "function")) { + return(stat) + } # Check for possible `stat` aliases alias_match_id <- match(stat, implemented_stats_aliases[["alias"]]) @@ -178,6 +223,10 @@ check_input_vs_stat <- function(x, stat, call = caller_env()) { ) } + if (identical(stat, "function")) { + return(x) + } + if (!stat %in% possible_stats) { if (has_explanatory(x)) { msg_tail <- glue( @@ -252,7 +301,7 @@ message_on_excessive_null <- function(x, stat = "mean", fn) { warn_on_insufficient_null <- function(x, stat, ...) { if (!is_hypothesized(x) && !has_explanatory(x) && - !stat %in% untheorized_stats && + !stat %in% c(untheorized_stats, "function") && !(stat == "t" && "mu" %in% names(list(...)))) { attr(x, "null") <- "point" attr(x, "params") <- assume_null(x, stat) @@ -626,3 +675,38 @@ calc_impl.z <- function(type, x, order, ...) { df_out } } + +#' @export +calc_impl.function <- function(type, x, order, ..., call = rlang::caller_env()) { + rlang::try_fetch( + { + if (!identical(dplyr::group_vars(x), "replicate")) { + x <- dplyr::group_by(x, replicate) + } + x_by_replicate <- dplyr::group_split(x) + res <- purrr::map(x_by_replicate, ~type(.x, order, ...)) + }, + error = function(cnd) {rethrow_stat_cnd(cnd, call = call)}, + warning = function(cnd) {rethrow_stat_cnd(cnd, call = call)} + ) + + if (!rlang::is_scalar_atomic(res[[1]])) { + cli::cli_abort( + c( + "The supplied {.arg stat} function must return a scalar value.", + "i" = "It returned {.obj_type_friendly {res[[1]]}}." + ), + call = call + ) + } + + tibble::new_tibble(list(stat = unlist(res))) +} + +rethrow_stat_cnd <- function(cnd, call = call) { + cli::cli_abort( + "The supplied {.arg stat} function encountered an issue.", + parent = cnd, + call = call + ) +} diff --git a/R/observe.R b/R/observe.R index cc6e7ddc..d437e5c1 100644 --- a/R/observe.R +++ b/R/observe.R @@ -15,6 +15,8 @@ #' #' @return A 1-column tibble containing the calculated statistic `stat`. #' +#' @inheritSection calculate Arbitrary test statistics +#' #' @examples #' # calculating the observed mean number of hours worked per week #' gss %>% diff --git a/man/calculate.Rd b/man/calculate.Rd index f2739ca9..e2a595f6 100755 --- a/man/calculate.Rd +++ b/man/calculate.Rd @@ -17,13 +17,15 @@ calculate( \item{x}{The output from \code{\link[=generate]{generate()}} for computation-based inference or the output from \code{\link[=hypothesize]{hypothesize()}} piped in to here for theory-based inference.} -\item{stat}{A string giving the type of the statistic to calculate. Current +\item{stat}{A string giving the type of the statistic to calculate or a +function that takes in a replicate of \code{x} and returns a scalar value. Current options include \code{"mean"}, \code{"median"}, \code{"sum"}, \code{"sd"}, \code{"prop"}, \code{"count"}, \code{"diff in means"}, \code{"diff in medians"}, \code{"diff in props"}, \code{"Chisq"} (or \code{"chisq"}), \code{"F"} (or \code{"f"}), \code{"t"}, \code{"z"}, \code{"ratio of props"}, \code{"slope"}, \code{"odds ratio"}, \code{"ratio of means"}, and \code{"correlation"}. \code{infer} only supports theoretical tests on one or two means via the \code{"t"} distribution -and one or two proportions via the \code{"z"}.} +and one or two proportions via the \code{"z"}. See the "Arbitrary test statistics" +section below for more on how to define a custom statistic.} \item{order}{A string vector of specifying the order in which the levels of the explanatory variable should be ordered for subtraction (or division @@ -48,6 +50,39 @@ supplied \code{stat} for each \code{replicate}. Learn more in \code{vignette("infer")}. } +\section{Arbitrary test statistics}{ + + +In addition to the pre-implemented statistics documented in \code{stat}, users can +supply an arbitrary test statistic by supplying a function to the \code{stat} +argument. + +The function should have arguments \code{stat(x, order, ...)}, where \code{x} is one +replicate's worth of \code{x}. The \code{order} argument and ellipses will be supplied +directly to the \code{stat} function. Internally, \code{calculate()} will split \code{x} up +into data frames by replicate and pass them one-by-one to the supplied \code{stat}. +For example, to implement \code{stat = "mean"} as a function, one could write: + +\if{html}{\out{
}}\preformatted{stat_mean <- function(x, order, ...) \{mean(x$hours)\} +obs_mean <- + gss \%>\% + specify(response = hours) \%>\% + calculate(stat = stat_mean) + +set.seed(1) +null_dist_mean <- + gss \%>\% + specify(response = hours) \%>\% + hypothesize(null = "point", mu = 40) \%>\% + generate(reps = 5, type = "bootstrap") \%>\% + calculate(stat = stat_mean) +}\if{html}{\out{
}} + +Note that the same \code{stat_mean} function is supplied to both \code{generate()}d and +non-\code{generate()}d infer objects--no need to implement support for grouping +by \code{replicate} yourself. +} + \section{Missing levels in small samples}{ In some cases, when bootstrapping with small samples, some generated diff --git a/man/observe.Rd b/man/observe.Rd index 9945a85a..46c86f95 100644 --- a/man/observe.Rd +++ b/man/observe.Rd @@ -67,13 +67,15 @@ hypotheses when the specified response variable is continuous.} \item{sigma}{The true standard deviation (any numerical value). To be used with point null hypotheses.} -\item{stat}{A string giving the type of the statistic to calculate. Current +\item{stat}{A string giving the type of the statistic to calculate or a +function that takes in a replicate of \code{x} and returns a scalar value. Current options include \code{"mean"}, \code{"median"}, \code{"sum"}, \code{"sd"}, \code{"prop"}, \code{"count"}, \code{"diff in means"}, \code{"diff in medians"}, \code{"diff in props"}, \code{"Chisq"} (or \code{"chisq"}), \code{"F"} (or \code{"f"}), \code{"t"}, \code{"z"}, \code{"ratio of props"}, \code{"slope"}, \code{"odds ratio"}, \code{"ratio of means"}, and \code{"correlation"}. \code{infer} only supports theoretical tests on one or two means via the \code{"t"} distribution -and one or two proportions via the \code{"z"}.} +and one or two proportions via the \code{"z"}. See the "Arbitrary test statistics" +section below for more on how to define a custom statistic.} \item{order}{A string vector of specifying the order in which the levels of the explanatory variable should be ordered for subtraction (or division @@ -97,6 +99,39 @@ null hypothesis parameter is supplied. Learn more in \code{vignette("infer")}. } +\section{Arbitrary test statistics}{ + + +In addition to the pre-implemented statistics documented in \code{stat}, users can +supply an arbitrary test statistic by supplying a function to the \code{stat} +argument. + +The function should have arguments \code{stat(x, order, ...)}, where \code{x} is one +replicate's worth of \code{x}. The \code{order} argument and ellipses will be supplied +directly to the \code{stat} function. Internally, \code{calculate()} will split \code{x} up +into data frames by replicate and pass them one-by-one to the supplied \code{stat}. +For example, to implement \code{stat = "mean"} as a function, one could write: + +\if{html}{\out{
}}\preformatted{stat_mean <- function(x, order, ...) \{mean(x$hours)\} +obs_mean <- + gss \%>\% + specify(response = hours) \%>\% + calculate(stat = stat_mean) + +set.seed(1) +null_dist_mean <- + gss \%>\% + specify(response = hours) \%>\% + hypothesize(null = "point", mu = 40) \%>\% + generate(reps = 5, type = "bootstrap") \%>\% + calculate(stat = stat_mean) +}\if{html}{\out{
}} + +Note that the same \code{stat_mean} function is supplied to both \code{generate()}d and +non-\code{generate()}d infer objects--no need to implement support for grouping +by \code{replicate} yourself. +} + \examples{ # calculating the observed mean number of hours worked per week gss \%>\% diff --git a/tests/testthat/_snaps/calculate.md b/tests/testthat/_snaps/calculate.md index 48311d34..7d56b4ad 100644 --- a/tests/testthat/_snaps/calculate.md +++ b/tests/testthat/_snaps/calculate.md @@ -527,3 +527,49 @@ ! Multiple explanatory variables are not supported in `calculate()`. i When working with multiple explanatory variables, use `fit()` (`?infer::fit.infer()`) instead. +# arbitrary test statistic works + + Code + gss %>% specify(response = hours) %>% calculate(stat = function(x, ...) { + mean(x$hour) + }) + Condition + Error in `calculate()`: + ! The supplied `stat` function encountered an issue. + Caused by warning: + ! Unknown or uninitialised column: `hour`. + +--- + + Code + gss %>% specify(response = hours) %>% calculate(stat = function(x, ...) { + mean("hey there") + }) + Condition + Error in `calculate()`: + ! The supplied `stat` function encountered an issue. + Caused by warning in `mean.default()`: + ! argument is not numeric or logical: returning NA + +--- + + Code + gss %>% specify(response = hours) %>% calculate(stat = function(x, ...) { + data.frame(woops = mean(x$hours)) + }) + Condition + Error in `calculate()`: + ! The supplied `stat` function must return a scalar value. + i It returned a data frame. + +--- + + Code + gss %>% specify(response = hours) %>% calculate(stat = function(x, ...) { + identity + }) + Condition + Error in `calculate()`: + ! The supplied `stat` function must return a scalar value. + i It returned a function. + diff --git a/tests/testthat/test-calculate.R b/tests/testthat/test-calculate.R index 1e00c091..9b5c5fa1 100644 --- a/tests/testthat/test-calculate.R +++ b/tests/testthat/test-calculate.R @@ -821,13 +821,98 @@ test_that("reported standard errors are correct", { expect_null(attr(rat_hat, "se")) }) +test_that("arbitrary test statistic works", { + # observed test statistics match pre-implemented ones + obs_stat_manual <- + gss %>% + specify(response = hours) %>% + calculate(stat = function(x, ...) {mean(x$hours)}) + + obs_stat_pre_implemented <- + gss %>% + specify(response = hours) %>% + calculate(stat = function(x, ...) {mean(x$hours)}) + + expect_equal(obs_stat_manual, obs_stat_pre_implemented) + + # can supply a stat totally unknown to infer + mode_hours <- function(x, ...) { + hours_tbl <- table(x$hours) + as.numeric(names(sort(hours_tbl)))[length(hours_tbl)] + } + obs_stat_manual <- + gss %>% + specify(response = hours) %>% + calculate(stat = mode_hours) + expect_s3_class(obs_stat_manual, c("infer", "tbl_df")) + expect_named(obs_stat_manual, "stat") + expect_equal(obs_stat_manual$stat[[1]], 40) + # ...even one with a character value! + mode_partyid <- function(x, ...) { + partyid_tbl <- table(x$partyid) + names(sort(partyid_tbl))[length(partyid_tbl)] + } + obs_stat_manual <- + gss %>% + specify(response = partyid) %>% + calculate(stat = mode_partyid) + + expect_s3_class(obs_stat_manual, c("infer", "tbl_df")) + expect_named(obs_stat_manual, "stat") + expect_equal(obs_stat_manual$stat[[1]], "ind") + + # resampled test statistics match pre-implemented ones + set.seed(1) + stat_dist_manual <- + gss %>% + specify(response = hours) %>% + hypothesize(null = "point", mu = 40) %>% + generate(reps = 5, type = "bootstrap") %>% + calculate(stat = function(x, ...) {mean(x$hours)}) + + set.seed(1) + stat_dist_pre_implemented <- + gss %>% + specify(response = hours) %>% + hypothesize(null = "point", mu = 40) %>% + generate(reps = 5, type = "bootstrap") %>% + calculate(stat = function(x, ...) {mean(x$hours)}) + expect_equal(stat_dist_manual, stat_dist_pre_implemented) + # errors and warnings are rethrown informatively + expect_snapshot( + error = TRUE, + gss %>% + specify(response = hours) %>% + # intentionally misspell `hour` to trigger warning + calculate(stat = function(x, ...) {mean(x$hour)}) + ) + expect_snapshot( + error = TRUE, + gss %>% + specify(response = hours) %>% + # intentionally raise error + calculate(stat = function(x, ...) {mean("hey there")}) + ) + # incompatible functions are handled gracefully + expect_snapshot( + error = TRUE, + gss %>% + specify(response = hours) %>% + calculate(stat = function(x, ...) {data.frame(woops = mean(x$hours))}) + ) - + expect_snapshot( + error = TRUE, + gss %>% + specify(response = hours) %>% + calculate(stat = function(x, ...) {identity}) + ) +}) diff --git a/tests/testthat/test-observe.R b/tests/testthat/test-observe.R index d3a4bd6f..2ef50a8d 100644 --- a/tests/testthat/test-observe.R +++ b/tests/testthat/test-observe.R @@ -161,3 +161,16 @@ test_that("observe() output is the same as the old wrappers", { ) }) +test_that("observe() can handle arbitrary test statistics", { + mean_manual <- + gss %>% + specify(response = hours) %>% + calculate(stat = "mean") + + mean_observe <- + observe(gss, response = hours, stat = function(x, ...) {mean(x$hours)}) + + # use `ignore_attr` since infer will only calculate standard errors with + # some pre-implemented statistics + expect_equal(mean_manual, mean_observe, ignore_attr = TRUE) +})