diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 07413981..91af2392 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,6 @@ repos: - id: style-files args: [--style_pkg=styler, --style_fun=tidyverse_style, --cache-root=styler-perm] - - id: roxygenize - id: use-tidy-description - id: lintr - id: readme-rmd-rendered @@ -49,7 +48,7 @@ repos: ##### # Python - repo: https://github.com/psf/black - rev: 24.4.2 + rev: 24.8.0 hooks: # if you have ipython notebooks, consider using # `black-jupyter` hook instead @@ -62,13 +61,13 @@ repos: args: ['--profile', 'black', '--line-length', '79'] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.2 + rev: v0.6.3 hooks: - id: ruff ##### # Java - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks - rev: v2.13.0 + rev: v2.14.0 hooks: - id: pretty-format-java args: [--aosp,--autofix] @@ -83,7 +82,7 @@ repos: ##### # Secrets - repo: https://github.com/Yelp/detect-secrets - rev: v1.4.0 + rev: v1.5.0 hooks: - id: detect-secrets args: ['--baseline', '.secrets.baseline'] @@ -97,5 +96,4 @@ ci: autoupdate_branch: '' autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate' autoupdate_schedule: weekly - skip: [roxygenize] submodules: false diff --git a/DESCRIPTION b/DESCRIPTION index 35c367e0..390d1e7d 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -26,5 +26,7 @@ Imports: AzureRMR, AzureStor, cli, + DBI, + duckdb, rlang URL: https://cdcgov.github.io/cfa-epinow2-pipeline/ diff --git a/NAMESPACE b/NAMESPACE index b6920831..25d4e0a6 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -4,3 +4,5 @@ export(add_two_numbers) export(download_from_azure_blob) export(fetch_blob_container) export(fetch_credential_from_env_var) +export(read_disease_parameters) +export(read_interval_pmf) diff --git a/NEWS.md b/NEWS.md index bfa2eb3b..efaf34db 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # CFAEpiNow2Pipeline (development version) +* Parameters read from local parquet file or files * Additional CI bug squashing * Bug fixed in the updated, faster pre-commit checks * Updated, faster pre-commit checks diff --git a/R/parameters.R b/R/parameters.R new file mode 100644 index 00000000..a6c5dbfa --- /dev/null +++ b/R/parameters.R @@ -0,0 +1,238 @@ +#' Read in disease process parameters from an external file or files +#' +#' @param generation_interval_path,delay_interval_path,right_truncation_path +#' Path to a local file with the parameter PMF. See [read_interval_pmf] for +#' details on the file schema. The parameters can be in the same file or a +#' different file. +#' @param disease One of `COVID-19` or `Influenza` +#' @param as_of_date Use the parameters that were used in production on this +#' date. Set for the current date for the most up-to-to date version of the +#' parameters and set to an earlier date to use parameters from an earlier +#' time period. +#' @param group Used only for parameters with a state-level estimate (i.e., only +#' right-truncation). The two-letter uppercase state abbreviation. +#' +#' @return A named list with three PMFs. The list elements are named +#' `generation_interval`, `delay_interval`, and `right_truncation`. If a path +#' to a local file is not provided (NA or NULL), the corresponding parameter +#' estimate will be NA in the returned list. +#' @details `generation_interval_path` is required because the generation +#' interval is a required parameter for $R_t$ estimation. +#' `delay_interval_path` and `right_truncation_path` are optional (but +#' strongly suggested). +#' @export +read_disease_parameters <- function( + generation_interval_path, + delay_interval_path, + right_truncation_path, + disease, + as_of_date, + group) { + generation_interval <- read_interval_pmf( + path = generation_interval_path, + disease = disease, + as_of_date = as_of_date, + parameter = "generation_interval" + ) + + if (path_is_specified(delay_interval_path)) { + delay_interval <- read_interval_pmf( + path = delay_interval_path, + disease = disease, + as_of_date = as_of_date, + parameter = "delay" + ) + } else { + cli::cli_alert_warning( + "No delay interval path specified. Using a delay of 0 days." + ) + delay_interval <- NA + } + + if (path_is_specified(right_truncation_path)) { + right_truncation <- read_interval_pmf( + path = right_truncation_path, + disease = disease, + as_of_date = as_of_date, + parameter = "right_truncation", + group = group + ) + } else { + cli::cli_alert_warning( + "No right truncation path specified. Not adjusting for right truncation." + ) + right_truncation <- NA + } + + parameters <- list( + generation_interval = generation_interval, + delay_interval = delay_interval, + right_truncation = right_truncation + ) + return(parameters) +} + +path_is_specified <- function(path) { + !rlang::is_null(path) && + !rlang::is_na(path) +} + +#' Read parameter PMF into memory +#' +#' Using DuckDB from a parquet file. The function expects the file to be in SCD2 +#' format with column names: +#' * parameter +#' * geo_value +#' * disease +#' * start_date +#' * end_date +#' * value +#' +#' start_date and end_date specify the date range for which the value was used. +#' end_date may be NULL (e.g. for the current value used in production). value +#' must contain a pmf vector whose values are all positive and sum to 1. all +#' other fields must be consistent with the specifications of the function +#' arguments describe below, which are used to query from the .parquet file. +#' +#' SCD2 format is shorthand for slowly changing dimension type 2. This format is +#' normalized to track change over time: +#' https://en.wikipedia.org/wiki/Slowly_changing_dimension#Type_2:_add_new_row +#' +#' @param path A path to a local file +#' @param disease One of "COVID-19" or "Influenza" +#' @param as_of_date The parameters "as of" the date of the model run +#' @param parameter One of "generation interval", "delay", or "right-truncation +#' @param group An optional parameter to subset the query to a parameter with a +#' particular two-letter state abbrevation. Right now, the only parameter with +#' state-specific estimates is `right-truncation`. +#' +#' @return A PMF vector +#' @export +read_interval_pmf <- function(path, + disease = c("COVID-19", "Influenza", "test"), + as_of_date, + parameter = c( + "generation_interval", + "delay", + "right_truncation" + ), + group = NA) { + ################### + # Validate input + rlang::arg_match(parameter) + rlang::arg_match(disease) + + as_of_date <- stringify_date(as_of_date) + cli::cli_alert_info("Reading {.arg right_truncation} from {.path {path}}") + if (!file.exists(path)) { + cli::cli_abort("File {.path {path}} does not exist", + class = "file_not_found" + ) + } + + + ################ + # Prepare query + + query <- " + SELECT value + FROM read_parquet(?) + WHERE 1=1 + AND parameter = ? + AND disease = ? + AND start_date < ? :: DATE + AND (end_date > ? OR end_date IS NULL) + " + parameters <- list( + path, + parameter, + disease, + as_of_date, + as_of_date + ) + + # Handle state separately because can't use `=` for NULL comparison and + # DBI::dbBind() can't parameterize a query after IS + if (rlang::is_na(group) || rlang::is_null(group)) { + query <- paste(query, "AND geo_value IS NULL;") + } else { + query <- paste(query, "AND geo_value = ?") + parameters <- c(parameters, list(group)) + } + + ################ + # Execute query + + con <- DBI::dbConnect(duckdb::duckdb()) + pmf_df <- rlang::try_fetch( + DBI::dbGetQuery( + conn = con, + statement = query, + params = parameters + ), + error = function(cnd) { + cli::cli_abort( + c( + "Failure loading {.arg {parameter}} from {.path {path}}", + "Using {.val {disease}}, {.val {as_of_date}}, and {.val {group}}", + "Original error: {cnd}" + ), + class = "wrapped_error" + ) + } + ) + DBI::dbDisconnect(con) + + + ################ + # Validate loaded PMF + if (nrow(pmf_df) != 1) { + cli::cli_abort( + c( + "Failure loading {.arg {parameter}} from {.path {path}} ", + "Query did not return exactly one row", + "Using {.val {disease}}, {.val {as_of_date}}, and {.val {group}}", + "Query matched {.val {nrow(pmf_df)}} rows" + ), + class = "not_one_row_returned" + ) + } + + pmf <- pmf_df[["value"]][[1]] + + if ((length(pmf) < 1) || !rlang::is_bare_numeric(pmf)) { + cli::cli_abort( + c( + "Invalid {.arg {parameter}} returned.", + "x" = "Expected a PMF", + "i" = "Loaded object: {pmf_df}" + ), + class = "not_a_pmf" + ) + } + + if (any(pmf < 0) || any(pmf > 1) || abs(sum(pmf) - 1) > 1e-10) { + cli::cli_abort( + c( + "Returned numeric vector is not a valid PMF", + "Any below 0: {any(pmf < 0)}", + "Any above 1: {any(pmf > 1)}", + "Sum is within 1 with tol of 1e-10: {abs(sum(pmf) - 1) < 1e-10}", + "pmf: {.val {pmf}}" + ), + class = "invalid_pmf" + ) + } + + cli::cli_alert_success("{.arg {parameter}} loaded") + + return(pmf) +} + +#' DuckDB date comparison fails if the dates are not in string format +#' @noRd +stringify_date <- function(date) { + if (inherits(date, "Date")) { + format(date, "%Y-%m-%d") + } +} diff --git a/man/add_two_numbers.Rd b/man/add_two_numbers.Rd index ba7a1446..ffd18496 100644 --- a/man/add_two_numbers.Rd +++ b/man/add_two_numbers.Rd @@ -15,5 +15,6 @@ add_two_numbers(x, y) Their sum } \description{ -A temp function +Adding some meaningless text to test rendering via PR command. +Adding some more } diff --git a/man/read_disease_parameters.Rd b/man/read_disease_parameters.Rd new file mode 100644 index 00000000..7003932b --- /dev/null +++ b/man/read_disease_parameters.Rd @@ -0,0 +1,45 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/parameters.R +\name{read_disease_parameters} +\alias{read_disease_parameters} +\title{Read in disease process parameters from an external file or files} +\usage{ +read_disease_parameters( + generation_interval_path, + delay_interval_path, + right_truncation_path, + disease, + as_of_date, + group +) +} +\arguments{ +\item{generation_interval_path, delay_interval_path, right_truncation_path}{Path to a local file with the parameter PMF. See \link{read_interval_pmf} for +details on the file schema. The parameters can be in the same file or a +different file.} + +\item{disease}{One of \code{COVID-19} or \code{Influenza}} + +\item{as_of_date}{Use the parameters that were used in production on this +date. Set for the current date for the most up-to-to date version of the +parameters and set to an earlier date to use parameters from an earlier +time period.} + +\item{group}{Used only for parameters with a state-level estimate (i.e., only +right-truncation). The two-letter uppercase state abbreviation.} +} +\value{ +A named list with three PMFs. The list elements are named +\code{generation_interval}, \code{delay_interval}, and \code{right_truncation}. If a path +to a local file is not provided (NA or NULL), the corresponding parameter +estimate will be NA in the returned list. +} +\description{ +Read in disease process parameters from an external file or files +} +\details{ +\code{generation_interval_path} is required because the generation +interval is a required parameter for $R_t$ estimation. +\code{delay_interval_path} and \code{right_truncation_path} are optional (but +strongly suggested). +} diff --git a/man/read_interval_pmf.Rd b/man/read_interval_pmf.Rd new file mode 100644 index 00000000..fa1897f6 --- /dev/null +++ b/man/read_interval_pmf.Rd @@ -0,0 +1,53 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/parameters.R +\name{read_interval_pmf} +\alias{read_interval_pmf} +\title{Read parameter PMF into memory} +\usage{ +read_interval_pmf( + path, + disease = c("COVID-19", "Influenza", "test"), + as_of_date, + parameter = c("generation_interval", "delay", "right_truncation"), + group = NA +) +} +\arguments{ +\item{path}{A path to a local file} + +\item{disease}{One of "COVID-19" or "Influenza"} + +\item{as_of_date}{The parameters "as of" the date of the model run} + +\item{parameter}{One of "generation interval", "delay", or "right-truncation} + +\item{group}{An optional parameter to subset the query to a parameter with a +particular two-letter state abbrevation. Right now, the only parameter with +state-specific estimates is \code{right-truncation}.} +} +\value{ +A PMF vector +} +\description{ +Using DuckDB from a parquet file. The function expects the file to be in SCD2 +format with column names: +\itemize{ +\item parameter +\item geo_value +\item disease +\item start_date +\item end_date +\item value +} +} +\details{ +start_date and end_date specify the date range for which the value was used. +end_date may be NULL (e.g. for the current value used in production). value +must contain a pmf vector whose values are all positive and sum to 1. all +other fields must be consistent with the specifications of the function +arguments describe below, which are used to query from the .parquet file. + +SCD2 format is shorthand for slowly changing dimension type 2. This format is +normalized to track change over time: +https://en.wikipedia.org/wiki/Slowly_changing_dimension#Type_2:_add_new_row +} diff --git a/tests/testthat/helper-write_parameter_file.R b/tests/testthat/helper-write_parameter_file.R new file mode 100644 index 00000000..387ec9ca --- /dev/null +++ b/tests/testthat/helper-write_parameter_file.R @@ -0,0 +1,30 @@ +write_sample_parameters_file <- function(value, + path, + state, + param, + disease, + parameter, + start_date, + end_date) { + df <- data.frame( + start_date = as.Date(start_date), + geo_value = state, + disease = disease, + parameter = parameter, + end_date = end_date, + value = I(list(value)) + ) + + con <- DBI::dbConnect(duckdb::duckdb()) + + duckdb::duckdb_register(con, "test_table", df) + # This is bad practice but `dbBind()` doesn't allow us to parameterize COPY + # ... TO. The danger of doing it this way seems quite low risk because it's + # an ephemeral from a temporary in-memory DB. There's no actual database to + # guard against a SQL injection attack. + query <- paste0("COPY (SELECT * FROM test_table) TO '", path, "'") + DBI::dbExecute(con, query) + DBI::dbDisconnect(con) + + invisible(path) +} diff --git a/tests/testthat/test-parameters.R b/tests/testthat/test-parameters.R new file mode 100644 index 00000000..bf0ca0c6 --- /dev/null +++ b/tests/testthat/test-parameters.R @@ -0,0 +1,424 @@ +test_that("Can read all params on happy path", { + expected <- c(0.8, 0.2) + start_date <- as.Date("2023-01-01") + disease <- "COVID-19" + + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + parameter = "generation_interval", + path = "generation_interval.parquet", + disease = disease, + state = NA, + start_date = start_date, + end_date = NA + ) + write_sample_parameters_file( + value = expected, + parameter = "delay", + path = "delay_interval.parquet", + disease = disease, + state = NA, + start_date = start_date, + end_date = NA + ) + write_sample_parameters_file( + value = expected, + parameter = "right_truncation", + path = "right_truncation.parquet", + disease = disease, + state = "test", + start_date = start_date, + end_date = NA + ) + + + actual <- read_disease_parameters( + generation_interval_path = "generation_interval.parquet", + delay_interval_path = "delay_interval.parquet", + right_truncation_path = "right_truncation.parquet", + disease = "COVID-19", + as_of_date = start_date + 1, + group = "test" + ) + }) + + + expect_equal( + actual, + list( + generation_interval = expected, + delay_interval = expected, + right_truncation = expected + ) + ) +}) + +test_that("Can skip params on happy path", { + expected <- c(0.8, 0.2) + start_date <- as.Date("2023-01-01") + disease <- "COVID-19" + + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + parameter = "generation_interval", + path = "generation_interval.parquet", + disease = disease, + state = NA, + start_date = start_date, + end_date = NA + ) + write_sample_parameters_file( + value = expected, + parameter = "delay", + path = "delay_interval.parquet", + disease = disease, + state = NA, + start_date = start_date, + end_date = NA + ) + write_sample_parameters_file( + value = expected, + parameter = "right_truncation", + path = "right_truncation.parquet", + disease = disease, + state = "test", + start_date = start_date, + end_date = NA + ) + + + actual <- read_disease_parameters( + generation_interval_path = "generation_interval.parquet", + delay_interval_path = NULL, + right_truncation_path = NULL, + disease = "COVID-19", + as_of_date = start_date + 1, + group = "test" + ) + }) + + + expect_equal( + actual, + list( + generation_interval = expected, + delay_interval = NA, + right_truncation = NA + ) + ) +}) + +test_that("Can read right-truncation on happy path", { + expected <- c(0.8, 0.2) + path <- "test.parquet" + parameter <- "right_truncation" + start_date <- as.Date("2023-01-01") + + # COVID-19 + disease <- "COVID-19" + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = "test", + disease = disease, + parameter = parameter, + start_date = start_date, + end_date = NA + ) + actual <- read_interval_pmf( + path = path, + parameter = parameter, + disease = disease, + as_of_date = start_date + 1, + group = "test" + ) + }) + expect_equal(actual, expected) + + + # Influenza + disease <- "Influenza" + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = "test", + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + actual <- read_interval_pmf( + path = path, + parameter = parameter, + disease = disease, + as_of_date = start_date + 1, + group = "test" + ) + }) + expect_equal(actual, expected) +}) + +test_that("Invalid PMF errors", { + expected <- c(0.8, -0.1) + path <- "test.parquet" + parameter <- "right_truncation" + start_date <- as.Date("2023-01-01") + + # COVID-19 + disease <- "COVID-19" + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = "test", + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + expect_error( + read_interval_pmf( + path = path, + parameter = parameter, + disease = disease, + as_of_date = start_date + 1, + group = "test" + ), + class = "invalid_pmf" + ) + }) +}) + + +test_that("Can read delay on happy path", { + expected <- c(0.8, 0.2) + path <- "test.parquet" + parameter <- "delay" + start_date <- as.Date("2023-01-01") + + # COVID-19 + disease <- "COVID-19" + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = NA, + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + actual <- read_interval_pmf( + path = path, + disease = disease, + as_of_date = start_date + 1, + parameter = parameter + ) + }) + expect_equal(actual, expected) + + + # Influenza + disease <- "Influenza" + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = NA, + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + actual <- read_interval_pmf( + path = path, + disease = disease, + as_of_date = start_date + 1, + parameter = parameter + ) + }) + expect_equal(actual, expected) +}) + + +test_that("Not a PMF errors", { + expected <- "hello" + path <- "test.parquet" + parameter <- "delay" + start_date <- as.Date("2023-01-01") + + # COVID-19 + disease <- "COVID-19" + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = NA, + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + expect_error( + read_interval_pmf( + path = path, + disease = disease, + as_of_date = start_date + 1, + parameter = parameter + ), + class = "not_a_pmf" + ) + }) +}) + +test_that("Invalid disease errors", { + expected <- c(0.8, 0.2) + path <- "test.parquet" + parameter <- "delay" + start_date <- as.Date("2023-01-01") + disease <- "not_a_valid_disease" + + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = "test", + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + + expect_error( + read_interval_pmf( + path = path, + disease = disease, + as_of_date = start_date + 1, + parameter = parameter + ), + regexp = "`disease` must be one of" + ) + }) +}) + +test_that("Invalid parameter errors", { + expected <- c(0.8, 0.2) + path <- "test.parquet" + parameter <- "not_a_valid_parameter" + start_date <- as.Date("2023-01-01") + disease <- "COVID-19" + + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = "test", + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + + expect_error( + read_interval_pmf( + path = path, + disease = disease, + as_of_date = start_date + 1, + parameter = parameter + ), + regexp = "`parameter` must be one of" + ) + }) +}) + +test_that("Return isn't exactly one errors", { + expected <- c(0.8, 0.2) + path <- "test.parquet" + parameter <- "delay" + start_date <- as.Date("2023-01-01") + disease <- "COVID-19" + + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = "test", + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + + # Date too early + expect_error( + read_interval_pmf( + path = path, + disease = disease, + as_of_date = start_date - 1, + parameter = parameter + ), + class = "not_one_row_returned" + ) + }) +}) + +test_that("No file exists errors", { + expected <- c(0.8, 0.2) + path <- "test.parquet" + parameter <- "delay" + start_date <- as.Date("2023-01-01") + disease <- "COVID-19" + + expect_error( + read_interval_pmf( + path = "not_a_real_file", + disease = disease, + as_of_date = start_date - 1, + parameter = parameter + ), + class = "file_not_found" + ) +}) + +test_that("Invalid query throws wrapped error", { + expected <- c(0.8, 0.2) + path <- "test.parquet" + parameter <- "delay" + start_date <- as.Date("2023-01-01") + disease <- "COVID-19" + + withr::with_tempdir({ + write_sample_parameters_file( + value = expected, + path = path, + state = NA, + disease = disease, + parameter = parameter, + param = parameter, + start_date = start_date, + end_date = NA + ) + + expect_error( + read_interval_pmf( + path = path, + disease = disease, + as_of_date = "abc123", + parameter = parameter + ), + class = "wrapped_error" + ) + }) +})