Skip to content

Commit

Permalink
DRYify params
Browse files Browse the repository at this point in the history
  • Loading branch information
zsusswein committed Aug 31, 2024
1 parent 05b0ed6 commit b6045f6
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 84 deletions.
133 changes: 53 additions & 80 deletions R/parameters.R
Original file line number Diff line number Diff line change
@@ -1,66 +1,9 @@
#' Read right-truncation PMF into memory
#'
read_right_truncation_pmf <- function(path,
disease = c(
"COVID-19",
"Influenza",
"test"
),
as_of_date,
state) {
rlang::arg_match(disease)
check_path(path)

# The comparison fails if the dates are not in string format, reformat if
# needed
if (inherits(as_of_date, "Date")) {
as_of_date <- format(as_of_date, "%Y-%m-%d")
}

con <- DBI::dbConnect(duckdb::duckdb())
pmf_df <- rlang::try_fetch(
DBI::dbGetQuery(
conn = con,
statement = "
SELECT value
FROM read_parquet(?)
WHERE 1=1
AND parameter = 'right_truncation'
AND disease = ?
AND start_date < ? :: DATE
AND (end_date > ? OR end_date IS NULL)
AND geo_value = ?
;
",
params = list(
path,
disease,
as_of_date,
as_of_date,
state
)
),
error = function(cnd) {
cli::cli_abort(c(
"Failure loading {.arg right_truncation} from {.path {path}}",
"Using {.val {disease}} and {.val {as_of_date}}",
"Original error: {cnd}"
))
}
)
DBI::dbDisconnect(con)

pmf <- check_returned_pmf(pmf_df)
cli::cli_alert_success("{.arg right_truncation} loaded")

return(pmf)
}

#' Read generation and delay interval PMFs into memory
#' Read parameter PMFs into memory
#'
#' Using DuckDB from a parquet file. The function expects the file to be in SCD2
#' format with column names:
#' * parameter
#' * geo_value
#' * disease
#' * start_date
#' * end_date
Expand All @@ -69,54 +12,84 @@ read_right_truncation_pmf <- function(path,
#' @param path A path to a local file
#' @param disease One of "COVID-19" or "Influenza"
#' @param as_of_date The parameters "as of" the date of the model run
#' @param parameter One of "generation interval" or "delay"
#' @param parameter One of "generation interval", "delay", or "right-truncation
#'
#' @return A PMF vector
#' @export
read_interval_pmf <- function(path,
disease = c("COVID-19", "Influenza", "test"),
as_of_date,
parameter = c("generation_interval", "delay")) {
parameter = c(
"generation_interval",
"delay",
"right_truncation"
),
state = NA) {
rlang::arg_match(parameter)
rlang::arg_match(disease)

cli::cli_alert_info("Reading {.arg {parameter}} from {.path {path}}")
if (!file.exists(path)) {
cli::cli_abort("File {.path {path}} does not exist")
}
check_path(path)

# The comparison fails if the dates are not in string format, reformat if
# needed
if (inherits(as_of_date, "Date")) {
as_of_date <- format(as_of_date, "%Y-%m-%d")
}

con <- DBI::dbConnect(duckdb::duckdb())
pmf_df <- rlang::try_fetch(
DBI::dbGetQuery(
conn = con,
statement = "
# Can't use `=` for NULL comparison and dbBind can't insert a ? after an is
if (rlang::is_na(state) || rlang::is_null(state)) {
query <- "
SELECT value
FROM read_parquet(?)
WHERE 1=1
AND parameter = ?
AND disease = ?
AND start_date < ? :: DATE
AND (end_date > ? OR end_date IS NULL)
;
",
params = list(
path,
parameter,
disease,
as_of_date,
as_of_date
)
AND geo_value IS NULL
;
"
parameters <- list(
path,
parameter,
disease,
as_of_date,
as_of_date
)
} else {
query <- "
SELECT value
FROM read_parquet(?)
WHERE 1=1
AND parameter = ?
AND disease = ?
AND start_date < ? :: DATE
AND (end_date > ? OR end_date IS NULL)
AND geo_value = ?
;
"

parameters <- list(
path,
parameter,
disease,
as_of_date,
as_of_date,
state
)
}

con <- DBI::dbConnect(duckdb::duckdb())
pmf_df <- rlang::try_fetch(
DBI::dbGetQuery(
conn = con,
statement = query,
params = parameters
),
error = function(cnd) {
cli::cli_abort(c(
"Failure loading {.arg {parameter}} from {.path {path}}",
"Using {.val {disease}} and {.val {as_of_date}}",
"Using {.val {disease}}, {.val {as_of_date}}, and {.val {state}}",
"Original error: {cnd}"

Check warning on line 93 in R/parameters.R

View check run for this annotation

Codecov / codecov/patch

R/parameters.R#L90-L93

Added lines #L90 - L93 were not covered by tests
))
}
Expand Down
38 changes: 34 additions & 4 deletions tests/testthat/test-parameters.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ test_that("Can read right-truncation on happy path", {
start_date = start_date,
end_date = NA
)
actual <- read_right_truncation_pmf(
actual <- read_interval_pmf(
path = path,
parameter = parameter,
disease = disease,
as_of_date = start_date + 1,
state = "test"
Expand All @@ -40,8 +41,9 @@ test_that("Can read right-truncation on happy path", {
start_date = start_date,
end_date = NA
)
actual <- read_right_truncation_pmf(
actual <- read_interval_pmf(
path = path,
parameter = parameter,
disease = disease,
as_of_date = start_date + 1,
state = "test"
Expand All @@ -50,6 +52,34 @@ test_that("Can read right-truncation on happy path", {
expect_equal(actual, expected)
})

test_that("Invalid PMF errors", {
expected <- c(0.8, -0.1)
path <- "test.parquet"
parameter <- "right_truncation"
start_date <- as.Date("2023-01-01")

# COVID-19
disease <- "COVID-19"
withr::with_tempdir({
write_sample_parameters_file(
value = expected,
path = path,
state = "test",
disease = disease,
parameter = parameter,
param = parameter,
start_date = start_date,
end_date = NA
)
expect_error(read_interval_pmf(
path = path,
parameter = parameter,
disease = disease,
as_of_date = start_date + 1,
state = "test"
))
})
})


test_that("Can read delay on happy path", {
Expand All @@ -64,7 +94,7 @@ test_that("Can read delay on happy path", {
write_sample_parameters_file(
value = expected,
path = path,
state = "test",
state = NA,
disease = disease,
parameter = parameter,
param = parameter,
Expand All @@ -87,7 +117,7 @@ test_that("Can read delay on happy path", {
write_sample_parameters_file(
value = expected,
path = path,
state = "test",
state = NA,
disease = disease,
parameter = parameter,
param = parameter,
Expand Down

0 comments on commit b6045f6

Please sign in to comment.