Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

arbitrary test statistics in calculate() #542

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

S3method(calc_impl,"function")
S3method(calc_impl,Chisq)
S3method(calc_impl,F)
S3method(calc_impl,correlation)
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
# infer (development version)

* Introduced support for arbitrary test statistics in `calculate()`. In addition
to the pre-implemented `calculate(stat)` options, taken as strings, users can
now supply a function defining any scalar-valued test statistic. See
`?calculate()` to learn more.

* Added missing commas and addressed formatting issues throughout the vignettes and articles. Backticks for package names were removed and missing parentheses for functions were added (@Joscelinrocha).


# infer 1.0.7

* The aliases `p_value()` and `conf_int()`, first deprecated 6 years ago, now
Expand Down
106 changes: 95 additions & 11 deletions R/calculate.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
#'
#' @param x The output from [generate()] for computation-based inference or the
#' output from [hypothesize()] piped in to here for theory-based inference.
#' @param stat A string giving the type of the statistic to calculate. Current
#' @param stat A string giving the type of the statistic to calculate or a
#' function that takes in a replicate of `x` and returns a scalar value. Current
#' options include `"mean"`, `"median"`, `"sum"`, `"sd"`, `"prop"`, `"count"`,
#' `"diff in means"`, `"diff in medians"`, `"diff in props"`, `"Chisq"` (or
#' `"chisq"`), `"F"` (or `"f"`), `"t"`, `"z"`, `"ratio of props"`, `"slope"`,
#' `"odds ratio"`, `"ratio of means"`, and `"correlation"`. `infer` only
#' supports theoretical tests on one or two means via the `"t"` distribution
#' and one or two proportions via the `"z"`.
#' and one or two proportions via the `"z"`. See the "Arbitrary test statistics"
#' section below for more on how to define a custom statistic.
#' @param order A string vector of specifying the order in which the levels of
#' the explanatory variable should be ordered for subtraction (or division
#' for ratio-based statistics), where `order = c("first", "second")` means
Expand All @@ -31,6 +33,38 @@
#'
#' @return A tibble containing a `stat` column of calculated statistics.
#'
#' @section Arbitrary test statistics:
#'
#' In addition to the pre-implemented statistics documented in `stat`, users can
#' supply an arbitrary test statistic by supplying a function to the `stat`
#' argument.
#'
#' The function should have arguments `stat(x, order, ...)`, where `x` is one
#' replicate's worth of `x`. The `order` argument and ellipses will be supplied
#' directly to the `stat` function. Internally, `calculate()` will split `x` up
#' into data frames by replicate and pass them one-by-one to the supplied `stat`.
#' For example, to implement `stat = "mean"` as a function, one could write:
#'
#' ```r
#' stat_mean <- function(x, order, ...) {mean(x$hours)}
#' obs_mean <-
#' gss %>%
#' specify(response = hours) %>%
#' calculate(stat = stat_mean)
#'
#' set.seed(1)
#' null_dist_mean <-
#' gss %>%
#' specify(response = hours) %>%
#' hypothesize(null = "point", mu = 40) %>%
#' generate(reps = 5, type = "bootstrap") %>%
#' calculate(stat = stat_mean)
#' ```
#'
#' Note that the same `stat_mean` function is supplied to both `generate()`d and
#' non-`generate()`d infer objects--no need to implement support for grouping
#' by `replicate` yourself.
#'
#' @section Missing levels in small samples:
#' In some cases, when bootstrapping with small samples, some generated
#' bootstrap samples will have only one level of the explanatory variable
Expand Down Expand Up @@ -97,22 +131,23 @@ calculate <- function(x,
...) {
check_type(x, tibble::is_tibble)
check_if_mlr(x, "calculate")
stat <- check_calculate_stat(stat)
check_input_vs_stat(x, stat)
check_point_params(x, stat)
stat_chr <- stat_chr(stat)
stat_chr <- check_calculate_stat(stat_chr)
check_input_vs_stat(x, stat_chr)
check_point_params(x, stat_chr)

order <- check_order(x, order, in_calculate = TRUE, stat)
order <- check_order(x, order, in_calculate = TRUE, stat_chr)

if (!is_generated(x)) {
x$replicate <- 1L
}

x <- message_on_excessive_null(x, stat = stat, fn = "calculate")
x <- warn_on_insufficient_null(x, stat, ...)
x <- message_on_excessive_null(x, stat = stat_chr, fn = "calculate")
x <- warn_on_insufficient_null(x, stat_chr, ...)

# Use S3 method to match correct calculation
result <- calc_impl(
structure(stat, class = gsub(" ", "_", stat)), x, order, ...
structure(stat, class = gsub(" ", "_", stat_chr)), x, order, ...
)

result <- copy_attrs(to = result, from = x)
Expand Down Expand Up @@ -144,9 +179,19 @@ check_if_mlr <- function(x, fn, call = caller_env()) {
}
}

