Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add filtering joins, with documentation and tests. #120

Merged
merged 3 commits into from
May 23, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

S3method(anti_join,tbl_svy)
S3method(as.character,survey_vars)
S3method(as.data.frame,tbl_svy)
S3method(as_survey,data.frame)
Expand Down Expand Up @@ -58,6 +59,7 @@ S3method(rename_,tbl_svy)
S3method(rename_with,tbl_svy)
S3method(select,tbl_svy)
S3method(select_,tbl_svy)
S3method(semi_join,tbl_svy)
S3method(summarise,grouped_svy)
S3method(summarise,tbl_svy)
S3method(summarise_,grouped_svy)
Expand All @@ -70,6 +72,7 @@ S3method(ungroup,tbl_svy)
export("%>%")
export(across)
export(all_vars)
export(anti_join)
export(any_vars)
export(as_survey)
export(as_survey_)
Expand Down Expand Up @@ -149,6 +152,7 @@ export(select_)
export(select_all)
export(select_at)
export(select_if)
export(semi_join)
export(set_survey_vars)
export(summarise)
export(summarise_)
Expand Down Expand Up @@ -186,6 +190,7 @@ export(vars)
import(rlang)
importFrom(dplyr,across)
importFrom(dplyr,all_vars)
importFrom(dplyr,anti_join)
importFrom(dplyr,any_vars)
importFrom(dplyr,c_across)
importFrom(dplyr,collect)
Expand Down Expand Up @@ -241,6 +246,7 @@ importFrom(dplyr,select_)
importFrom(dplyr,select_all)
importFrom(dplyr,select_at)
importFrom(dplyr,select_if)
importFrom(dplyr,semi_join)
importFrom(dplyr,summarise)
importFrom(dplyr,summarise_)
importFrom(dplyr,summarise_all)
Expand Down
91 changes: 91 additions & 0 deletions R/join.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#' @export
semi_join.tbl_svy <- function(
x,
y,
by = NULL,
copy = FALSE,
...,
na_matches = c("na", "never")
) {

if (inherits(y, "tbl_svy")) {
y <- y$variables
}

x <- mutate(x, `___row_number` = dplyr::row_number())

filtered_vars <- semi_join(x = x$variables,
y = y,
by = by,
copy = copy,
na_matches = na_matches,
...)

x <- mutate(x, `___retained` = `___row_number` %in% filtered_vars[['___row_number']])
Copy link
Owner

@gergness gergness May 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you skip the ___retained variable and just filter on the expression? (If not, I thnk you need to de-select the ___row_number variable, don't you?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's a good idea. I didn't de-select ___row_number because it was actually removed internally in the call to filter(). But just to be careful, in f240676 added a conditional select to remove that column if it still exists after using filter().

x <- filter(x, .data$`___retained`)
x <- select(x, -.data$`___retained`)

x

}

#' @export
anti_join.tbl_svy <- function(
x,
y,
by = NULL,
copy = FALSE,
...,
na_matches = c("na", "never")
) {

if (inherits(y, "tbl_svy")) {
y <- y$variables
}

x <- mutate(x, `___row_number` = dplyr::row_number())

filtered_vars <- anti_join(x = x$variables,
y = y,
by = by,
copy = copy,
na_matches = na_matches,
...)

x <- mutate(x, `___retained` = `___row_number` %in% filtered_vars[['___row_number']])
x <- filter(x, `___retained`)
x <- select(x, -`___retained`)

x

}

# Import + export generics from dplyr and tidyr
#' Filtering joins from dplyr
#'
#' These are data manipulation functions designed to work on a \code{tbl_svy} object
#' and another data frame or \code{tbl_svy} object.
#'
#' \code{semi_join} and \code{anti_join} filter certain observations from a \code{tbl_svy}
#' depending on the presence or absence of matches in another table.
#' See \code{\link[dplyr]{filter-joins}} for more details.
#'
#' Mutating joins (\code{full_join}, \code{left_join}, etc.) are not implemented
#' for any \code{tbl_svy} objects. These data manipulations
#' may require modifications to the survey variable specifications and so
#' cannot be done automatically. Instead, use dplyr to perform them while the
#' data is still stored in data.frames.
#' @name dplyr_filter_joins
NULL

#' @name semi_join
#' @export
#' @importFrom dplyr semi_join
#' @rdname dplyr_filter_joins
NULL

#' @name anti_join
#' @export
#' @importFrom dplyr anti_join
#' @rdname dplyr_filter_joins
NULL
22 changes: 22 additions & 0 deletions man/dplyr_filter_joins.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

229 changes: 229 additions & 0 deletions tests/testthat/test_joins.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
context("filtering joins (semi_join and anti_join) work")

suppressPackageStartupMessages({
library(survey)
library(srvyr)
library(dplyr)
})

source("utilities.R")

# Set up example data ----

data(api)

##_ Create simple stratified survey design object ----
stratified_design <- apistrat %>%
as_survey_design(strata = stype, weights = pw)

##_ Create clustered survey design object ----
cluster_design <- as_survey_design(
.data = apiclus1,
id = dnum,
weights = pw,
fpc = fpc
)

##_ Create survey design object with calibration weights ----
##_ NOTE: The survey package uses special behavior when subsetting such survey designs.
##_ Rows are never removed, the weights are simply set effectively to zero (technically, Inf)

### Add raking weights for school type
pop.types <- data.frame(stype=c("E","H","M"), Freq=c(4421,755,1018))
pop.schwide <- data.frame(sch.wide=c("No","Yes"), Freq=c(1072,5122))

raked_design <- rake(
cluster_design,
sample.margins = list(~stype,~sch.wide),
population.margins = list(pop.types, pop.schwide)
)

# semi_join ----

test_that(
"semi_join works with `by = NULL`", {
# Stratified design
expect_equal(
## Calculate statistic, after using a filtering join
object = stratified_design %>%
semi_join(y = filter(apistrat, stype == "E")) %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = stratified_design %>%
filter(stype == "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)

# Cluster design
expect_equal(
## Calculate statistic, after using a filtering join
object = cluster_design %>%
semi_join(y = filter(apiclus1, stype == "E")) %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = cluster_design %>%
filter(stype == "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)

# Calibration weighted design
expect_equal(
## Calculate statistic, after using a filtering join
object = raked_design %>%
semi_join(y = filter(apiclus1, stype == "E")) %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = raked_design %>%
filter(stype == "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)
})

test_that(
"semi_join works with supplied `by` argument", {
# Stratified design
expect_equal(
## Calculate statistic, after using a filtering join
object = stratified_design %>%
semi_join(y = filter(apistrat, stype == "E"),
by = "stype") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = stratified_design %>%
filter(stype == "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)

# Cluster design
expect_equal(
## Calculate statistic, after using a filtering join
object = cluster_design %>%
semi_join(y = filter(apiclus1, stype == "E"),
by = "stype") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = cluster_design %>%
filter(stype == "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)

# Calibration weighted design
expect_equal(
## Calculate statistic, after using a filtering join
object = raked_design %>%
semi_join(y = filter(apiclus1, stype == "E"),
by = "stype") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = raked_design %>%
filter(stype == "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)
})

# anti_join ----

test_that(
"anti_join works with `by = NULL`", {
# Stratified design
expect_equal(
## Calculate statistic, after using a filtering join
object = stratified_design %>%
anti_join(y = filter(apistrat, stype == "E")) %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = stratified_design %>%
filter(stype != "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)

# Cluster design
expect_equal(
## Calculate statistic, after using a filtering join
object = cluster_design %>%
anti_join(y = filter(apiclus1, stype == "E")) %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = cluster_design %>%
filter(stype != "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)

# Calibration weighted design
expect_equal(
## Calculate statistic, after using a filtering join
object = raked_design %>%
anti_join(y = filter(apiclus1, stype == "E")) %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = raked_design %>%
filter(stype != "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)
})

test_that(
"anti_join works with supplied `by` argument", {
# Stratified design
expect_equal(
## Calculate statistic, after using a filtering join
object = stratified_design %>%
anti_join(y = filter(apistrat, stype == "E"),
by = "stype") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = stratified_design %>%
filter(stype != "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)

# Cluster design
expect_equal(
## Calculate statistic, after using a filtering join
object = cluster_design %>%
anti_join(y = filter(apiclus1, stype == "E"),
by = "stype") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = cluster_design %>%
filter(stype != "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)

# Calibration weighted design
expect_equal(
## Calculate statistic, after using a filtering join
object = raked_design %>%
anti_join(y = filter(apiclus1, stype == "E"),
by = "stype") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = raked_design %>%
filter(stype != "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)
})