Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read in parameter PMFs #16

Merged
merged 16 commits into from
Sep 4, 2024
Merged
10 changes: 4 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ repos:
- id: style-files
args: [--style_pkg=styler, --style_fun=tidyverse_style,
--cache-root=styler-perm]
- id: roxygenize
- id: use-tidy-description
- id: lintr
- id: readme-rmd-rendered
Expand Down Expand Up @@ -49,7 +48,7 @@ repos:
#####
# Python
- repo: https://github.com/psf/black
rev: 24.4.2
rev: 24.8.0
hooks:
# if you have ipython notebooks, consider using
# `black-jupyter` hook instead
Expand All @@ -62,13 +61,13 @@ repos:
args: ['--profile', 'black',
'--line-length', '79']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.2
rev: v0.6.3
hooks:
- id: ruff
#####
# Java
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.13.0
rev: v2.14.0
hooks:
- id: pretty-format-java
args: [--aosp,--autofix]
Expand All @@ -83,7 +82,7 @@ repos:
#####
# Secrets
- repo: https://github.com/Yelp/detect-secrets
rev: v1.4.0
rev: v1.5.0
hooks:
- id: detect-secrets
args: ['--baseline', '.secrets.baseline']
Expand All @@ -97,5 +96,4 @@ ci:
autoupdate_branch: ''
autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate'
autoupdate_schedule: weekly
skip: [roxygenize]
submodules: false
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@ Imports:
AzureRMR,
AzureStor,
cli,
DBI,
duckdb,
rlang
URL: https://cdcgov.github.io/cfa-epinow2-pipeline/
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ export(add_two_numbers)
export(download_from_azure_blob)
export(fetch_blob_container)
export(fetch_credential_from_env_var)
export(read_disease_parameters)
export(read_interval_pmf)
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# CFAEpiNow2Pipeline (development version)

* Parameters read from local parquet file or files
* Additional CI bug squashing
* Bug fixed in the updated, faster pre-commit checks
* Updated, faster pre-commit checks
Expand Down
238 changes: 238 additions & 0 deletions R/parameters.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
#' Read in disease process parameters from an external file or files
#'
#' @param generation_interval_path,delay_interval_path,right_truncation_path
#' Path to a local file with the parameter PMF. See [read_interval_pmf] for
#' details on the file schema. The parameters can be in the same file or a
#' different file.
#' @param disease One of `COVID-19` or `Influenza`
#' @param as_of_date Use the parameters that were used in production on this
#' date. Set for the current date for the most up-to-to date version of the
#' parameters and set to an earlier date to use parameters from an earlier
#' time period.
#' @param group Used only for parameters with a state-level estimate (i.e., only
#' right-truncation). The two-letter uppercase state abbreviation.
#'
#' @return A named list with three PMFs. The list elements are named
#' `generation_interval`, `delay_interval`, and `right_truncation`. If a path
#' to a local file is not provided (NA or NULL), the corresponding parameter
#' estimate will be NA in the returned list.
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
#' @details `generation_interval_path` is required because the generation
#' interval is a required parameter for $R_t$ estimation.
#' `delay_interval_path` and `right_truncation_path` are optional (but
#' strongly suggested).
#' @export
read_disease_parameters <- function(
generation_interval_path,
delay_interval_path,
right_truncation_path,
disease,
as_of_date,
group) {
generation_interval <- read_interval_pmf(
path = generation_interval_path,
disease = disease,
as_of_date = as_of_date,
parameter = "generation_interval"
)

if (path_is_specified(delay_interval_path)) {
delay_interval <- read_interval_pmf(
path = delay_interval_path,
disease = disease,
as_of_date = as_of_date,
parameter = "delay"
)
} else {
cli::cli_alert_warning(
"No delay interval path specified. Using a delay of 0 days."
)
delay_interval <- NA
}

if (path_is_specified(right_truncation_path)) {
right_truncation <- read_interval_pmf(
path = right_truncation_path,
disease = disease,
as_of_date = as_of_date,
parameter = "right_truncation",
group = group
)
} else {
cli::cli_alert_warning(
"No right truncation path specified. Not adjusting for right truncation."
)
right_truncation <- NA
}

parameters <- list(
generation_interval = generation_interval,
delay_interval = delay_interval,
right_truncation = right_truncation
)
return(parameters)
}

path_is_specified <- function(path) {
!rlang::is_null(path) &&
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
!rlang::is_na(path)
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
}