check_calculate_stat <- function(stat, call = caller_env()) {
stat_chr <- function(stat) {
if (rlang::is_function(stat)) {
return("function")
}

stat
}

check_calculate_stat <- function(stat, call = caller_env()) {
check_type(stat, rlang::is_string, call = call)
if (identical(stat, "function")) {
return(stat)
}

# Check for possible `stat` aliases
alias_match_id <- match(stat, implemented_stats_aliases[["alias"]])
Expand Down Expand Up @@ -178,6 +223,10 @@ check_input_vs_stat <- function(x, stat, call = caller_env()) {
)
}

if (identical(stat, "function")) {
return(x)
}

if (!stat %in% possible_stats) {
if (has_explanatory(x)) {
msg_tail <- glue(
Expand Down Expand Up @@ -252,7 +301,7 @@ message_on_excessive_null <- function(x, stat = "mean", fn) {
warn_on_insufficient_null <- function(x, stat, ...) {
if (!is_hypothesized(x) &&
!has_explanatory(x) &&
!stat %in% untheorized_stats &&
!stat %in% c(untheorized_stats, "function") &&
!(stat == "t" && "mu" %in% names(list(...)))) {
attr(x, "null") <- "point"
attr(x, "params") <- assume_null(x, stat)
Expand Down Expand Up @@ -626,3 +675,38 @@ calc_impl.z <- function(type, x, order, ...) {
df_out
}
}

#' @export
calc_impl.function <- function(type, x, order, ..., call = rlang::caller_env()) {
rlang::try_fetch(
{
if (!identical(dplyr::group_vars(x), "replicate")) {
x <- dplyr::group_by(x, replicate)
}
x_by_replicate <- dplyr::group_split(x)
res <- purrr::map(x_by_replicate, ~type(.x, order, ...))
},
error = function(cnd) {rethrow_stat_cnd(cnd, call = call)},
warning = function(cnd) {rethrow_stat_cnd(cnd, call = call)}
)

if (!rlang::is_scalar_atomic(res[[1]])) {
cli::cli_abort(
c(
"The supplied {.arg stat} function must return a scalar value.",
"i" = "It returned {.obj_type_friendly {res[[1]]}}."
),
call = call
)
}

tibble::new_tibble(list(stat = unlist(res)))
}

rethrow_stat_cnd <- function(cnd, call = call) {
cli::cli_abort(
"The supplied {.arg stat} function encountered an issue.",
parent = cnd,
call = call
)
}
2 changes: 2 additions & 0 deletions R/observe.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#'
#' @return A 1-column tibble containing the calculated statistic `stat`.
#'
#' @inheritSection calculate Arbitrary test statistics
#'
#' @examples
#' # calculating the observed mean number of hours worked per week
#' gss %>%
Expand Down
39 changes: 37 additions & 2 deletions man/calculate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 37 additions & 2 deletions man/observe.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

46 changes: 46 additions & 0 deletions tests/testthat/_snaps/calculate.md
Original file line number Diff line number Diff line change
Expand Up @@ -527,3 +527,49 @@
! Multiple explanatory variables are not supported in `calculate()`.
i When working with multiple explanatory variables, use `fit()` (`?infer::fit.infer()`) instead.

# arbitrary test statistic works

Code
gss %>% specify(response = hours) %>% calculate(stat = function(x, ...) {
mean(x$hour)
})
Condition
Error in `calculate()`:
! The supplied `stat` function encountered an issue.
Caused by warning:
! Unknown or uninitialised column: `hour`.

---

Code
gss %>% specify(response = hours) %>% calculate(stat = function(x, ...) {
mean("hey there")
})
Condition
Error in `calculate()`:
! The supplied `stat` function encountered an issue.
Caused by warning in `mean.default()`:
! argument is not numeric or logical: returning NA

---

Code
gss %>% specify(response = hours) %>% calculate(stat = function(x, ...) {
data.frame(woops = mean(x$hours))
})
Condition
Error in `calculate()`:
! The supplied `stat` function must return a scalar value.
i It returned a data frame.

---

Code
gss %>% specify(response = hours) %>% calculate(stat = function(x, ...) {
identity
})
Condition
Error in `calculate()`:
! The supplied `stat` function must return a scalar value.
i It returned a function.

Loading
Loading