Skip to content

Commit

Permalink
Merge pull request #120 from bschneidr/filtering-joins
Browse files Browse the repository at this point in the history
Add filtering joins, with documentation and tests.
  • Loading branch information
gergness authored May 23, 2021
2 parents 7c2c77c + f240676 commit 60cff29
Show file tree
Hide file tree
Showing 4 changed files with 350 additions and 0 deletions.
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

S3method(anti_join,tbl_svy)
S3method(as.character,survey_vars)
S3method(as.data.frame,tbl_svy)
S3method(as_survey,data.frame)
Expand Down Expand Up @@ -58,6 +59,7 @@ S3method(rename_,tbl_svy)
S3method(rename_with,tbl_svy)
S3method(select,tbl_svy)
S3method(select_,tbl_svy)
S3method(semi_join,tbl_svy)
S3method(summarise,grouped_svy)
S3method(summarise,tbl_svy)
S3method(summarise_,grouped_svy)
Expand All @@ -70,6 +72,7 @@ S3method(ungroup,tbl_svy)
export("%>%")
export(across)
export(all_vars)
export(anti_join)
export(any_vars)
export(as_survey)
export(as_survey_)
Expand Down Expand Up @@ -149,6 +152,7 @@ export(select_)
export(select_all)
export(select_at)
export(select_if)
export(semi_join)
export(set_survey_vars)
export(summarise)
export(summarise_)
Expand Down Expand Up @@ -186,6 +190,7 @@ export(vars)
import(rlang)
importFrom(dplyr,across)
importFrom(dplyr,all_vars)
importFrom(dplyr,anti_join)
importFrom(dplyr,any_vars)
importFrom(dplyr,c_across)
importFrom(dplyr,collect)
Expand Down Expand Up @@ -241,6 +246,7 @@ importFrom(dplyr,select_)
importFrom(dplyr,select_all)
importFrom(dplyr,select_at)
importFrom(dplyr,select_if)
importFrom(dplyr,semi_join)
importFrom(dplyr,summarise)
importFrom(dplyr,summarise_)
importFrom(dplyr,summarise_all)
Expand Down
93 changes: 93 additions & 0 deletions R/join.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#' @export
semi_join.tbl_svy <- function(
x,
y,
by = NULL,
copy = FALSE,
...,
na_matches = c("na", "never")
) {

if (inherits(y, "tbl_svy")) {
y <- y$variables
}

x <- mutate(x, `___row_number` = dplyr::row_number())

filtered_vars <- semi_join(x = x$variables,
y = y,
by = by,
copy = copy,
na_matches = na_matches,
...)

x <- filter(x, .data$`___row_number` %in% filtered_vars[['___row_number']])
if ("___row_number" %in% tbl_vars(x)) {
x <- select(x, -`___row_number`)
}

x

}

#' @export
anti_join.tbl_svy <- function(
x,
y,
by = NULL,
copy = FALSE,
...,
na_matches = c("na", "never")
) {

if (inherits(y, "tbl_svy")) {
y <- y$variables
}

x <- mutate(x, `___row_number` = dplyr::row_number())

filtered_vars <- anti_join(x = x$variables,
y = y,
by = by,
copy = copy,
na_matches = na_matches,
...)

x <- filter(x, .data$`___row_number` %in% filtered_vars[['___row_number']])
if ("___row_number" %in% tbl_vars(x)) {
x <- select(x, -`___row_number`)
}

x

}

# Import + export generics from dplyr and tidyr
#' Filtering joins from dplyr
#'
#' These are data manipulation functions designed to work on a \code{tbl_svy} object
#' and another data frame or \code{tbl_svy} object.
#'
#' \code{semi_join} and \code{anti_join} filter certain observations from a \code{tbl_svy}
#' depending on the presence or absence of matches in another table.
#' See \code{\link[dplyr]{filter-joins}} for more details.
#'
#' Mutating joins (\code{full_join}, \code{left_join}, etc.) are not implemented
#' for any \code{tbl_svy} objects. These data manipulations
#' may require modifications to the survey variable specifications and so
#' cannot be done automatically. Instead, use dplyr to perform them while the
#' data is still stored in data.frames.
#' @name dplyr_filter_joins
NULL

#' @name semi_join
#' @export
#' @importFrom dplyr semi_join
#' @rdname dplyr_filter_joins
NULL

#' @name anti_join
#' @export
#' @importFrom dplyr anti_join
#' @rdname dplyr_filter_joins
NULL
22 changes: 22 additions & 0 deletions man/dplyr_filter_joins.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

229 changes: 229 additions & 0 deletions tests/testthat/test_joins.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
context("filtering joins (semi_join and anti_join) work")

suppressPackageStartupMessages({
library(survey)
library(srvyr)
library(dplyr)
})

source("utilities.R")

# Set up example data ----

data(api)

##_ Create simple stratified survey design object ----
stratified_design <- apistrat %>%
as_survey_design(strata = stype, weights = pw)

##_ Create clustered survey design object ----
cluster_design <- as_survey_design(
.data = apiclus1,
id = dnum,
weights = pw,
fpc = fpc
)

##_ Create survey design object with calibration weights ----
##_ NOTE: The survey package uses special behavior when subsetting such survey designs.
##_ Rows are never removed, the weights are simply set effectively to zero (technically, Inf)

### Add raking weights for school type
pop.types <- data.frame(stype=c("E","H","M"), Freq=c(4421,755,1018))
pop.schwide <- data.frame(sch.wide=c("No","Yes"), Freq=c(1072,5122))

raked_design <- rake(
cluster_design,
sample.margins = list(~stype,~sch.wide),
population.margins = list(pop.types, pop.schwide)
)

# semi_join ----

test_that(
"semi_join works with `by = NULL`", {
# Stratified design
expect_equal(
## Calculate statistic, after using a filtering join
object = stratified_design %>%
semi_join(y = filter(apistrat, stype == "E")) %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = stratified_design %>%
filter(stype == "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)

# Cluster design
expect_equal(
## Calculate statistic, after using a filtering join
object = cluster_design %>%
semi_join(y = filter(apiclus1, stype == "E")) %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = cluster_design %>%
filter(stype == "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)

# Calibration weighted design
expect_equal(
## Calculate statistic, after using a filtering join
object = raked_design %>%
semi_join(y = filter(apiclus1, stype == "E")) %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = raked_design %>%
filter(stype == "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)
})

test_that(
"semi_join works with supplied `by` argument", {
# Stratified design
expect_equal(
## Calculate statistic, after using a filtering join
object = stratified_design %>%
semi_join(y = filter(apistrat, stype == "E"),
by = "stype") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = stratified_design %>%
filter(stype == "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)

# Cluster design
expect_equal(
## Calculate statistic, after using a filtering join
object = cluster_design %>%
semi_join(y = filter(apiclus1, stype == "E"),
by = "stype") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = cluster_design %>%
filter(stype == "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)

# Calibration weighted design
expect_equal(
## Calculate statistic, after using a filtering join
object = raked_design %>%
semi_join(y = filter(apiclus1, stype == "E"),
by = "stype") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = raked_design %>%
filter(stype == "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)
})

# anti_join ----

test_that(
"anti_join works with `by = NULL`", {
# Stratified design
expect_equal(
## Calculate statistic, after using a filtering join
object = stratified_design %>%
anti_join(y = filter(apistrat, stype == "E")) %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = stratified_design %>%
filter(stype != "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)

# Cluster design
expect_equal(
## Calculate statistic, after using a filtering join
object = cluster_design %>%
anti_join(y = filter(apiclus1, stype == "E")) %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = cluster_design %>%
filter(stype != "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)

# Calibration weighted design
expect_equal(
## Calculate statistic, after using a filtering join
object = raked_design %>%
anti_join(y = filter(apiclus1, stype == "E")) %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = raked_design %>%
filter(stype != "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)
})

test_that(
"anti_join works with supplied `by` argument", {
# Stratified design
expect_equal(
## Calculate statistic, after using a filtering join
object = stratified_design %>%
anti_join(y = filter(apistrat, stype == "E"),
by = "stype") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = stratified_design %>%
filter(stype != "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)

# Cluster design
expect_equal(
## Calculate statistic, after using a filtering join
object = cluster_design %>%
anti_join(y = filter(apiclus1, stype == "E"),
by = "stype") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = cluster_design %>%
filter(stype != "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)

# Calibration weighted design
expect_equal(
## Calculate statistic, after using a filtering join
object = raked_design %>%
anti_join(y = filter(apiclus1, stype == "E"),
by = "stype") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat"),
## Calculate statistic after manually filtering
expected = raked_design %>%
filter(stype != "E") %>%
summarize(stat = survey_mean(pcttest)) %>%
pull("stat")
)
})

0 comments on commit 60cff29

Please sign in to comment.