Skip to content

Commit

Permalink
feat: More complete across() (#306)
Browse files Browse the repository at this point in the history
* Weird behavior

* Avoid the need to capture expressions

* Memoise

* Unleash

* Auto suffix

* Support lists of length > 1
  • Loading branch information
krlmlr authored Nov 2, 2024
1 parent 073a94b commit 8e48b7c
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 44 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Imports:
glue,
jsonlite,
lifecycle,
memoise,
rlang (>= 1.0.6),
tibble,
tidyselect,
Expand Down
98 changes: 69 additions & 29 deletions R/duckplyr-across.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,10 @@ duckplyr_expand_across <- function(data, quo) {
fns <- as_quosure(expr$.fns, env)
fns <- quo_eval_fns(fns, mask = env, error_call = error_call)

# duckplyr doesn't currently support >1 function so we bail if we
# see that potential case, but to potentially allow for this in the future we
# manually wrap in a list using the default name of `"1"`.
if (!is.function(fns)) {
return(NULL)
}
fns <- list("1" = fns)
fn_exprs <- list(expr$.fns)

# In dplyr this evaluates in the mask to reproduce the `mutate()` or
# `summarise()` context. We don't have a mask here but it's probably fine in
# almost all cases.
names <- eval_tidy(expr$.names, env = env)
names <- names %||% "{.col}"

setup <- duckplyr_across_setup(
data,
Expand Down Expand Up @@ -105,16 +95,8 @@ duckplyr_expand_across <- function(data, quo) {
var <- vars[[i]]

for (j in seq_fns) {
fn_expr <- fn_exprs[[j]]

if (is_symbol(fn_expr)) {
# When we see a bare symbol like `across(x:y, mean)`, we don't
# want to inline the function itself, we want to inline its expression.
fn_call <- new_quosure(call2(fn_expr, sym(var)), env = env)
} else {
# Note: `mask` isn't actually used inside this helper
fn_call <- as_across_fn_call(fns[[j]], var, env, mask = env)
}
# Note: `mask` isn't actually used inside this helper
fn_call <- as_across_fn_call(fn_to_expr(fns[[j]], env), var, env, mask = env)

name <- names[[k]]

Expand All @@ -139,11 +121,6 @@ duckplyr_across_setup <- function(data,
names,
.caller_env,
error_call = caller_env()) {
stopifnot(
is.list(fns),
length(fns) == 1
)

data <- set_names(seq_along(data), data)

vars <- tidyselect::eval_select(
Expand All @@ -157,10 +134,32 @@ duckplyr_across_setup <- function(data,

names_fns <- names(fns)

glue_mask <- across_glue_mask(
.col = names_vars,
.fn = names_fns,
.caller_env = .caller_env
# apply `.names` smart default
if (is.function(fns)) {
names <- names %||% "{.col}"
fns <- list("1" = fns)
} else {
names <- names %||% "{.col}_{.fn}"
}

if (!is.list(fns)) {
abort("Expected a list.", .internal = TRUE)
}

# make sure fns has names, use number to replace unnamed
if (is.null(names(fns))) {
names_fns <- seq_along(fns)
} else {
names_fns <- names(fns)
empties <- which(names_fns == "")
if (length(empties)) {
names_fns[empties] <- empties
}
}

glue_mask <- across_glue_mask(.caller_env,
.col = rep(names_vars, each = length(fns)),
.fn = rep(names_fns , length(vars))
)
names <- vec_as_names(
glue(names, .envir = glue_mask),
Expand All @@ -175,6 +174,47 @@ duckplyr_across_setup <- function(data,
)
}

fn_to_expr <- function(fn, env) {
fn_env <- environment(fn)
if (!is_namespace(fn_env)) {
return(fn)
}

# This is an environment that maps hashes to function names
ns_exports_lookup <- get_ns_exports_lookup(fn_env)

# Can we find the function among the exports in the namespace?
fun_name <- ns_exports_lookup[[hash(fn)]]
if (is.null(fun_name)) {
return(fn)
}

# Triple-check: Does the expression actually evaluate to fn?
ns_name <- getNamespaceName(fn_env)
out <- call2("::", sym(ns_name), sym(fun_name))
if (!identical(eval(out, env), fn)) {
return(fn)
}

out
}

# Memoize get_ns_exports_lookup() to avoid recomputing the hash of
# every function in every namespace every time
on_load({
get_ns_exports_lookup <<- memoise::memoise(get_ns_exports_lookup)
})

get_ns_exports_lookup <- function(ns) {
names <- getNamespaceExports(ns)
objs <- mget(names, ns)
funs <- objs[map_lgl(objs, is.function)]

hashes <- map_chr(funs, hash)
# Reverse, return as environment
new_environment(set_names(as.list(names(hashes)), hashes))
}

test_duckplyr_expand_across <- function(data, expr) {
quo <- new_dplyr_quosure(enquo(expr), is_named = FALSE, index = 1L)
out <- duckplyr_expand_across(data, quo)
Expand Down
5 changes: 5 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,8 @@
packageStartupMessage(msg)
}
}

# Avoid R CMD check warning
dummy <- function() {
memoise::memoise()
}
35 changes: 32 additions & 3 deletions tests/testthat/_snaps/duckplyr-across.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Code
test_duckplyr_expand_across(c("x", "y"), across(x:y, mean))
Output
tibble(x = mean(x), y = mean(y))
tibble(x = base::mean(x), y = base::mean(y))

---

Expand All @@ -17,15 +17,15 @@
Code
test_duckplyr_expand_across(c("x", "y"), across(c(x_mean = x, y_mean = y), mean))
Output
tibble(x_mean = mean(x), y_mean = mean(y))
tibble(x_mean = base::mean(x), y_mean = base::mean(y))

---

Code
test_duckplyr_expand_across(c("x", "y"), across(c(x_mean = x, y_mean = y), mean,
.names = "{.col}_{.fn}"))
Output
tibble(x_mean_1 = mean(x), y_mean_1 = mean(y))
tibble(x_mean_1 = base::mean(x), y_mean_1 = base::mean(y))

---

Expand Down Expand Up @@ -56,3 +56,32 @@
Output
tibble(x = x * x, y = y * y)

---

Code
test_duckplyr_expand_across(c("x", "y"), across(x:y, base::mean))
Output
tibble(x = base::mean(x), y = base::mean(y))

---

Code
test_duckplyr_expand_across(c("x", "y"), across(x:y, list(mean)))
Output
tibble(x_1 = base::mean(x), y_1 = base::mean(y))

---

Code
test_duckplyr_expand_across(c("x", "y"), across(x:y, list(mean = mean)))
Output
tibble(x_mean = base::mean(x), y_mean = base::mean(y))

---

Code
test_duckplyr_expand_across(c("x", "y"), across(x:y, list(mean = mean, median = median)))
Output
tibble(x_mean = base::mean(x), x_median = stats::median(x), y_mean = base::mean(y),
y_median = stats::median(y))

40 changes: 28 additions & 12 deletions tests/testthat/test-duckplyr-across.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,34 @@ test_that("duckplyr_expand_across() successful", {
across(-a, function(x) x * x)
)
})

expect_snapshot({
test_duckplyr_expand_across(
c("x", "y"),
across(x:y, base::mean)
)
})

expect_snapshot({
test_duckplyr_expand_across(
c("x", "y"),
across(x:y, list(mean))
)
})

expect_snapshot({
test_duckplyr_expand_across(
c("x", "y"),
across(x:y, list(mean = mean))
)
})

expect_snapshot({
test_duckplyr_expand_across(
c("x", "y"),
across(x:y, list(mean = mean, median = median))
)
})
})

test_that("duckplyr_expand_across() failing", {
Expand All @@ -65,16 +93,4 @@ test_that("duckplyr_expand_across() failing", {
c("x", "y"),
across(x:y, mean, na.rm = TRUE)
))
expect_null(test_duckplyr_expand_across(
c("x", "y"),
across(x:y, list(mean))
))
expect_null(test_duckplyr_expand_across(
c("x", "y"),
across(x:y, list(mean = mean))
))
expect_null(test_duckplyr_expand_across(
c("x", "y"),
across(x:y, list(mean = mean, median = median))
))
})

0 comments on commit 8e48b7c

Please sign in to comment.