#' Read parameter PMF into memory
#'
#' Using DuckDB from a parquet file. The function expects the file to be in SCD2
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
#' format with column names:
#' * parameter
#' * geo_value
#' * disease
#' * start_date
#' * end_date
#' * value
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
#'
#' start_date and end_date specify the date range for which the value was used.
#' end_date may be NULL (e.g. for the current value used in production). value
#' must contain a pmf vector whose values are all positive and sum to 1. all
#' other fields must be consistent with the specifications of the function
#' arguments describe below, which are used to query from the .parquet file.
#'
#' SCD2 format is shorthand for slowly changing dimension type 2. This format is
#' normalized to track change over time:
#' https://en.wikipedia.org/wiki/Slowly_changing_dimension#Type_2:_add_new_row
#'
#' @param path A path to a local file
#' @param disease One of "COVID-19" or "Influenza"
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
#' @param as_of_date The parameters "as of" the date of the model run
#' @param parameter One of "generation interval", "delay", or "right-truncation
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
#' @param group An optional parameter to subset the query to a parameter with a
#' particular two-letter state abbrevation. Right now, the only parameter with
#' state-specific estimates is `right-truncation`.
#'
#' @return A PMF vector
#' @export
read_interval_pmf <- function(path,
disease = c("COVID-19", "Influenza", "test"),
as_of_date,
parameter = c(
"generation_interval",
"delay",
"right_truncation"
),
group = NA) {
###################
# Validate input
rlang::arg_match(parameter)
rlang::arg_match(disease)

as_of_date <- stringify_date(as_of_date)
cli::cli_alert_info("Reading {.arg right_truncation} from {.path {path}}")
if (!file.exists(path)) {
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
cli::cli_abort("File {.path {path}} does not exist",
class = "file_not_found"
)
}


################
# Prepare query

query <- "
SELECT value
FROM read_parquet(?)
WHERE 1=1
AND parameter = ?
AND disease = ?
AND start_date < ? :: DATE
AND (end_date > ? OR end_date IS NULL)
"
parameters <- list(
path,
parameter,
disease,
as_of_date,
as_of_date
)

# Handle state separately because can't use `=` for NULL comparison and
# DBI::dbBind() can't parameterize a query after IS
zsusswein marked this conversation as resolved.
Show resolved Hide resolved
if (rlang::is_na(group) || rlang::is_null(group)) {
query <- paste(query, "AND geo_value IS NULL;")
} else {
query <- paste(query, "AND geo_value = ?")
parameters <- c(parameters, list(group))
}

################
# Execute query

con <- DBI::dbConnect(duckdb::duckdb())
pmf_df <- rlang::try_fetch(
DBI::dbGetQuery(
conn = con,
statement = query,
params = parameters
),
error = function(cnd) {
cli::cli_abort(
c(
"Failure loading {.arg {parameter}} from {.path {path}}",
"Using {.val {disease}}, {.val {as_of_date}}, and {.val {group}}",
"Original error: {cnd}"
),
class = "wrapped_error"
)
}
)
DBI::dbDisconnect(con)


################
# Validate loaded PMF
if (nrow(pmf_df) != 1) {
cli::cli_abort(
c(
"Failure loading {.arg {parameter}} from {.path {path}} ",
"Query did not return exactly one row",
"Using {.val {disease}}, {.val {as_of_date}}, and {.val {group}}",
"Query matched {.val {nrow(pmf_df)}} rows"
),
class = "not_one_row_returned"
)
}

pmf <- pmf_df[["value"]][[1]]

if ((length(pmf) < 1) || !rlang::is_bare_numeric(pmf)) {
cli::cli_abort(
c(
"Invalid {.arg {parameter}} returned.",
"x" = "Expected a PMF",
"i" = "Loaded object: {pmf_df}"
),
class = "not_a_pmf"
)
}

if (any(pmf < 0) || any(pmf > 1) || abs(sum(pmf) - 1) > 1e-10) {
cli::cli_abort(
c(
"Returned numeric vector is not a valid PMF",
"Any below 0: {any(pmf < 0)}",
"Any above 1: {any(pmf > 1)}",
"Sum is within 1 with tol of 1e-10: {abs(sum(pmf) - 1) < 1e-10}",
"pmf: {.val {pmf}}"
),
class = "invalid_pmf"
)
}

cli::cli_alert_success("{.arg {parameter}} loaded")

return(pmf)
}

#' DuckDB date comparison fails if the dates are not in string format
#' @noRd
stringify_date <- function(date) {
if (inherits(date, "Date")) {
format(date, "%Y-%m-%d")
}
}
3 changes: 2 additions & 1 deletion man/add_two_numbers.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 45 additions & 0 deletions man/read_disease_parameters.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading