Skip to content

Commit

Permalink
Merge pull request #196 from tidymodels/loopify
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt authored Aug 8, 2023
2 parents 6e26ecc + 8b5c45c commit 3665f6b
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 44 deletions.
9 changes: 5 additions & 4 deletions R/collapse_cart.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,13 @@ prep.step_collapse_cart <- function(x, training, info = NULL, ...) {

#' @export
bake.step_collapse_cart <- function(object, new_data, ...) {
nms <- names(object$results)
check_new_data(nms, object, new_data)
col_names <- names(object$results)
check_new_data(col_names, object, new_data)

for (i in seq_along(object$results)) {
new_data <- convert_keys(nms[i], object$results[[i]], new_data)
for (col_name in col_names) {
new_data <- convert_keys(col_name, object$results[[col_name]], new_data)
}

new_data
}

Expand Down
15 changes: 7 additions & 8 deletions R/collapse_stringdist.R
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,13 @@ collapse_stringdist_impl <- function(x, dist, method, options) {

#' @export
bake.step_collapse_stringdist <- function(object, new_data, ...) {
col_names <- object$columns
# for backward compat
check_new_data(names(col_names), object, new_data)

for (i in seq_along(col_names)) {
new_data[, col_names[i]] <- collapse_apply(
new_data[[col_names[i]]],
object$results[[i]]
col_names <- names(object$columns)
check_new_data(col_names, object, new_data)

for (col_name in col_names) {
new_data[[col_name]] <- collapse_apply(
new_data[[col_name]],
object$results[[col_name]]
)
}
as_tibble(new_data)
Expand Down
15 changes: 10 additions & 5 deletions R/embed.R
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,15 @@ map_tf_coef2 <- function(dat, mapping, prefix) {

#' @export
bake.step_embed <- function(object, new_data, ...) {
check_new_data(names(object$mapping), object, new_data)

for (col in names(object$mapping)) {
tmp <- map_tf_coef2(new_data[, col], object$mapping[[col]], prefix = col)
col_names <- names(object$mapping)
check_new_data(col_names, object, new_data)

for (col_name in col_names) {
tmp <- map_tf_coef2(
dat = new_data[, col_name], # map_tf_coef2() expects a tibble
mapping = object$mapping[[col_name]],
prefix = col_name
)

tmp <- check_name(tmp, new_data, object, names(tmp))

Expand All @@ -423,7 +428,7 @@ bake.step_embed <- function(object, new_data, ...) {

keep_original_cols <- get_keep_original_cols(object)
if (!keep_original_cols) {
new_data <- new_data[, !(names(new_data) %in% names(object$mapping))]
new_data <- new_data[, !(names(new_data) %in% col_names)]
}

new_data
Expand Down
11 changes: 6 additions & 5 deletions R/feature_hash.R
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,18 @@ make_hash_tbl <- function(ind, nms) {

#' @export
bake.step_feature_hash <- function(object, new_data, ...) {
check_new_data(names(object$columns), object, new_data)
col_names <- names(object$columns)
check_new_data(col_names, object, new_data)

# If no terms were selected
if (length(object$columns) == 0) {
if (length(col_names) == 0) {
return(new_data)
}

new_names <- paste0(object$columns, "_hash_")
new_names <- paste0(col_names, "_hash_")

new_cols <- purrr::map2_dfc(
new_data[, object$columns],
new_data[, col_names],
new_names, make_hash_vars,
num_hash =
object$num_hash
Expand All @@ -220,7 +221,7 @@ bake.step_feature_hash <- function(object, new_data, ...) {

keep_original_cols <- get_keep_original_cols(object)
if (!keep_original_cols) {
new_data <- new_data %>% dplyr::select(-one_of(!!!object$columns))
new_data <- new_data %>% dplyr::select(-one_of(!!!col_names))
}

new_data
Expand Down
10 changes: 7 additions & 3 deletions R/lencode_bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,14 @@ stan_coefs <- function(x, y, options, verbose, wts = NULL, ...) {

#' @export
bake.step_lencode_bayes <- function(object, new_data, ...) {
check_new_data(names(object$mapping), object, new_data)
col_names <- names(object$mapping)
check_new_data(col_names, object, new_data)

for (col in names(object$mapping)) {
new_data[, col] <- map_glm_coef(new_data[, col], object$mapping[[col]])
for (col_name in col_names) {
new_data[[col_name]] <- map_glm_coef(
new_data[, col_name], # map_glm_coef() expects a tibble
object$mapping[[col_name]]
)
}

new_data
Expand Down
10 changes: 7 additions & 3 deletions R/lencode_glm.R
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,14 @@ map_glm_coef <- function(dat, mapping) {

#' @export
bake.step_lencode_glm <- function(object, new_data, ...) {
check_new_data(names(object$mapping), object, new_data)
col_names <- names(object$mapping)
check_new_data(col_names, object, new_data)

for (col in names(object$mapping)) {
new_data[, col] <- map_glm_coef(new_data[, col], object$mapping[[col]])
for (col_name in col_names) {
new_data[[col_name]] <- map_glm_coef(
dat = new_data[, col_name], # map_glm_coef() expects a tibble
mapping = object$mapping[[col_names]]
)
}

new_data
Expand Down
10 changes: 7 additions & 3 deletions R/lencode_mixed.R
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,14 @@ map_lme_coef <- function(dat, mapping) {

#' @export
bake.step_lencode_mixed <- function(object, new_data, ...) {
check_new_data(names(object$mapping), object, new_data)
col_names <- names(object$mapping)
check_new_data(col_names, object, new_data)

for (col in names(object$mapping)) {
new_data[, col] <- map_lme_coef(new_data[, col], object$mapping[[col]])
for (col_name in col_names) {
new_data[[col_name]] <- map_lme_coef(
dat = new_data[, col_name], # map_glm_coef() expects a tibble
mapping = object$mapping[[col_name]]
)
}

new_data
Expand Down
9 changes: 4 additions & 5 deletions R/pca_truncated.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,14 @@ prep.step_pca_truncated <- function(x, training, info = NULL, ...) {

#' @export
bake.step_pca_truncated <- function(object, new_data, ...) {
if (is.null(object$columns)) {
object$columns <- stats::setNames(nm = rownames(object$res$rotation))
}
col_names <- names(object$columns) %||%
stats::setNames(nm = rownames(object$res$rotation))

if (length(object$columns) == 0 || all(is.na(object$res$rotation))) {
if (length(col_names) == 0 || all(is.na(object$res$rotation))) {
return(new_data)
}

check_new_data(object$columns, object, new_data)
check_new_data(col_names, object, new_data)

pca_vars <- rownames(object$res$rotation)
comps <- scale(new_data[, pca_vars], object$res$center, object$res$scale) %*%
Expand Down
13 changes: 8 additions & 5 deletions R/umap.R
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,10 @@ prep.step_umap <- function(x, training, info = NULL, ...) {

#' @export
bake.step_umap <- function(object, new_data, ...) {
check_new_data(names(object$object$xnames), object, new_data)
col_names <- names(object$object$xnames)
check_new_data(col_names, object, new_data)

if (length(object$object) == 0) {
if (length(col_names) == 0) {
return(new_data)
}

Expand All @@ -261,11 +262,13 @@ bake.step_umap <- function(object, new_data, ...) {
res <-
uwot::umap_transform(
model = object$object,
X = new_data[, object$object$xnames]
X = new_data[, col_names]
)
)

if (is.null(object$prefix)) object$prefix <- "UMAP"
if (is.null(object$prefix)) {
object$prefix <- "UMAP"
}

colnames(res) <- names0(object$num_comp, prefix = object$prefix)
res <- as_tibble(res)
Expand All @@ -275,7 +278,7 @@ bake.step_umap <- function(object, new_data, ...) {

keep_original_cols <- get_keep_original_cols(object)
if (!keep_original_cols) {
keep_cols <- !(colnames(new_data) %in% object$object$xnames)
keep_cols <- !(colnames(new_data) %in% col_names)
new_data <- new_data[, keep_cols, drop = FALSE]
}

Expand Down
6 changes: 3 additions & 3 deletions R/woe.R
Original file line number Diff line number Diff line change
Expand Up @@ -458,9 +458,9 @@ prep.step_woe <- function(x, training, info = NULL, ...) {
#' @export
bake.step_woe <- function(object, new_data, ...) {
dict <- object$dictionary
woe_vars <- unique(dict[["variable"]])
col_names <- unique(dict[["variable"]])

check_new_data(woe_vars, object, new_data)
check_new_data(col_names, object, new_data)

if (nrow(object$dictionary) == 0) {
return(new_data)
Expand All @@ -475,7 +475,7 @@ bake.step_woe <- function(object, new_data, ...) {

keep_original_cols <- get_keep_original_cols(object)
if (!keep_original_cols) {
new_data <- new_data[, !(colnames(new_data) %in% woe_vars), drop = FALSE]
new_data <- new_data[, !(colnames(new_data) %in% col_names), drop = FALSE]
}

new_data
Expand Down

0 comments on commit 3665f6b

Please sign in to comment.