Skip to content

Commit

Permalink
Add test coverage of GI/delay reader
Browse files Browse the repository at this point in the history
  • Loading branch information
zsusswein committed Aug 30, 2024
1 parent 31af955 commit b637fc9
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 5 deletions.
35 changes: 30 additions & 5 deletions R/parameters.R
Original file line number Diff line number Diff line change
@@ -1,14 +1,38 @@
#' Read generation and delay interval PMFs into memory
#'
#' Using DuckDB from a parquet file. The function expects the file to be in SCD2
#' format with column names:
#' * parameter
#' * disease
#' * start_date
#' * end_date
#' * value
#'
#' @param path A path to a local file
#' @param disease One of "COVID-19" or "Influenza"
#' @param as_of_date The parameters "as of" the date of the model run
#' @param parameter One of "generation interval" or "delay"
#'
#' @return A PMF vector
#' @export
read_interval_pmf <- function(path,
disease,
disease = c("COVID-19", "Influenza", "test"),
as_of_date,
parameter = c("generation_interval", "delay")) {
rlang::arg_match(parameter)
rlang::arg_match(disease)

cli::cli_alert_info("Reading {.arg {parameter}} from {.path {path}}")
if (!file.exists(path)) {
cli::cli_abort("File {.path {path}} does not exist")
}

# The comparison fails if the dates are not in string format, reformat if
# needed
if (inherits(as_of_date, "Date")) {
as_of_date <- format(as_of_date, "%Y-%m-%d")
}

con <- DBI::dbConnect(duckdb::duckdb())
pmf_df <- rlang::try_fetch(
DBI::dbGetQuery(
Expand All @@ -20,8 +44,8 @@ read_interval_pmf <- function(path,
AND parameter = ?
AND disease = ?
AND start_date < ? :: DATE
AND (end_date > ? :: DATE OR end_date IS NULL)
LIMIT 5;
AND (end_date > ? OR end_date IS NULL)
;
",
params = list(
path,
Expand Down Expand Up @@ -51,8 +75,9 @@ read_interval_pmf <- function(path,

pmf <- pmf_df[["value"]][[1]]

if (length(pmf) < 1) {
cli::cli_abort(c("Invalid {.arg {parameter}} returned.",
if ((length(pmf) < 1) || !rlang::is_bare_numeric(pmf)) {
cli::cli_abort(c(
"Invalid {.arg {parameter}} returned.",
"x" = "Expected a PMF",
"i" = "Loaded object: {pmf}"
))
Expand Down
31 changes: 31 additions & 0 deletions tests/testthat/helper-write_parameter_file.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
write_sample_parameters_file <- function(value,
path,
state,
param,
disease,
parameter,
start_date,
end_date) {
df <- data.frame(
start_date = as.Date(start_date),
state = state,
disease = disease,
parameter = parameter,
end_date = end_date,
disease = disease,
value = I(list(value))
)

con <- DBI::dbConnect(duckdb::duckdb())

duckdb::duckdb_register(con, "test_table", df)
# This is bad practice but `dbBind()` doesn't allow us to parameterize COPY
# ... TO. The danger of doing it this way seems quite low risk because it's
# an ephemeral from a temporary in-memory DB. There's no actual database to
# guard against a SQL injection attack.
query <- paste0("COPY (SELECT * FROM test_table) TO '", path, "'")
DBI::dbExecute(con, query)
DBI::dbDisconnect(con)

invisible(path)
}
208 changes: 208 additions & 0 deletions tests/testthat/test-parameters.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
test_that("Can read on happy path", {
expected <- c(0.8, 0.2)
path <- "test.parquet"
parameter <- "delay"
start_date <- as.Date("2023-01-01")

# COVID-19
disease <- "COVID-19"
withr::with_tempdir({
write_sample_parameters_file(
value = expected,
path = path,
state = "test",
disease = disease,
parameter = parameter,
param = parameter,
start_date = start_date,
end_date = NA
)
actual <- read_interval_pmf(
path = path,
disease = disease,
as_of_date = start_date + 1,
parameter = parameter
)
})
expect_equal(actual, expected)


# Influenza
disease <- "Influenza"
withr::with_tempdir({
write_sample_parameters_file(
value = expected,
path = path,
state = "test",
disease = disease,
parameter = parameter,
param = parameter,
start_date = start_date,
end_date = NA
)
actual <- read_interval_pmf(
path = path,
disease = disease,
as_of_date = start_date + 1,
parameter = parameter
)
})
expect_equal(actual, expected)
})


test_that("Not a PMF errors", {
expected <- "hello"
path <- "test.parquet"
parameter <- "delay"
start_date <- as.Date("2023-01-01")

# COVID-19
disease <- "COVID-19"
withr::with_tempdir({
write_sample_parameters_file(
value = expected,
path = path,
state = "test",
disease = disease,
parameter = parameter,
param = parameter,
start_date = start_date,
end_date = NA
)
expect_error(read_interval_pmf(
path = path,
disease = disease,
as_of_date = start_date + 1,
parameter = parameter
))
})
})

test_that("Invalid disease errors", {
expected <- c(0.8, 0.2)
path <- "test.parquet"
parameter <- "delay"
start_date <- as.Date("2023-01-01")
disease <- "not_a_valid_disease"

withr::with_tempdir({
write_sample_parameters_file(
value = expected,
path = path,
state = "test",
disease = disease,
parameter = parameter,
param = parameter,
start_date = start_date,
end_date = NA
)

expect_error(read_interval_pmf(
path = path,
disease = disease,
as_of_date = start_date + 1,
parameter = parameter
))
})
})

test_that("Invalid parameter errors", {
expected <- c(0.8, 0.2)
path <- "test.parquet"
parameter <- "not_a_valid_parameter"
start_date <- as.Date("2023-01-01")
disease <- "COVID-19"

withr::with_tempdir({
write_sample_parameters_file(
value = expected,
path = path,
state = "test",
disease = disease,
parameter = parameter,
param = parameter,
start_date = start_date,
end_date = NA
)

expect_error(read_interval_pmf(
path = path,
disease = disease,
as_of_date = start_date + 1,
parameter = parameter
))
})
})

test_that("Return isn't exactly one errors", {
expected <- c(0.8, 0.2)
path <- "test.parquet"
parameter <- "delay"
start_date <- as.Date("2023-01-01")
disease <- "COVID-19"

withr::with_tempdir({
write_sample_parameters_file(
value = expected,
path = path,
state = "test",
disease = disease,
parameter = parameter,
param = parameter,
start_date = start_date,
end_date = NA
)

# Date too early
expect_error(read_interval_pmf(
path = path,
disease = disease,
as_of_date = start_date - 1,
parameter = parameter
))
})
})

test_that("No file exists errors", {
expected <- c(0.8, 0.2)
path <- "test.parquet"
parameter <- "delay"
start_date <- as.Date("2023-01-01")
disease <- "COVID-19"

expect_error(read_interval_pmf(
path = "not_a_real_file",
disease = disease,
as_of_date = start_date - 1,
parameter = parameter
))
})

test_that("Invalid query throws wrapped error", {
expected <- c(0.8, 0.2)
path <- "test.parquet"
parameter <- "delay"
start_date <- as.Date("2023-01-01")
disease <- "COVID-19"

withr::with_tempdir({
write_sample_parameters_file(
value = expected,
path = path,
state = "test",
disease = disease,
parameter = parameter,
param = parameter,
start_date = start_date,
end_date = NA
)

expect_error(read_interval_pmf(
path = path,
disease = disease,
as_of_date = "abc123",
parameter = parameter
))
})
})

0 comments on commit b637fc9

Please sign in to comment.