Skip to content

Commit

Permalink
Wrapper to read all params
Browse files Browse the repository at this point in the history
  • Loading branch information
zsusswein committed Aug 31, 2024
1 parent fa6acaa commit bd7b211
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 15 deletions.
84 changes: 71 additions & 13 deletions R/parameters.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,57 @@
read_parameters <- function(
generation_interval_path,
delay_interval_path,
right_truncation_path,
disease,
as_of_date,
state) {
generation_interval <- read_interval_pmf(
path = generation_interval_path,
disease = disease,
as_of_date = as_of_date,
parameter = "generation_interval"
)

if (path_is_specified(delay_interval_path)) {
delay_interval <- read_interval_pmf(
path = delay_interval_path,
disease = disease,
as_of_date = as_of_date,
parameter = "delay"
)
} else {
cli::cli_alert_warning(
"No delay interval path specified. Using a delay of 0 days"
)
}

if (path_is_specified(right_truncation_path)) {
right_truncation <- read_interval_pmf(
path = right_truncation_path,
disease = disease,
as_of_date = as_of_date,
parameter = "right_truncation",
state = state
)
} else {
cli::cli_alert_warning(
"No right truncation path specified. Not adjusting for right truncation."
)
}

parameters <- list(
generation_interval = generation_interval,
delay_interval = delay_interval,
right_truncation = right_truncation
)
return(parameters)
}

path_is_specified <- function(path) {
!rlang::is_null(path) &&
!rlang::is_na(path)
}

#' Read parameter PMFs into memory
#'
#' Using DuckDB from a parquet file. The function expects the file to be in SCD2
Expand All @@ -13,6 +67,9 @@
#' @param disease One of "COVID-19" or "Influenza"
#' @param as_of_date The parameters "as of" the date of the model run
#' @param parameter One of "generation interval", "delay", or "right-truncation
#' @param state An optional parameter to subset the query to a parameter with a
#' particular two-letter state abbrevation. Right now, the only parameter with
#' state-specific estimates is `right-truncation`.
#'
#' @return A PMF vector
#' @export
Expand All @@ -30,16 +87,12 @@ read_interval_pmf <- function(path,
rlang::arg_match(parameter)
rlang::arg_match(disease)

as_of_date <- stringify_date(as_of_date)
cli::cli_alert_info("Reading {.arg right_truncation} from {.path {path}}")
if (!file.exists(path)) {
cli::cli_abort("File {.path {path}} does not exist")
}

# The comparison fails if the dates are not in string format, reformat if
# needed
if (inherits(as_of_date, "Date")) {
as_of_date <- format(as_of_date, "%Y-%m-%d")
}

################
# Prepare query
Expand Down Expand Up @@ -93,16 +146,11 @@ read_interval_pmf <- function(path,

################
# Validate loaded PMF
pmf <- check_returned_pmf(pmf_df)
cli::cli_alert_success("{.arg {parameter}} loaded")

return(pmf)
}

check_returned_pmf <- function(pmf_df) {
if (nrow(pmf_df) != 1) {
cli::cli_abort(c(
"Failure loading {.arg {parameter}} from {.path {path}} ",
"Query did not return exactly one row",
"Using {.val {disease}}, {.val {as_of_date}}, and {.val {state}}",
"Query matched {.val {nrow(pmf_df)}} rows"
))
}
Expand All @@ -122,9 +170,19 @@ check_returned_pmf <- function(pmf_df) {
"Returned numeric vector is not a valid PMF",
"Any below 0: {any(pmf < 0)}",
"Any above 1: {any(pmf > 1)}",
"Within 1 with tol of 1e-10: {abs(sum(pmf) - 1) < 1e-10}"
"Within 1 with tol of 1e-10: {abs(sum(pmf) - 1) < 1e-10},
pmf: : {.val {pmf}}"
))
}

cli::cli_alert_success("{.arg {parameter}} loaded")

return(pmf)
}

#' DuckDB date comparison fails if the dates are not in string format
stringify_date <- function(date) {
if (inherits(as_of_date, "Date")) {
format(as_of_date, "%Y-%m-%d")
}
}
1 change: 0 additions & 1 deletion tests/testthat/helper-write_parameter_file.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ write_sample_parameters_file <- function(value,
disease = disease,
parameter = parameter,
end_date = end_date,
disease = disease,
value = I(list(value))
)

Expand Down
50 changes: 49 additions & 1 deletion tests/testthat/test-parameters.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,52 @@
test_that("Can read all params on happy path", {
expected <- c(0.8, 0.2)
start_date <- as.Date("2023-01-01")
disease <- "COVID-19"

withr::with_tempdir({
write_sample_parameters_file(
value = expected,
parameter = "generation_interval",
path = "generation_interval.parquet",
disease = disease,
state = NA,
start_date = start_date,
end_date = NA
)
write_sample_parameters_file(
value = expected,
parameter = "delay",
path = "delay_interval.parquet",
disease = disease,
state = NA,
start_date = start_date,
end_date = NA
)
write_sample_parameters_file(
value = expected,
parameter = "right_truncation",
path = "right_truncation.parquet",
disease = disease,
state = "test",
start_date = start_date,
end_date = NA
)


actual <- read_parameters(
generation_interval_path = "generation_interval.parquet",
delay_interval_path = "delay_interval.parquet",
right_truncation_path = "right_truncation.parquet",
disease = "COVID-19",
as_of_date = start_date + 1,
state = "test"
)
})


expect_equal(actual, c(rep(list(c(0.8, 0.2)), 3)))
})

test_that("Can read right-truncation on happy path", {
expected <- c(0.8, 0.2)
path <- "test.parquet"
Expand All @@ -13,7 +62,6 @@ test_that("Can read right-truncation on happy path", {
state = "test",
disease = disease,
parameter = parameter,
param = parameter,
start_date = start_date,
end_date = NA
)
Expand Down

0 comments on commit bd7b211

Please sign in to comment.