diff --git a/R/parameters.R b/R/parameters.R index 4804420c..bfef758f 100644 --- a/R/parameters.R +++ b/R/parameters.R @@ -1,66 +1,9 @@ -#' Read right-truncation PMF into memory -#' -read_right_truncation_pmf <- function(path, - disease = c( - "COVID-19", - "Influenza", - "test" - ), - as_of_date, - state) { - rlang::arg_match(disease) - check_path(path) - - # The comparison fails if the dates are not in string format, reformat if - # needed - if (inherits(as_of_date, "Date")) { - as_of_date <- format(as_of_date, "%Y-%m-%d") - } - - con <- DBI::dbConnect(duckdb::duckdb()) - pmf_df <- rlang::try_fetch( - DBI::dbGetQuery( - conn = con, - statement = " - SELECT value - FROM read_parquet(?) - WHERE 1=1 - AND parameter = 'right_truncation' - AND disease = ? - AND start_date < ? :: DATE - AND (end_date > ? OR end_date IS NULL) - AND geo_value = ? - ; - ", - params = list( - path, - disease, - as_of_date, - as_of_date, - state - ) - ), - error = function(cnd) { - cli::cli_abort(c( - "Failure loading {.arg right_truncation} from {.path {path}}", - "Using {.val {disease}} and {.val {as_of_date}}", - "Original error: {cnd}" - )) - } - ) - DBI::dbDisconnect(con) - - pmf <- check_returned_pmf(pmf_df) - cli::cli_alert_success("{.arg right_truncation} loaded") - - return(pmf) -} - -#' Read generation and delay interval PMFs into memory +#' Read parameter PMFs into memory #' #' Using DuckDB from a parquet file. The function expects the file to be in SCD2 #' format with column names: #' * parameter +#' * geo_value #' * disease #' * start_date #' * end_date @@ -69,21 +12,23 @@ read_right_truncation_pmf <- function(path, #' @param path A path to a local file #' @param disease One of "COVID-19" or "Influenza" #' @param as_of_date The parameters "as of" the date of the model run -#' @param parameter One of "generation interval" or "delay" +#' @param parameter One of "generation interval", "delay", or "right-truncation #' #' @return A PMF vector #' @export read_interval_pmf <- function(path, disease = c("COVID-19", "Influenza", "test"), as_of_date, - parameter = c("generation_interval", "delay")) { + parameter = c( + "generation_interval", + "delay", + "right_truncation" + ), + state = NA) { rlang::arg_match(parameter) rlang::arg_match(disease) - cli::cli_alert_info("Reading {.arg {parameter}} from {.path {path}}") - if (!file.exists(path)) { - cli::cli_abort("File {.path {path}} does not exist") - } + check_path(path) # The comparison fails if the dates are not in string format, reformat if # needed @@ -91,11 +36,9 @@ read_interval_pmf <- function(path, as_of_date <- format(as_of_date, "%Y-%m-%d") } - con <- DBI::dbConnect(duckdb::duckdb()) - pmf_df <- rlang::try_fetch( - DBI::dbGetQuery( - conn = con, - statement = " + # Can't use `=` for NULL comparison and dbBind can't insert a ? after an is + if (rlang::is_na(state) || rlang::is_null(state)) { + query <- " SELECT value FROM read_parquet(?) WHERE 1=1 @@ -103,20 +46,50 @@ read_interval_pmf <- function(path, AND disease = ? AND start_date < ? :: DATE AND (end_date > ? OR end_date IS NULL) - ; - ", - params = list( - path, - parameter, - disease, - as_of_date, - as_of_date - ) + AND geo_value IS NULL + ; + " + parameters <- list( + path, + parameter, + disease, + as_of_date, + as_of_date + ) + } else { + query <- " + SELECT value + FROM read_parquet(?) + WHERE 1=1 + AND parameter = ? + AND disease = ? + AND start_date < ? :: DATE + AND (end_date > ? OR end_date IS NULL) + AND geo_value = ? + ; + " + + parameters <- list( + path, + parameter, + disease, + as_of_date, + as_of_date, + state + ) + } + + con <- DBI::dbConnect(duckdb::duckdb()) + pmf_df <- rlang::try_fetch( + DBI::dbGetQuery( + conn = con, + statement = query, + params = parameters ), error = function(cnd) { cli::cli_abort(c( "Failure loading {.arg {parameter}} from {.path {path}}", - "Using {.val {disease}} and {.val {as_of_date}}", + "Using {.val {disease}}, {.val {as_of_date}}, and {.val {state}}", "Original error: {cnd}" )) } diff --git a/tests/testthat/test-parameters.R b/tests/testthat/test-parameters.R index 01a1677d..b7c77577 100644 --- a/tests/testthat/test-parameters.R +++ b/tests/testthat/test-parameters.R @@ -17,8 +17,9 @@ test_that("Can read right-truncation on happy path", { start_date = start_date, end_date = NA ) - actual <- read_right_truncation_pmf( + actual <- read_interval_pmf( path = path, + parameter = parameter, disease = disease, as_of_date = start_date + 1, state = "test" @@ -40,8 +41,9 @@ test_that("Can read right-truncation on happy path", { start_date = start_date, end_date = NA ) - actual <- read_right_truncation_pmf( + actual <- read_interval_pmf( path = path, + parameter = parameter, disease = disease, as_of_date = start_date + 1, state = "test" @@ -50,6 +52,34 @@ test_that("Can read right-truncation on happy path", { expect_equal(actual, expected) }) +test_that("Invalid PMF errors", { + expected <- c(0.8, -0.1) + path <- "test.parquet" + parameter <- "right_truncation" + start_date <- as.Date("2023-01-01") + + # COVID-19 + disease <- "COVID-19" + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = "test", + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + expect_error(read_interval_pmf( + path = path, + parameter = parameter, + disease = disease, + as_of_date = start_date + 1, + state = "test" + )) + }) +}) test_that("Can read delay on happy path", { @@ -64,7 +94,7 @@ test_that("Can read delay on happy path", { write_sample_parameters_file( value = expected, path = path, - state = "test", + state = NA, disease = disease, parameter = parameter, param = parameter, @@ -87,7 +117,7 @@ test_that("Can read delay on happy path", { write_sample_parameters_file( value = expected, path = path, - state = "test", + state = NA, disease = disease, parameter = parameter, param = parameter,