Skip to content

Commit

Permalink
DRYify params
Browse files Browse the repository at this point in the history
  • Loading branch information
zsusswein committed Aug 31, 2024
1 parent 05b0ed6 commit fa6acaa
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 91 deletions.
132 changes: 47 additions & 85 deletions R/parameters.R
Original file line number Diff line number Diff line change
@@ -1,66 +1,9 @@
#' Read right-truncation PMF into memory
#'
read_right_truncation_pmf <- function(path,
disease = c(
"COVID-19",
"Influenza",
"test"
),
as_of_date,
state) {
rlang::arg_match(disease)
check_path(path)

# The comparison fails if the dates are not in string format, reformat if
# needed
if (inherits(as_of_date, "Date")) {
as_of_date <- format(as_of_date, "%Y-%m-%d")
}

con <- DBI::dbConnect(duckdb::duckdb())
pmf_df <- rlang::try_fetch(
DBI::dbGetQuery(
conn = con,
statement = "
SELECT value
FROM read_parquet(?)
WHERE 1=1
AND parameter = 'right_truncation'
AND disease = ?
AND start_date < ? :: DATE
AND (end_date > ? OR end_date IS NULL)
AND geo_value = ?
;
",
params = list(
path,
disease,
as_of_date,
as_of_date,
state
)
),
error = function(cnd) {
cli::cli_abort(c(
"Failure loading {.arg right_truncation} from {.path {path}}",
"Using {.val {disease}} and {.val {as_of_date}}",
"Original error: {cnd}"
))
}
)
DBI::dbDisconnect(con)

pmf <- check_returned_pmf(pmf_df)
cli::cli_alert_success("{.arg right_truncation} loaded")

return(pmf)
}

#' Read generation and delay interval PMFs into memory
#' Read parameter PMFs into memory
#'
#' Using DuckDB from a parquet file. The function expects the file to be in SCD2
#' format with column names:
#' * parameter
#' * geo_value
#' * disease
#' * start_date
#' * end_date
Expand All @@ -69,18 +12,25 @@ read_right_truncation_pmf <- function(path,
#' @param path A path to a local file
#' @param disease One of "COVID-19" or "Influenza"
#' @param as_of_date The parameters "as of" the date of the model run
#' @param parameter One of "generation interval" or "delay"
#' @param parameter One of "generation interval", "delay", or "right-truncation
#'
#' @return A PMF vector
#' @export
read_interval_pmf <- function(path,
disease = c("COVID-19", "Influenza", "test"),
as_of_date,
parameter = c("generation_interval", "delay")) {
parameter = c(
"generation_interval",
"delay",
"right_truncation"
),
state = NA) {
###################
# Validate input
rlang::arg_match(parameter)
rlang::arg_match(disease)

cli::cli_alert_info("Reading {.arg {parameter}} from {.path {path}}")
cli::cli_alert_info("Reading {.arg right_truncation} from {.path {path}}")
if (!file.exists(path)) {
cli::cli_abort("File {.path {path}} does not exist")
}
Expand All @@ -91,38 +41,58 @@ read_interval_pmf <- function(path,
as_of_date <- format(as_of_date, "%Y-%m-%d")
}

con <- DBI::dbConnect(duckdb::duckdb())
pmf_df <- rlang::try_fetch(
DBI::dbGetQuery(
conn = con,
statement = "
################
# Prepare query

query <- "
SELECT value
FROM read_parquet(?)
WHERE 1=1
AND parameter = ?
AND disease = ?
AND start_date < ? :: DATE
AND (end_date > ? OR end_date IS NULL)
;
",
params = list(
path,
parameter,
disease,
as_of_date,
as_of_date
)
"
parameters <- list(
path,
parameter,
disease,
as_of_date,
as_of_date
)

# Handle state separately because can't use `=` for NULL comparison and
# DBI::dbBind() can't parameterize a query after IS
if (rlang::is_na(state) || rlang::is_null(state)) {
query <- paste(query, "AND geo_value IS NULL;")
} else {
query <- paste(query, "AND geo_value = ?")
parameters <- c(parameters, list(state))
}

################
# Execute query

con <- DBI::dbConnect(duckdb::duckdb())
pmf_df <- rlang::try_fetch(
DBI::dbGetQuery(
conn = con,
statement = query,
params = parameters
),
error = function(cnd) {
cli::cli_abort(c(
"Failure loading {.arg {parameter}} from {.path {path}}",
"Using {.val {disease}} and {.val {as_of_date}}",
"Using {.val {disease}}, {.val {as_of_date}}, and {.val {state}}",
"Original error: {cnd}"
))
}
)
DBI::dbDisconnect(con)


################
# Validate loaded PMF
pmf <- check_returned_pmf(pmf_df)
cli::cli_alert_success("{.arg {parameter}} loaded")

Expand Down Expand Up @@ -158,11 +128,3 @@ check_returned_pmf <- function(pmf_df) {

return(pmf)
}

check_path <- function(path) {
cli::cli_alert_info("Reading {.arg right_truncation} from {.path {path}}")
if (!file.exists(path)) {
cli::cli_abort("File {.path {path}} does not exist")
}
invisible(NULL)
}
42 changes: 36 additions & 6 deletions tests/testthat/test-parameters.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ test_that("Can read right-truncation on happy path", {
start_date = start_date,
end_date = NA
)
actual <- read_right_truncation_pmf(
actual <- read_interval_pmf(
path = path,
parameter = parameter,
disease = disease,
as_of_date = start_date + 1,
state = "test"
Expand All @@ -40,8 +41,9 @@ test_that("Can read right-truncation on happy path", {
start_date = start_date,
end_date = NA
)
actual <- read_right_truncation_pmf(
actual <- read_interval_pmf(
path = path,
parameter = parameter,
disease = disease,
as_of_date = start_date + 1,
state = "test"
Expand All @@ -50,6 +52,34 @@ test_that("Can read right-truncation on happy path", {
expect_equal(actual, expected)
})

test_that("Invalid PMF errors", {
expected <- c(0.8, -0.1)
path <- "test.parquet"
parameter <- "right_truncation"
start_date <- as.Date("2023-01-01")

# COVID-19
disease <- "COVID-19"
withr::with_tempdir({
write_sample_parameters_file(
value = expected,
path = path,
state = "test",
disease = disease,
parameter = parameter,
param = parameter,
start_date = start_date,
end_date = NA
)
expect_error(read_interval_pmf(
path = path,
parameter = parameter,
disease = disease,
as_of_date = start_date + 1,
state = "test"
))
})
})


test_that("Can read delay on happy path", {
Expand All @@ -64,7 +94,7 @@ test_that("Can read delay on happy path", {
write_sample_parameters_file(
value = expected,
path = path,
state = "test",
state = NA,
disease = disease,
parameter = parameter,
param = parameter,
Expand All @@ -87,7 +117,7 @@ test_that("Can read delay on happy path", {
write_sample_parameters_file(
value = expected,
path = path,
state = "test",
state = NA,
disease = disease,
parameter = parameter,
param = parameter,
Expand Down Expand Up @@ -117,7 +147,7 @@ test_that("Not a PMF errors", {
write_sample_parameters_file(
value = expected,
path = path,
state = "test",
state = NA,
disease = disease,
parameter = parameter,
param = parameter,
Expand Down Expand Up @@ -244,7 +274,7 @@ test_that("Invalid query throws wrapped error", {
write_sample_parameters_file(
value = expected,
path = path,
state = "test",
state = NA,
disease = disease,
parameter = parameter,
param = parameter,
Expand Down

0 comments on commit fa6acaa

Please sign in to comment.