Skip to content

Commit

Permalink
Read in parameter PMFs (#16)
Browse files Browse the repository at this point in the history
* Parameter PMF reader

For generation interval, delays, and right-truncation.

It assumes that:
- We want a PMF
- The PMF is coming from a file with schema specified as in
  cdcent/cfa-parameter-estimates#9
- The parameter names and disease names are following a specified schema
- The file is in parquet format
- The PMFs are actually proper PMFs

It relaxes the assumptions that:
- The PMFs are coming from the same file
- The PMFs must be present (can skip delays and right-truncation but
  not GI)
- The files are in Azure or have a specific name/path

Unit tests cover successfully reading all the parameters individually
as well as in combination. They also check for failure in the expected
places for desired failure modes.

Switching over to this schema from manual CSVs now is a bit of a choice.
I took three cracks at this PR before landing on doing it this way. I
think it provides a couple of important benefits to make the switch now
and that comes through in designing the code here:
1. Our existing CSV-based approach produces 3 distinct files with
   close-ish schemas, but not an exact match. It requires distinct
   reader functions and substantially more code than reading from a
   file with a unified schema like I do here.
2. I think this approach helps avoid issues like this week's production
   mishap. It allows us to mix and match files as needed, and we can
   point to a single drop-in file for testing if desired. We don't need
   to fiddle with the production environment.
3. I think we're going to need to make a switch on the parameter
   approach sooner or later. I wanted to bite the bullet and get it
   over with all at once. It's convenient that it helps make the code
   simpler, but I think making changes swiftly rather than dragging
   things out provides its own benefit too.

* Bump NEWS

* Document new functions

* Update NEWS.md

* Drop roxygenize hooking -- it's fully broken

* Clarify `as_of` date

* Clarify state-varying params

* Use classed errors

For additional specificity linking expected errors to tests. I use the
regexp option instead of classed errors for `arg_match()` because I had
trouble catching those even when matching the class exactly.

* Document

* Apply suggestions from code review

Co-authored-by: Katie Gostic (she/her) <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add clarifying info from code review

* state -> group

But I left state in the documentation for clarity because we don't
currently do anything at the sub-state level

* Document

* Update tests & docs with changes from review

* `read_parameters()` -> `read_disease_parameters()`

---------

Co-authored-by: zsusswein <[email protected]>
Co-authored-by: Katie Gostic (she/her) <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Sep 4, 2024
1 parent c239172 commit 9152bc0
Show file tree
Hide file tree
Showing 10 changed files with 801 additions and 7 deletions.
10 changes: 4 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ repos:
- id: style-files
args: [--style_pkg=styler, --style_fun=tidyverse_style,
--cache-root=styler-perm]
- id: roxygenize
- id: use-tidy-description
- id: lintr
- id: readme-rmd-rendered
Expand Down Expand Up @@ -49,7 +48,7 @@ repos:
#####
# Python
- repo: https://github.com/psf/black
rev: 24.4.2
rev: 24.8.0
hooks:
# if you have ipython notebooks, consider using
# `black-jupyter` hook instead
Expand All @@ -62,13 +61,13 @@ repos:
args: ['--profile', 'black',
'--line-length', '79']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.2
rev: v0.6.3
hooks:
- id: ruff
#####
# Java
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.13.0
rev: v2.14.0
hooks:
- id: pretty-format-java
args: [--aosp,--autofix]
Expand All @@ -83,7 +82,7 @@ repos:
#####
# Secrets
- repo: https://github.com/Yelp/detect-secrets
rev: v1.4.0
rev: v1.5.0
hooks:
- id: detect-secrets
args: ['--baseline', '.secrets.baseline']
Expand All @@ -97,5 +96,4 @@ ci:
autoupdate_branch: ''
autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate'
autoupdate_schedule: weekly
skip: [roxygenize]
submodules: false
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@ Imports:
AzureRMR,
AzureStor,
cli,
DBI,
duckdb,
rlang
URL: https://cdcgov.github.io/cfa-epinow2-pipeline/
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ export(add_two_numbers)
export(download_from_azure_blob)
export(fetch_blob_container)
export(fetch_credential_from_env_var)
export(read_disease_parameters)
export(read_interval_pmf)
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# CFAEpiNow2Pipeline (development version)

* Parameters read from local parquet file or files
* Additional CI bug squashing
* Bug fixed in the updated, faster pre-commit checks
* Updated, faster pre-commit checks
Expand Down
238 changes: 238 additions & 0 deletions R/parameters.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
#' Read in disease process parameters from an external file or files
#'
#' @param generation_interval_path,delay_interval_path,right_truncation_path
#' Path to a local file with the parameter PMF. See [read_interval_pmf] for
#' details on the file schema. The parameters can be in the same file or a
#' different file.
#' @param disease One of `COVID-19` or `Influenza`
#' @param as_of_date Use the parameters that were used in production on this
#' date. Set for the current date for the most up-to-to date version of the
#' parameters and set to an earlier date to use parameters from an earlier
#' time period.
#' @param group Used only for parameters with a state-level estimate (i.e., only
#' right-truncation). The two-letter uppercase state abbreviation.
#'
#' @return A named list with three PMFs. The list elements are named
#' `generation_interval`, `delay_interval`, and `right_truncation`. If a path
#' to a local file is not provided (NA or NULL), the corresponding parameter
#' estimate will be NA in the returned list.
#' @details `generation_interval_path` is required because the generation
#' interval is a required parameter for $R_t$ estimation.
#' `delay_interval_path` and `right_truncation_path` are optional (but
#' strongly suggested).
#' @export
read_disease_parameters <- function(
generation_interval_path,
delay_interval_path,
right_truncation_path,
disease,
as_of_date,
group) {
generation_interval <- read_interval_pmf(
path = generation_interval_path,
disease = disease,
as_of_date = as_of_date,
parameter = "generation_interval"
)

if (path_is_specified(delay_interval_path)) {
delay_interval <- read_interval_pmf(
path = delay_interval_path,
disease = disease,
as_of_date = as_of_date,
parameter = "delay"
)
} else {
cli::cli_alert_warning(
"No delay interval path specified. Using a delay of 0 days."
)
delay_interval <- NA
}

if (path_is_specified(right_truncation_path)) {
right_truncation <- read_interval_pmf(
path = right_truncation_path,
disease = disease,
as_of_date = as_of_date,
parameter = "right_truncation",
group = group
)
} else {
cli::cli_alert_warning(
"No right truncation path specified. Not adjusting for right truncation."
)
right_truncation <- NA
}

parameters <- list(
generation_interval = generation_interval,
delay_interval = delay_interval,
right_truncation = right_truncation
)
return(parameters)
}

path_is_specified <- function(path) {
!rlang::is_null(path) &&
!rlang::is_na(path)
}

#' Read parameter PMF into memory
#'
#' Using DuckDB from a parquet file. The function expects the file to be in SCD2
#' format with column names:
#' * parameter
#' * geo_value
#' * disease
#' * start_date
#' * end_date
#' * value
#'
#' start_date and end_date specify the date range for which the value was used.
#' end_date may be NULL (e.g. for the current value used in production). value
#' must contain a pmf vector whose values are all positive and sum to 1. all
#' other fields must be consistent with the specifications of the function
#' arguments describe below, which are used to query from the .parquet file.
#'
#' SCD2 format is shorthand for slowly changing dimension type 2. This format is
#' normalized to track change over time:
#' https://en.wikipedia.org/wiki/Slowly_changing_dimension#Type_2:_add_new_row
#'
#' @param path A path to a local file
#' @param disease One of "COVID-19" or "Influenza"
#' @param as_of_date The parameters "as of" the date of the model run
#' @param parameter One of "generation interval", "delay", or "right-truncation
#' @param group An optional parameter to subset the query to a parameter with a
#' particular two-letter state abbrevation. Right now, the only parameter with
#' state-specific estimates is `right-truncation`.
#'
#' @return A PMF vector
#' @export
read_interval_pmf <- function(path,
disease = c("COVID-19", "Influenza", "test"),
as_of_date,
parameter = c(
"generation_interval",
"delay",
"right_truncation"
),
group = NA) {
###################
# Validate input
rlang::arg_match(parameter)
rlang::arg_match(disease)

as_of_date <- stringify_date(as_of_date)
cli::cli_alert_info("Reading {.arg right_truncation} from {.path {path}}")
if (!file.exists(path)) {
cli::cli_abort("File {.path {path}} does not exist",
class = "file_not_found"
)
}


################
# Prepare query

query <- "
SELECT value
FROM read_parquet(?)
WHERE 1=1
AND parameter = ?
AND disease = ?
AND start_date < ? :: DATE
AND (end_date > ? OR end_date IS NULL)
"
parameters <- list(
path,
parameter,
disease,
as_of_date,
as_of_date
)

# Handle state separately because can't use `=` for NULL comparison and
# DBI::dbBind() can't parameterize a query after IS
if (rlang::is_na(group) || rlang::is_null(group)) {
query <- paste(query, "AND geo_value IS NULL;")
} else {
query <- paste(query, "AND geo_value = ?")
parameters <- c(parameters, list(group))
}

################
# Execute query

con <- DBI::dbConnect(duckdb::duckdb())
pmf_df <- rlang::try_fetch(
DBI::dbGetQuery(
conn = con,
statement = query,
params = parameters
),
error = function(cnd) {
cli::cli_abort(
c(
"Failure loading {.arg {parameter}} from {.path {path}}",
"Using {.val {disease}}, {.val {as_of_date}}, and {.val {group}}",
"Original error: {cnd}"
),
class = "wrapped_error"
)
}
)
DBI::dbDisconnect(con)


################
# Validate loaded PMF
if (nrow(pmf_df) != 1) {
cli::cli_abort(
c(
"Failure loading {.arg {parameter}} from {.path {path}} ",
"Query did not return exactly one row",
"Using {.val {disease}}, {.val {as_of_date}}, and {.val {group}}",
"Query matched {.val {nrow(pmf_df)}} rows"
),
class = "not_one_row_returned"
)
}

pmf <- pmf_df[["value"]][[1]]

if ((length(pmf) < 1) || !rlang::is_bare_numeric(pmf)) {
cli::cli_abort(
c(
"Invalid {.arg {parameter}} returned.",
"x" = "Expected a PMF",
"i" = "Loaded object: {pmf_df}"
),
class = "not_a_pmf"
)
}

if (any(pmf < 0) || any(pmf > 1) || abs(sum(pmf) - 1) > 1e-10) {
cli::cli_abort(
c(
"Returned numeric vector is not a valid PMF",
"Any below 0: {any(pmf < 0)}",
"Any above 1: {any(pmf > 1)}",
"Sum is within 1 with tol of 1e-10: {abs(sum(pmf) - 1) < 1e-10}",
"pmf: {.val {pmf}}"
),
class = "invalid_pmf"
)
}

cli::cli_alert_success("{.arg {parameter}} loaded")

return(pmf)
}

#' DuckDB date comparison fails if the dates are not in string format
#' @noRd
stringify_date <- function(date) {
if (inherits(date, "Date")) {
format(date, "%Y-%m-%d")
}
}
3 changes: 2 additions & 1 deletion man/add_two_numbers.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 45 additions & 0 deletions man/read_disease_parameters.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 9152bc0

Please sign in to comment.