From 9152bc0ee629562a1e14be453459c05cd549aac4 Mon Sep 17 00:00:00 2001 From: Zachary Susswein <46581799+zsusswein@users.noreply.github.com> Date: Wed, 4 Sep 2024 09:55:24 -0400 Subject: [PATCH] Read in parameter PMFs (#16) * Parameter PMF reader For generation interval, delays, and right-truncation. It assumes that: - We want a PMF - The PMF is coming from a file with schema specified as in cdcent/cfa-parameter-estimates#9 - The parameter names and disease names are following a specified schema - The file is in parquet format - The PMFs are actually proper PMFs It relaxes the assumptions that: - The PMFs are coming from the same file - The PMFs must be present (can skip delays and right-truncation but not GI) - The files are in Azure or have a specific name/path Unit tests cover successfully reading all the parameters individually as well as in combination. They also check for failure in the expected places for desired failure modes. Switching over to this schema from manual CSVs now is a bit of a choice. I took three cracks at this PR before landing on doing it this way. I think it provides a couple of important benefits to make the switch now and that comes through in designing the code here: 1. Our existing CSV-based approach produces 3 distinct files with close-ish schemas, but not an exact match. It requires distinct reader functions and substantially more code than reading from a file with a unified schema like I do here. 2. I think this approach helps avoid issues like this week's production mishap. It allows us to mix and match files as needed, and we can point to a single drop-in file for testing if desired. We don't need to fiddle with the production environment. 3. I think we're going to need to make a switch on the parameter approach sooner or later. I wanted to bite the bullet and get it over with all at once. It's convenient that it helps make the code simpler, but I think making changes swiftly rather than dragging things out provides its own benefit too. * Bump NEWS * Document new functions * Update NEWS.md * Drop roxygenize hooking -- it's fully broken * Clarify `as_of` date * Clarify state-varying params * Use classed errors For additional specificity linking expected errors to tests. I use the regexp option instead of classed errors for `arg_match()` because I had trouble catching those even when matching the class exactly. * Document * Apply suggestions from code review Co-authored-by: Katie Gostic (she/her) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add clarifying info from code review * state -> group But I left state in the documentation for clarity because we don't currently do anything at the sub-state level * Document * Update tests & docs with changes from review * `read_parameters()` -> `read_disease_parameters()` --------- Co-authored-by: zsusswein Co-authored-by: Katie Gostic (she/her) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 10 +- DESCRIPTION | 2 + NAMESPACE | 2 + NEWS.md | 1 + R/parameters.R | 238 +++++++++++ man/add_two_numbers.Rd | 3 +- man/read_disease_parameters.Rd | 45 ++ man/read_interval_pmf.Rd | 53 +++ tests/testthat/helper-write_parameter_file.R | 30 ++ tests/testthat/test-parameters.R | 424 +++++++++++++++++++ 10 files changed, 801 insertions(+), 7 deletions(-) create mode 100644 R/parameters.R create mode 100644 man/read_disease_parameters.Rd create mode 100644 man/read_interval_pmf.Rd create mode 100644 tests/testthat/helper-write_parameter_file.R create mode 100644 tests/testthat/test-parameters.R diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 07413981..91af2392 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,6 @@ repos: - id: style-files args: [--style_pkg=styler, --style_fun=tidyverse_style, --cache-root=styler-perm] - - id: roxygenize - id: use-tidy-description - id: lintr - id: readme-rmd-rendered @@ -49,7 +48,7 @@ repos: ##### # Python - repo: https://github.com/psf/black - rev: 24.4.2 + rev: 24.8.0 hooks: # if you have ipython notebooks, consider using # `black-jupyter` hook instead @@ -62,13 +61,13 @@ repos: args: ['--profile', 'black', '--line-length', '79'] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.2 + rev: v0.6.3 hooks: - id: ruff ##### # Java - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks - rev: v2.13.0 + rev: v2.14.0 hooks: - id: pretty-format-java args: [--aosp,--autofix] @@ -83,7 +82,7 @@ repos: ##### # Secrets - repo: https://github.com/Yelp/detect-secrets - rev: v1.4.0 + rev: v1.5.0 hooks: - id: detect-secrets args: ['--baseline', '.secrets.baseline'] @@ -97,5 +96,4 @@ ci: autoupdate_branch: '' autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate' autoupdate_schedule: weekly - skip: [roxygenize] submodules: false diff --git a/DESCRIPTION b/DESCRIPTION index 35c367e0..390d1e7d 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -26,5 +26,7 @@ Imports: AzureRMR, AzureStor, cli, + DBI, + duckdb, rlang URL: https://cdcgov.github.io/cfa-epinow2-pipeline/ diff --git a/NAMESPACE b/NAMESPACE index b6920831..25d4e0a6 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -4,3 +4,5 @@ export(add_two_numbers) export(download_from_azure_blob) export(fetch_blob_container) export(fetch_credential_from_env_var) +export(read_disease_parameters) +export(read_interval_pmf) diff --git a/NEWS.md b/NEWS.md index bfa2eb3b..efaf34db 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # CFAEpiNow2Pipeline (development version) +* Parameters read from local parquet file or files * Additional CI bug squashing * Bug fixed in the updated, faster pre-commit checks * Updated, faster pre-commit checks diff --git a/R/parameters.R b/R/parameters.R new file mode 100644 index 00000000..a6c5dbfa --- /dev/null +++ b/R/parameters.R @@ -0,0 +1,238 @@ +#' Read in disease process parameters from an external file or files +#' +#' @param generation_interval_path,delay_interval_path,right_truncation_path +#' Path to a local file with the parameter PMF. See [read_interval_pmf] for +#' details on the file schema. The parameters can be in the same file or a +#' different file. +#' @param disease One of `COVID-19` or `Influenza` +#' @param as_of_date Use the parameters that were used in production on this +#' date. Set for the current date for the most up-to-to date version of the +#' parameters and set to an earlier date to use parameters from an earlier +#' time period. +#' @param group Used only for parameters with a state-level estimate (i.e., only +#' right-truncation). The two-letter uppercase state abbreviation. +#' +#' @return A named list with three PMFs. The list elements are named +#' `generation_interval`, `delay_interval`, and `right_truncation`. If a path +#' to a local file is not provided (NA or NULL), the corresponding parameter +#' estimate will be NA in the returned list. +#' @details `generation_interval_path` is required because the generation +#' interval is a required parameter for $R_t$ estimation. +#' `delay_interval_path` and `right_truncation_path` are optional (but +#' strongly suggested). +#' @export +read_disease_parameters <- function( + generation_interval_path, + delay_interval_path, + right_truncation_path, + disease, + as_of_date, + group) { + generation_interval <- read_interval_pmf( + path = generation_interval_path, + disease = disease, + as_of_date = as_of_date, + parameter = "generation_interval" + ) + + if (path_is_specified(delay_interval_path)) { + delay_interval <- read_interval_pmf( + path = delay_interval_path, + disease = disease, + as_of_date = as_of_date, + parameter = "delay" + ) + } else { + cli::cli_alert_warning( + "No delay interval path specified. Using a delay of 0 days." + ) + delay_interval <- NA + } + + if (path_is_specified(right_truncation_path)) { + right_truncation <- read_interval_pmf( + path = right_truncation_path, + disease = disease, + as_of_date = as_of_date, + parameter = "right_truncation", + group = group + ) + } else { + cli::cli_alert_warning( + "No right truncation path specified. Not adjusting for right truncation." + ) + right_truncation <- NA + } + + parameters <- list( + generation_interval = generation_interval, + delay_interval = delay_interval, + right_truncation = right_truncation + ) + return(parameters) +} + +path_is_specified <- function(path) { + !rlang::is_null(path) && + !rlang::is_na(path) +} + +#' Read parameter PMF into memory +#' +#' Using DuckDB from a parquet file. The function expects the file to be in SCD2 +#' format with column names: +#' * parameter +#' * geo_value +#' * disease +#' * start_date +#' * end_date +#' * value +#' +#' start_date and end_date specify the date range for which the value was used. +#' end_date may be NULL (e.g. for the current value used in production). value +#' must contain a pmf vector whose values are all positive and sum to 1. all +#' other fields must be consistent with the specifications of the function +#' arguments describe below, which are used to query from the .parquet file. +#' +#' SCD2 format is shorthand for slowly changing dimension type 2. This format is +#' normalized to track change over time: +#' https://en.wikipedia.org/wiki/Slowly_changing_dimension#Type_2:_add_new_row +#' +#' @param path A path to a local file +#' @param disease One of "COVID-19" or "Influenza" +#' @param as_of_date The parameters "as of" the date of the model run +#' @param parameter One of "generation interval", "delay", or "right-truncation +#' @param group An optional parameter to subset the query to a parameter with a +#' particular two-letter state abbrevation. Right now, the only parameter with +#' state-specific estimates is `right-truncation`. +#' +#' @return A PMF vector +#' @export +read_interval_pmf <- function(path, + disease = c("COVID-19", "Influenza", "test"), + as_of_date, + parameter = c( + "generation_interval", + "delay", + "right_truncation" + ), + group = NA) { + ################### + # Validate input + rlang::arg_match(parameter) + rlang::arg_match(disease) + + as_of_date <- stringify_date(as_of_date) + cli::cli_alert_info("Reading {.arg right_truncation} from {.path {path}}") + if (!file.exists(path)) { + cli::cli_abort("File {.path {path}} does not exist", + class = "file_not_found" + ) + } + + + ################ + # Prepare query + + query <- " + SELECT value + FROM read_parquet(?) + WHERE 1=1 + AND parameter = ? + AND disease = ? + AND start_date < ? :: DATE + AND (end_date > ? OR end_date IS NULL) + " + parameters <- list( + path, + parameter, + disease, + as_of_date, + as_of_date + ) + + # Handle state separately because can't use `=` for NULL comparison and + # DBI::dbBind() can't parameterize a query after IS + if (rlang::is_na(group) || rlang::is_null(group)) { + query <- paste(query, "AND geo_value IS NULL;") + } else { + query <- paste(query, "AND geo_value = ?") + parameters <- c(parameters, list(group)) + } + + ################ + # Execute query + + con <- DBI::dbConnect(duckdb::duckdb()) + pmf_df <- rlang::try_fetch( + DBI::dbGetQuery( + conn = con, + statement = query, + params = parameters + ), + error = function(cnd) { + cli::cli_abort( + c( + "Failure loading {.arg {parameter}} from {.path {path}}", + "Using {.val {disease}}, {.val {as_of_date}}, and {.val {group}}", + "Original error: {cnd}" + ), + class = "wrapped_error" + ) + } + ) + DBI::dbDisconnect(con) + + + ################ + # Validate loaded PMF + if (nrow(pmf_df) != 1) { + cli::cli_abort( + c( + "Failure loading {.arg {parameter}} from {.path {path}} ", + "Query did not return exactly one row", + "Using {.val {disease}}, {.val {as_of_date}}, and {.val {group}}", + "Query matched {.val {nrow(pmf_df)}} rows" + ), + class = "not_one_row_returned" + ) + } + + pmf <- pmf_df[["value"]][[1]] + + if ((length(pmf) < 1) || !rlang::is_bare_numeric(pmf)) { + cli::cli_abort( + c( + "Invalid {.arg {parameter}} returned.", + "x" = "Expected a PMF", + "i" = "Loaded object: {pmf_df}" + ), + class = "not_a_pmf" + ) + } + + if (any(pmf < 0) || any(pmf > 1) || abs(sum(pmf) - 1) > 1e-10) { + cli::cli_abort( + c( + "Returned numeric vector is not a valid PMF", + "Any below 0: {any(pmf < 0)}", + "Any above 1: {any(pmf > 1)}", + "Sum is within 1 with tol of 1e-10: {abs(sum(pmf) - 1) < 1e-10}", + "pmf: {.val {pmf}}" + ), + class = "invalid_pmf" + ) + } + + cli::cli_alert_success("{.arg {parameter}} loaded") + + return(pmf) +} + +#' DuckDB date comparison fails if the dates are not in string format +#' @noRd +stringify_date <- function(date) { + if (inherits(date, "Date")) { + format(date, "%Y-%m-%d") + } +} diff --git a/man/add_two_numbers.Rd b/man/add_two_numbers.Rd index ba7a1446..ffd18496 100644 --- a/man/add_two_numbers.Rd +++ b/man/add_two_numbers.Rd @@ -15,5 +15,6 @@ add_two_numbers(x, y) Their sum } \description{ -A temp function +Adding some meaningless text to test rendering via PR command. +Adding some more } diff --git a/man/read_disease_parameters.Rd b/man/read_disease_parameters.Rd new file mode 100644 index 00000000..7003932b --- /dev/null +++ b/man/read_disease_parameters.Rd @@ -0,0 +1,45 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/parameters.R +\name{read_disease_parameters} +\alias{read_disease_parameters} +\title{Read in disease process parameters from an external file or files} +\usage{ +read_disease_parameters( + generation_interval_path, + delay_interval_path, + right_truncation_path, + disease, + as_of_date, + group +) +} +\arguments{ +\item{generation_interval_path, delay_interval_path, right_truncation_path}{Path to a local file with the parameter PMF. See \link{read_interval_pmf} for +details on the file schema. The parameters can be in the same file or a +different file.} + +\item{disease}{One of \code{COVID-19} or \code{Influenza}} + +\item{as_of_date}{Use the parameters that were used in production on this +date. Set for the current date for the most up-to-to date version of the +parameters and set to an earlier date to use parameters from an earlier +time period.} + +\item{group}{Used only for parameters with a state-level estimate (i.e., only +right-truncation). The two-letter uppercase state abbreviation.} +} +\value{ +A named list with three PMFs. The list elements are named +\code{generation_interval}, \code{delay_interval}, and \code{right_truncation}. If a path +to a local file is not provided (NA or NULL), the corresponding parameter +estimate will be NA in the returned list. +} +\description{ +Read in disease process parameters from an external file or files +} +\details{ +\code{generation_interval_path} is required because the generation +interval is a required parameter for $R_t$ estimation. +\code{delay_interval_path} and \code{right_truncation_path} are optional (but +strongly suggested). +} diff --git a/man/read_interval_pmf.Rd b/man/read_interval_pmf.Rd new file mode 100644 index 00000000..fa1897f6 --- /dev/null +++ b/man/read_interval_pmf.Rd @@ -0,0 +1,53 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/parameters.R +\name{read_interval_pmf} +\alias{read_interval_pmf} +\title{Read parameter PMF into memory} +\usage{ +read_interval_pmf( + path, + disease = c("COVID-19", "Influenza", "test"), + as_of_date, + parameter = c("generation_interval", "delay", "right_truncation"), + group = NA +) +} +\arguments{ +\item{path}{A path to a local file} + +\item{disease}{One of "COVID-19" or "Influenza"} + +\item{as_of_date}{The parameters "as of" the date of the model run} + +\item{parameter}{One of "generation interval", "delay", or "right-truncation} + +\item{group}{An optional parameter to subset the query to a parameter with a +particular two-letter state abbrevation. Right now, the only parameter with +state-specific estimates is \code{right-truncation}.} +} +\value{ +A PMF vector +} +\description{ +Using DuckDB from a parquet file. The function expects the file to be in SCD2 +format with column names: +\itemize{ +\item parameter +\item geo_value +\item disease +\item start_date +\item end_date +\item value +} +} +\details{ +start_date and end_date specify the date range for which the value was used. +end_date may be NULL (e.g. for the current value used in production). value +must contain a pmf vector whose values are all positive and sum to 1. all +other fields must be consistent with the specifications of the function +arguments describe below, which are used to query from the .parquet file. + +SCD2 format is shorthand for slowly changing dimension type 2. This format is +normalized to track change over time: +https://en.wikipedia.org/wiki/Slowly_changing_dimension#Type_2:_add_new_row +} diff --git a/tests/testthat/helper-write_parameter_file.R b/tests/testthat/helper-write_parameter_file.R new file mode 100644 index 00000000..387ec9ca --- /dev/null +++ b/tests/testthat/helper-write_parameter_file.R @@ -0,0 +1,30 @@ +write_sample_parameters_file <- function(value, + path, + state, + param, + disease, + parameter, + start_date, + end_date) { + df <- data.frame( + start_date = as.Date(start_date), + geo_value = state, + disease = disease, + parameter = parameter, + end_date = end_date, + value = I(list(value)) + ) + + con <- DBI::dbConnect(duckdb::duckdb()) + + duckdb::duckdb_register(con, "test_table", df) + # This is bad practice but `dbBind()` doesn't allow us to parameterize COPY + # ... TO. The danger of doing it this way seems quite low risk because it's + # an ephemeral from a temporary in-memory DB. There's no actual database to + # guard against a SQL injection attack. + query <- paste0("COPY (SELECT * FROM test_table) TO '", path, "'") + DBI::dbExecute(con, query) + DBI::dbDisconnect(con) + + invisible(path) +} diff --git a/tests/testthat/test-parameters.R b/tests/testthat/test-parameters.R new file mode 100644 index 00000000..bf0ca0c6 --- /dev/null +++ b/tests/testthat/test-parameters.R @@ -0,0 +1,424 @@ +test_that("Can read all params on happy path", { + expected <- c(0.8, 0.2) + start_date <- as.Date("2023-01-01") + disease <- "COVID-19" + + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + parameter = "generation_interval", + path = "generation_interval.parquet", + disease = disease, + state = NA, + start_date = start_date, + end_date = NA + ) + write_sample_parameters_file( + value = expected, + parameter = "delay", + path = "delay_interval.parquet", + disease = disease, + state = NA, + start_date = start_date, + end_date = NA + ) + write_sample_parameters_file( + value = expected, + parameter = "right_truncation", + path = "right_truncation.parquet", + disease = disease, + state = "test", + start_date = start_date, + end_date = NA + ) + + + actual <- read_disease_parameters( + generation_interval_path = "generation_interval.parquet", + delay_interval_path = "delay_interval.parquet", + right_truncation_path = "right_truncation.parquet", + disease = "COVID-19", + as_of_date = start_date + 1, + group = "test" + ) + }) + + + expect_equal( + actual, + list( + generation_interval = expected, + delay_interval = expected, + right_truncation = expected + ) + ) +}) + +test_that("Can skip params on happy path", { + expected <- c(0.8, 0.2) + start_date <- as.Date("2023-01-01") + disease <- "COVID-19" + + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + parameter = "generation_interval", + path = "generation_interval.parquet", + disease = disease, + state = NA, + start_date = start_date, + end_date = NA + ) + write_sample_parameters_file( + value = expected, + parameter = "delay", + path = "delay_interval.parquet", + disease = disease, + state = NA, + start_date = start_date, + end_date = NA + ) + write_sample_parameters_file( + value = expected, + parameter = "right_truncation", + path = "right_truncation.parquet", + disease = disease, + state = "test", + start_date = start_date, + end_date = NA + ) + + + actual <- read_disease_parameters( + generation_interval_path = "generation_interval.parquet", + delay_interval_path = NULL, + right_truncation_path = NULL, + disease = "COVID-19", + as_of_date = start_date + 1, + group = "test" + ) + }) + + + expect_equal( + actual, + list( + generation_interval = expected, + delay_interval = NA, + right_truncation = NA + ) + ) +}) + +test_that("Can read right-truncation on happy path", { + expected <- c(0.8, 0.2) + path <- "test.parquet" + parameter <- "right_truncation" + start_date <- as.Date("2023-01-01") + + # COVID-19 + disease <- "COVID-19" + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = "test", + disease = disease, + parameter = parameter, + start_date = start_date, + end_date = NA + ) + actual <- read_interval_pmf( + path = path, + parameter = parameter, + disease = disease, + as_of_date = start_date + 1, + group = "test" + ) + }) + expect_equal(actual, expected) + + + # Influenza + disease <- "Influenza" + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = "test", + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + actual <- read_interval_pmf( + path = path, + parameter = parameter, + disease = disease, + as_of_date = start_date + 1, + group = "test" + ) + }) + expect_equal(actual, expected) +}) + +test_that("Invalid PMF errors", { + expected <- c(0.8, -0.1) + path <- "test.parquet" + parameter <- "right_truncation" + start_date <- as.Date("2023-01-01") + + # COVID-19 + disease <- "COVID-19" + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = "test", + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + expect_error( + read_interval_pmf( + path = path, + parameter = parameter, + disease = disease, + as_of_date = start_date + 1, + group = "test" + ), + class = "invalid_pmf" + ) + }) +}) + + +test_that("Can read delay on happy path", { + expected <- c(0.8, 0.2) + path <- "test.parquet" + parameter <- "delay" + start_date <- as.Date("2023-01-01") + + # COVID-19 + disease <- "COVID-19" + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = NA, + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + actual <- read_interval_pmf( + path = path, + disease = disease, + as_of_date = start_date + 1, + parameter = parameter + ) + }) + expect_equal(actual, expected) + + + # Influenza + disease <- "Influenza" + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = NA, + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + actual <- read_interval_pmf( + path = path, + disease = disease, + as_of_date = start_date + 1, + parameter = parameter + ) + }) + expect_equal(actual, expected) +}) + + +test_that("Not a PMF errors", { + expected <- "hello" + path <- "test.parquet" + parameter <- "delay" + start_date <- as.Date("2023-01-01") + + # COVID-19 + disease <- "COVID-19" + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = NA, + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + expect_error( + read_interval_pmf( + path = path, + disease = disease, + as_of_date = start_date + 1, + parameter = parameter + ), + class = "not_a_pmf" + ) + }) +}) + +test_that("Invalid disease errors", { + expected <- c(0.8, 0.2) + path <- "test.parquet" + parameter <- "delay" + start_date <- as.Date("2023-01-01") + disease <- "not_a_valid_disease" + + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = "test", + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + + expect_error( + read_interval_pmf( + path = path, + disease = disease, + as_of_date = start_date + 1, + parameter = parameter + ), + regexp = "`disease` must be one of" + ) + }) +}) + +test_that("Invalid parameter errors", { + expected <- c(0.8, 0.2) + path <- "test.parquet" + parameter <- "not_a_valid_parameter" + start_date <- as.Date("2023-01-01") + disease <- "COVID-19" + + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = "test", + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + + expect_error( + read_interval_pmf( + path = path, + disease = disease, + as_of_date = start_date + 1, + parameter = parameter + ), + regexp = "`parameter` must be one of" + ) + }) +}) + +test_that("Return isn't exactly one errors", { + expected <- c(0.8, 0.2) + path <- "test.parquet" + parameter <- "delay" + start_date <- as.Date("2023-01-01") + disease <- "COVID-19" + + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = "test", + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + + # Date too early + expect_error( + read_interval_pmf( + path = path, + disease = disease, + as_of_date = start_date - 1, + parameter = parameter + ), + class = "not_one_row_returned" + ) + }) +}) + +test_that("No file exists errors", { + expected <- c(0.8, 0.2) + path <- "test.parquet" + parameter <- "delay" + start_date <- as.Date("2023-01-01") + disease <- "COVID-19" + + expect_error( + read_interval_pmf( + path = "not_a_real_file", + disease = disease, + as_of_date = start_date - 1, + parameter = parameter + ), + class = "file_not_found" + ) +}) + +test_that("Invalid query throws wrapped error", { + expected <- c(0.8, 0.2) + path <- "test.parquet" + parameter <- "delay" + start_date <- as.Date("2023-01-01") + disease <- "COVID-19" + + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = NA, + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + + expect_error( + read_interval_pmf( + path = path, + disease = disease, + as_of_date = "abc123", + parameter = parameter + ), + class = "wrapped_error" + ) + }) +})