diff --git a/.Rbuildignore b/.Rbuildignore index 8ae3faf3..af67f6c0 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -1,6 +1,7 @@ ^CONTRIBUTING.md$ ^DISCLAIMER.md$ ^Dockerfile$ +^Dockerfile-dependencies$ ^LICENSE.md$ ^Makefile$ ^\.github$ @@ -10,6 +11,7 @@ ^\.vscode$ ^\.vscode$ ^_pkgdown\.yml$ +^batch-autoscale-formula.txt$ ^code-of-conduct.md$ ^codecov\.yml$ ^data-raw$ diff --git a/.gitattributes b/.gitattributes index ab87cc58..e136624c 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,5 @@ +NEWS.md merge=union + # Normal text let sit to auto *.htm text *.html text diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 226597b8..7c0638ef 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,6 +44,7 @@ repos: entry: Cannot commit .Rhistory, .RData, .Rds or .rds. language: fail files: '\.(Rhistory|RData|Rds|rds)$' + exclude: '^tests/testthat/data/.*\.rds$' # `exclude: ` to allow committing specific files # Secrets - repo: https://github.com/Yelp/detect-secrets diff --git a/DESCRIPTION b/DESCRIPTION index 262970ec..ae971bd6 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -27,11 +27,18 @@ Imports: AzureRMR, AzureStor, cli, + data.table, DBI, + dplyr, duckdb, EpiNow2 (>= 1.4.0), - rlang + jsonlite, + rlang, + rstan, + tidybayes +Additional_repositories: + https://stan-dev.r-universe.dev URL: https://cdcgov.github.io/cfa-epinow2-pipeline/ Depends: - R (>= 2.10) + R (>= 3.50) LazyData: true diff --git a/NAMESPACE b/NAMESPACE index 7016e7c9..542cc5f3 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,13 +2,19 @@ export(apply_exclusions) export(download_from_azure_blob) +export(extract_diagnostics) export(fetch_blob_container) export(fetch_credential_from_env_var) export(fit_model) export(format_delay_interval) export(format_generation_interval) export(format_right_truncation) +export(low_case_count_diagnostic) +export(process_quantiles) +export(process_samples) export(read_data) export(read_disease_parameters) export(read_exclusions) export(read_interval_pmf) +export(write_model_outputs) +export(write_output_dir_structure) diff --git a/NEWS.md b/NEWS.md index 1ae6fd90..5cf2223c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -20,3 +20,4 @@ * Removed `add.R` placeholder * Fix bugs in date casting caused by DuckDB v1.1.1 release * Drop unused pre-commit hooks +* Write outputs to file diff --git a/R/extract_diagnostics.R b/R/extract_diagnostics.R new file mode 100644 index 00000000..f0bd4595 --- /dev/null +++ b/R/extract_diagnostics.R @@ -0,0 +1,195 @@ +#' Extract diagnostic metrics from model fit and data +#' +#' This function extracts various diagnostic metrics from a fitted `EpiNow2` +#' model and provided data. It checks for low case counts and computes +#' diagnostics from the fitted model, including the mean acceptance +#' statistic, divergent transitions, maximum tree depth, and Rhat values. +#' These diagnostics are then flagged if they exceed specific thresholds, +#' and the results are returned as a data frame. +#' +#' @param fit The model fit object from `EpiNow2` +#' @param data A data frame containing the input data used in the model fit. +#' @param job_id A unique identifier for the job +#' @param task_id A unique identifier for the task +#' @param disease,geo_value,model Metadata for downstream processing. +#' +#' @return A \code{data.frame} containing the extracted diagnostic metrics. The +#' data frame includes the following columns: +#' \itemize{ +#' \item \code{diagnostic}: The name of the diagnostic metric. +#' \item \code{value}: The value of the diagnostic metric. +#' \item \code{job_id}: The unique identifier for the job. +#' \item \code{task_id}: The unique identifier for the task. +#' \item \code{disease,geo_value,model}: Metadata for downstream processing. +#' } +#' +#' @details +#' The following diagnostics are calculated: +#' \itemize{ +#' \item \code{mean_accept_stat}: The average acceptance statistic across +#' all chains. +#' \item \code{p_divergent}: The *proportion* of divergent transitions across +#' all samples. +#' \item \code{n_divergent}: The *number* of divergent transitions across +#' all samples. +#' \item \code{p_max_treedepth}: The proportion of samples that hit the +#' maximum tree depth. +#' \item \code{p_high_rhat}: The *proportion* of parameters with Rhat values +#' greater than 1.05, indicating potential convergence issues. +#' \item \code{n_high_rhat}: The *number* of parameters with Rhat values +#' greater than 1.05, indicating potential convergence issues. +#' \item \code{low_case_count_flag}: A flag indicating if there are low case +#' counts in the data. See \code{low_case_count_diagnostic()} for more +#' information on this diagnostic. +#' \item \code{epinow2_diagnostic_flag}: A combined flag that indicates if +#' any diagnostic thresholds are exceeded. The diagnostic thresholds +#' (1) mean_accept_stat < 0.1, (2) p_divergent > 0.0075, (3) +#' p_max_treedepth > 0.05, and (4) p_high_rhat > 0.0075. +#' } +#' @export +extract_diagnostics <- function(fit, + data, + job_id, + task_id, + disease, + geo_value, + model) { + low_case_count <- low_case_count_diagnostic(data) + + epinow2_diagnostics <- rstan::get_sampler_params(fit$estimates$fit, + inc_warmup = FALSE + ) + mean_accept_stat <- mean( + sapply(epinow2_diagnostics, function(x) mean(x[, "accept_stat__"])) + ) + p_divergent <- mean( + rstan::get_divergent_iterations(fit$estimates$fit), + na.rm = TRUE + ) + n_divergent <- sum( + rstan::get_divergent_iterations(fit$estimates$fit), + na.rm = TRUE + ) + p_max_treedepth <- mean( + rstan::get_max_treedepth_iterations(fit$estimates$fit), + na.rm = TRUE + ) + p_high_rhat <- mean( + rstan::summary(fit$estimates$fit)$summary[, "Rhat"] > 1.05, + na.rm = TRUE + ) + n_high_rhat <- sum( + rstan::summary(fit$estimates$fit)$summary[, "Rhat"] > 1.05, + na.rm = TRUE + ) + + + # Combine all diagnostic flags into one flag + diagnostic_flag <- any( + mean_accept_stat < 0.1, + p_divergent > 0.0075, # 0.0075 = 15 in 2000 samples are divergent + p_max_treedepth > 0.05, + p_high_rhat > 0.0075 + ) + # Create individual vectors for the columns of the diagnostics data frame + diagnostic_names <- c( + "mean_accept_stat", + "p_divergent", + "n_divergent", + "p_max_treedepth", + "p_high_rhat", + "n_high_rhat", + "diagnostic_flag", + "low_case_count_flag" + ) + diagnostic_values <- c( + mean_accept_stat, + p_divergent, + n_divergent, + p_max_treedepth, + p_high_rhat, + n_high_rhat, + diagnostic_flag, + low_case_count + ) + + data.frame( + diagnostic = diagnostic_names, + value = diagnostic_values, + job_id = job_id, + task_id = task_id, + disease = disease, + geo_value = geo_value, + model = model + ) +} + +#' Calculate low case count diagnostic flag +#' +#' The diagnostic flag is TRUE if either of the _last_ two weeks of the dataset +#' have fewer than an aggregate 10 cases per week. This aggregation excludes the +#' count from confirmed outliers, which have been set to NA in the data. +#' +#' This function assumes that the `df` input dataset has been +#' "completed": that any implicit missingness has been made explicit. +#' +#' @param df A dataframe as returned by [read_data()]. The dataframe must +#' include columns such as `reference_date` (a date vector) and `confirm` +#' (the number of confirmed cases per day). +#' +#' @return A logical value (TRUE or FALSE) indicating whether either of the last +#' two weeks in the dataset had fewer than 10 cases per week. +#' @export +low_case_count_diagnostic <- function(df) { + cli::cli_alert_info("Calculating low case count diagnostic") + # Get the dates in the last and second-to-last weeks + last_date <- as.Date(max(df[["reference_date"]], na.rm = TRUE)) + # Create week sequences explicitly in case of missingness + ult_week_min <- last_date - 6 + ult_week_max <- last_date + pen_week_min <- last_date - 13 + pen_week_max <- last_date - 7 + ultimate_week_dates <- seq.Date( + from = ult_week_min, + to = ult_week_max, + by = "day" + ) + penultimate_week_dates <- seq.Date( + from = pen_week_min, + to = pen_week_max, + by = "day" + ) + + ultimate_week_count <- sum( + df[ + df[["reference_date"]] %in% ultimate_week_dates, + "confirm" + ], + na.rm = TRUE + ) + penultimate_week_count <- sum( + df[ + df[["reference_date"]] %in% penultimate_week_dates, + "confirm" + ], + na.rm = TRUE + ) + + + cli::cli_alert_info(c( + "Ultimate week spans {format(ult_week_min, '%a, %Y-%m-%d')} ", + "to {format(ult_week_max, '%a, %Y-%m-%d')} with ", + "count {.val {ultimate_week_count}}" + )) + cli::cli_alert_info(c( + "Penultimate week spans ", + "{format(pen_week_min, '%a, %Y-%m-%d')} to ", + "{format(pen_week_max, '%a, %Y-%m-%d')} with ", + "count {.val {penultimate_week_count}}" + )) + + any( + ultimate_week_count < 10, + penultimate_week_count < 10 + ) +} diff --git a/R/write_output.R b/R/write_output.R new file mode 100644 index 00000000..bcc5e3d9 --- /dev/null +++ b/R/write_output.R @@ -0,0 +1,363 @@ +#' Write model outputs to specified directories +#' +#' Processes the model fit, extracts samples and quantiles, +#' and writes them to the appropriate directories. +#' +#' @param fit An EpiNow2 fit object with posterior estimates. +#' @param output_dir String. The base output directory path. +#' @param samples A data.table as returned by [process_samples()] +#' @param summaries A data.table as returned by [process_quantiles()] +#' @param job_id String. The identifier for the job. +#' @param task_id String. The identifier for the task. +#' @param metadata List. Additional metadata to be included in the output. +#' +#' @return Invisible NULL. The function is called for its side effects. +#' @export +write_model_outputs <- function( + fit, + samples, + summaries, + output_dir, + job_id, + task_id, + metadata = list()) { + rlang::try_fetch( + { + # Create directory structure + write_output_dir_structure(output_dir, job_id, task_id) + + # Write raw samples + samples_path <- file.path( + output_dir, + job_id, + "samples", + paste0(task_id, ".parquet") + ) + write_parquet(samples, samples_path) + cli::cli_alert_success("Wrote samples to {.path {samples_path}}") + + # Process and write summarized quantiles + summaries_path <- file.path( + output_dir, + job_id, + "summaries", + paste0(task_id, ".parquet") + ) + write_parquet(summaries, summaries_path) + cli::cli_alert_success("Wrote summaries to {.path {summaries_path}}") + + # Write EpiNow2 model + model_path <- file.path( + output_dir, + job_id, + "tasks", + task_id, + "model.rds" + ) + saveRDS(fit, model_path) + cli::cli_alert_success("Wrote model to {.path {model_path}}") + + # Write model run metadata + metadata_path <- file.path( + output_dir, + job_id, + "tasks", + task_id, + "metadata.json" + ) + jsonlite::write_json(metadata, metadata_path, pretty = TRUE) + cli::cli_alert_success("Wrote metadata to {.path {metadata_path}}") + }, + error = function(cnd) { + # Downgrade erroring out to a warning so we can catch and log + cli::cli_warn( + "Failure writing outputs", + parent = cnd, + class = "no_outputs" + ) + } + ) + + invisible(NULL) +} + +#' Create output directory structure for a given job and task. +#' +#' This function generates the necessary directory structure for storing output +#' files related to a job and its tasks, including directories for raw samples +#' and summarized quantiles. +#' +#' @param output_dir String. The base output directory path. +#' @param job_id String. The identifier for the job. +#' @param task_id String. The identifier for the task. +#' +#' @return The path to the base output directory (invisible). +#' @export +write_output_dir_structure <- function(output_dir, job_id, task_id) { + # Define the directory structure + dirs <- c( + output_dir, + file.path(output_dir, job_id), + file.path(output_dir, job_id, "tasks"), + file.path(output_dir, job_id, "samples"), + file.path(output_dir, job_id, "summaries"), + file.path(output_dir, job_id, "tasks", task_id) + ) + + # Create directories + lapply(dirs, dir.create, showWarnings = FALSE) + + invisible(output_dir) +} + +#' Extract posterior draws from a Stan fit object. +#' +#' This function extracts posterior draws for specific parameters from a Stan +#' fit object and prepares a fact table containing unique date-time-parameter +#' combinations for further merging. +#' +#' @param fit A Stan fit object with posterior estimates. +#' +#' @return A list containing two elements: `stan_draws` (the extracted draws in +#' long format) and `fact_table` (a table of unique date-time-parameter +#' combinations). +#' @noRd +extract_draws_from_fit <- function(fit) { + # Step 1: Extract unique date-time-parameter combinations + fact_table <- fit[["estimates"]][["samples"]][, + c("date", "time", "parameter"), + with = FALSE + ] + fact_table <- stats::na.omit(unique(fact_table)) + + # Step 1.1: Add corresponding 'obs_cases' rows for 'latent_cases' dates + # Some of the `*_reports` parameters are indexed from time 1, ..., T and some + # go to time T + forecast horizon. `imputed_reports` goes out to T + forecast + # horizon so we can do a downstream join and that will save only up to the max + # timepoint for that parameter. + obs_fact_table <- fact_table[ + fact_table[["parameter"]] == "imputed_reports", + ] + reports_fact_table <- data.table::copy(obs_fact_table) + + # The EpiNow2 summary table has the variable `imputed_reports` + # for nowcast-corrected cases, but not `obs_reports` for right- + # truncated cases to compare to the observed data. We want both. + # + # The dates for `obs_reports` are the same as for `imputed_reports` + # (their differences are the nowcast correction + error structure). + # From Sam: imputed reports have error and are corrected for right-truncation + # (a posterior pred of the final observed value). Obs reports is the + # expected value actually observed in real time but without obs error. + # Get the dates for `obs_reports` by pulling out the `imputed_reports` + # dates and update the associated variable name in-place. Bind it back + # to the original fact table to have all desired variable-date combinations. + data.table::set(obs_fact_table, j = "parameter", value = factor( + obs_fact_table[["parameter"]], + levels = c("imputed_reports"), + labels = c("obs_reports") + )) + data.table::set(reports_fact_table, j = "parameter", value = factor( + reports_fact_table[["parameter"]], + levels = c("imputed_reports"), + labels = c("reports") + )) + + + # Combine original fact_table with new 'obs_reports' rows + fact_table <- rbind(fact_table, + obs_fact_table, + reports_fact_table, + fill = TRUE + ) + data.table::setnames( + fact_table, + old = c("parameter"), + new = c(".variable") + ) + + # Step 2: Extract desired parameters from the Stan object as posterior draws + stanfit <- fit[["estimates"]][["fit"]] + # Hacky workaround to avoid R CMD check NOTE on "no visible global binding" + # for variables in a dataframe evaluated via NSE. To use tidybayes, we need to + # use NSE, so giving these a global binding. The standard dplyr hacks + # (str2lang, .data prefix) didn't work here because it's not dplyr and we're + # not accessing a dataframe. + imputed_reports <- obs_reports <- R <- r <- time <- NULL # nolint + stan_draws <- tidybayes::gather_draws( + stanfit, + reports[time], + imputed_reports[time], + obs_reports[time], + R[time], + r[time], + ) |> + data.table::as.data.table() + + return(list(stan_draws = stan_draws, fact_table = fact_table)) +} + +#' Post-process and merge posterior draws with a fact table. +#' +#' This function merges posterior draws with a fact table containing +#' date-time-parameter combinations. It also standardizes parameter names and +#' renames key columns. +#' +#' @param draws A data.table of posterior draws (either raw or summarized). +#' @param fact_table A data.table of unique date-time-parameter combinations. +#' +#' @return A data.table with merged posterior draws and standardized parameter +#' names. +#' @noRd +post_process_and_merge <- function( + draws, + fact_table, + geo_value, + model, + disease) { + # Step 1: Left join the date-time-parameter map onto the Stan draws + merged_dt <- merge( + draws, + fact_table, + by = c("time", ".variable"), + all.x = TRUE, + all.y = FALSE + ) + + # Step 2: Standardize parameter names + data.table::set(merged_dt, j = ".variable", value = factor( + merged_dt[[".variable"]], + levels = c( + "reports", + "imputed_reports", + "obs_reports", + "R", + "r" + ), + labels = c( + "expected_nowcast_cases", + "pp_nowcast_cases", + "expected_obs_cases", + "Rt", + "growth_rate" + ) + )) + + # Step 3: Rename columns as necessary + data.table::setnames( + merged_dt, + old = c( + ".draw", ".chain", ".variable", ".value", ".lower", ".upper", ".width", + ".point", ".interval", "date", ".iteration" + ), + new = c( + "_draw", "_chain", "_variable", "value", "_lower", "_upper", "_width", + "_point", "_interval", "reference_date", "_iteration" + ), + # If using summaries, skip draws-specific names + skip_absent = TRUE + ) + + # Metadata for downstream querying without path parsing or joins + data.table::set(merged_dt, j = "geo_value", value = factor(geo_value)) + data.table::set(merged_dt, j = "model", value = factor(model)) + data.table::set(merged_dt, j = "disease", value = factor(disease)) + + return(merged_dt) +} + +#' Process posterior samples from a Stan fit object (raw draws). +#' +#' Extracts raw posterior samples from a Stan fit object and post-processes +#' them, including merging with a fact table and standardizing the parameter +#' names. If calling `process_quantiles()` the 50% and 95% intervals are +#' returned in `{tidybayes}` format. +#' +#' @param fit An EpiNow2 fit object with posterior estimates. +#' @param disease,geo_value,model Metadata for downstream processing. +#' @param quantiles A vector of quantiles to base to [tidybayes::median_qi()] +#' +#' @return A data.table of posterior draws or quantiles, merged and processed. +#' +#' @name sample_processing_functions +NULL + +#' @rdname sample_processing_functions +#' @export +process_samples <- function(fit, geo_value, model, disease) { + draws_list <- extract_draws_from_fit(fit) + raw_processed_output <- post_process_and_merge( + draws_list$stan_draws, + draws_list$fact_table, + geo_value, + model, + disease + ) + return(raw_processed_output) +} + +#' @rdname sample_processing_functions +#' @export +process_quantiles <- function( + fit, + geo_value, + model, + disease, + quantiles) { + # Step 1: Extract the draws + draws_list <- extract_draws_from_fit(fit) + + # Step 2: Summarize the draws + .variable <- time <- NULL # nolint + summarized_draws <- draws_list$stan_draws |> + dplyr::group_by(.variable, time) |> + tidybayes::median_qi( + .width = quantiles, + ) |> + data.table::as.data.table() + + # Step 3: Post-process summarized draws + post_process_and_merge( + summarized_draws, + draws_list$fact_table, + geo_value, + model, + disease + ) +} + +write_parquet <- function(data, path) { + # This is bad practice but `dbBind()` doesn't allow us to parameterize COPY + # ... TO. The danger of doing it this way seems quite low risk because it's + # ephemeral from a temporary in-memory DB. There's no actual database to + # guard against a SQL injection attack and all the data are already available + # here. + query <- paste0( + "COPY (SELECT * FROM df) TO '", + path, + "' (FORMAT PARQUET, CODEC 'zstd')" + ) + con <- DBI::dbConnect(duckdb::duckdb()) + on.exit(expr = DBI::dbDisconnect(con)) + + rlang::try_fetch( + { + duckdb::duckdb_register(con, "df", data) + DBI::dbExecute( + con, + statement = query + ) + }, + error = function(con) { + cli::cli_abort( + c( + "Error writing data to {.path {path}}", + "Original error: {con}" + ), + class = "wrapped_invalid_query" + ) + } + ) + + invisible(path) +} diff --git a/README.md b/README.md index c4c95d6c..f0e42e9a 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,71 @@ This package implements functions for: 1. **Logging**: Steps in the pipeline have comprehensive R-style logging, with the the [cli](https://github.com/r-lib/cli) package 1. **Metadata**: Extract comprehensive metadata on the model run and store alongside outputs +## Output format + +The end goals of this package is to standardize the raw outputs from EpiNow2 into samples and summaries tables, and to write those standardized outputs, along with relevant metadata, logs, etc. to a standard directory structure. Once in CFA's standard format, the outputs can be passed into a separate pipeline that handles post-processing (e.g. plotting, scoring, analysis) of Rt estimates from several different Rt estimation models. + +### Directories + +The nested partitioning structure of the outputs is designed to facilitate both automated processes and manual investigation: files are organized by job and task IDs, allowing for efficient file operations using glob patterns, while also maintaining a clear hierarchy that aids human users in navigating to specific results or logs. Files meant primarily for machine-readable consumption (i.e., draws, summaries, diagnostics) are structured together to make globbing easier. Files meant primarily for human investigation (i.e., logs, model fit object) are grouped together by task to facilitate manual workflows. +In this workflow, task IDs correspond to location specific model runs (which are independent of one another) and the jobid refers to a unique model run and disease. For example, a production job should contain task IDs for each of the 50 states and the US, but a job submitted for testing or experimentation might contain a smaller number of tasks/locations. + +```bash +/ +├── job_/ +│ ├── raw_samples/ +│ │ ├── samples_.parquet +│ ├── summarized_quantiles/ +│ │ ├── summarized_.parquet +│ ├── diagnostics/ +│ │ ├── diagnostics_.parquet +│ ├── tasks/ +│ │ ├── task_/ +│ │ │ ├── model.rds +│ │ │ ├── metadata.json +│ │ │ ├── stdout.log +│ │ │ └── stderr.log +│ ├── job_metadata.json +``` + +`/`: The base output directory. This could, for example, be `/` in a Docker container or dedicated output directory. +- `job_/`: A directory named after the specific job identifier, containing all outputs related to that job. All tasks within a job share this same top-level directory. + - `raw_samples/`: A subdirectory within each job folder that holds the raw sample files from all tasks in the job. Task-specific *draws* output files all live together in this directory to enable easy globbing over task-partitioned outputs. + - `samples_.parquet`: A file containing raw samples from the model, associated with a particular task identifier. This file has columns `job_id`, `task_id`, `geo_value`, `disease`, `model`, `_draw`, `_chain`, `_iteration`, `_variable`, `value`, and `reference_date`. These variables follow the [{tidybayes}](https://mjskay.github.io/tidybayes/articles/tidybayes.html) specification. + - `summarized_quantiles/`: A subdirectory for storing summarized quantile data. Task-specific *summarized* output files all live together in this directory to enable easy globbing over task-partitioned outputs. + - `summarized_.parquet`: A file with summarized quantiles relevant to a specific task identifier. This file has columns `job_id`, `task_id`, `geo_value`, `disease`, `model`, `value`, `_lower`, `_upper`, `_width`, `_point`, `_interval`, and `reference_date`. These variables follow the [{tidybayes}](https://mjskay.github.io/tidybayes/articles/tidybayes.html) specification. + - `diagnostics/`: A subdirectory for storing model fit diagnostics. Task-specific *diagnostic* output files all live together in this directory to enable easy globbing over task-partitioned outputs. + - `diagnostic_.parquet`: A file with diagnostics relevant to a specific task identifier. This file has columns `diagnostic`, `value`, `job_id`, `task_id`, `geo_value`, `disease`, and `model`. + - `tasks/`: This directory contains subdirectories for each task within a job. These are files that are less likely to require globbing from the data lake than manual investigation, so are stored togehter. + - `task_/`: Each task has its own folder identified by the task ID, which includes several files: + - `model.rds`: An RDS file storing the EpiNow2 model object fit to the data. + - `metadata.json`: A JSON file containing additional metadata about the model run for this task. + - `stdout.log`: A log file capturing standard output from the model run process. + - `stderr.log`: A log file capturing standard error output from the model run process. +- `job_metadata.json`: A JSON file located in the root of each job's directory, providing metadata about the entire job. + +### Model-estimated quantities + +EpiNow2 estimates the incident cases $\hat y_{td}$ for timepoint $t \in \{1, ..., T\}$ and delay $d \in \{1, ..., D\}$ where $D \le T$. In the single vintage we're providing to EpiNow2, the delay $d$ moves inversely to timepoints, so $d = T - t + 1$. + +The observed data vector of length $T$ is $y_{td} \in W$. We supply a nowcasting correction PMF $\nu$ for the last $D$ timepoints where $\nu_d \in [0, 1],$ and $\sum_{d=1}^D\nu_d = 1$. We also have some priors $\Theta$. + +We use EpiNow2's generative model $f(y, \nu, \Theta)$. + +EpiNow2 is a forward model that produces an expected nowcasted case count for each $t$ and $d$ pair: $\hat \gamma_{td}$. + It applies the nowcasting correction $\nu$ to the last $D$ timepoints of $\hat \gamma$ to produce the expected right-truncated case count $\hat y$. Note that these _expected_ case counts (with and without right-truncation) don't have observation noise included. + +We can apply negative binomial observation noise using EpiNow2's estimate of the negative binomial overdispersion parameter $\hat \phi$ and the expected case counts. The posterior predictive distributions of nowcasted case counts is $\tilde \gamma \sim \text{NB}(\hat \gamma, \hat \phi)$. The posterior predicted right-truncated case count is $\tilde y \sim \text{NB}(\hat y, \hat \phi)$. + +We can get 3 of these 4 quantities pre-generated from the returned EpiNow2 Stan model: + +- $\hat \gamma$: The expected nowcasted case count is `reports[t]` +- $\hat y$: The expected right-truncated case count is `obs_reports[t]` +- $\tilde \gamma$: The posterior-predicted nowcasted case count is `imputed_reports[t]` +- $\tilde y$: The posterior-predicted right-truncated case count isn't returned by EpiNow2. + +We also save the $R_t$ estimate at time $t$ and the intrinsic growth rate at time $t$. + ## Project Admin - @zsusswein diff --git a/man/extract_diagnostics.Rd b/man/extract_diagnostics.Rd new file mode 100644 index 00000000..5eb2e006 --- /dev/null +++ b/man/extract_diagnostics.Rd @@ -0,0 +1,62 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extract_diagnostics.R +\name{extract_diagnostics} +\alias{extract_diagnostics} +\title{Extract diagnostic metrics from model fit and data} +\usage{ +extract_diagnostics(fit, data, job_id, task_id, disease, geo_value, model) +} +\arguments{ +\item{fit}{The model fit object from \code{EpiNow2}} + +\item{data}{A data frame containing the input data used in the model fit.} + +\item{job_id}{A unique identifier for the job} + +\item{task_id}{A unique identifier for the task} + +\item{disease, geo_value, model}{Metadata for downstream processing.} +} +\value{ +A \code{data.frame} containing the extracted diagnostic metrics. The +data frame includes the following columns: +\itemize{ +\item \code{diagnostic}: The name of the diagnostic metric. +\item \code{value}: The value of the diagnostic metric. +\item \code{job_id}: The unique identifier for the job. +\item \code{task_id}: The unique identifier for the task. +\item \code{disease,geo_value,model}: Metadata for downstream processing. +} +} +\description{ +This function extracts various diagnostic metrics from a fitted \code{EpiNow2} +model and provided data. It checks for low case counts and computes +diagnostics from the fitted model, including the mean acceptance +statistic, divergent transitions, maximum tree depth, and Rhat values. +These diagnostics are then flagged if they exceed specific thresholds, +and the results are returned as a data frame. +} +\details{ +The following diagnostics are calculated: +\itemize{ +\item \code{mean_accept_stat}: The average acceptance statistic across +all chains. +\item \code{p_divergent}: The \emph{proportion} of divergent transitions across +all samples. +\item \code{n_divergent}: The \emph{number} of divergent transitions across +all samples. +\item \code{p_max_treedepth}: The proportion of samples that hit the +maximum tree depth. +\item \code{p_high_rhat}: The \emph{proportion} of parameters with Rhat values +greater than 1.05, indicating potential convergence issues. +\item \code{n_high_rhat}: The \emph{number} of parameters with Rhat values +greater than 1.05, indicating potential convergence issues. +\item \code{low_case_count_flag}: A flag indicating if there are low case +counts in the data. See \code{low_case_count_diagnostic()} for more +information on this diagnostic. +\item \code{epinow2_diagnostic_flag}: A combined flag that indicates if +any diagnostic thresholds are exceeded. The diagnostic thresholds +(1) mean_accept_stat < 0.1, (2) p_divergent > 0.0075, (3) +p_max_treedepth > 0.05, and (4) p_high_rhat > 0.0075. +} +} diff --git a/man/low_case_count_diagnostic.Rd b/man/low_case_count_diagnostic.Rd new file mode 100644 index 00000000..94657e94 --- /dev/null +++ b/man/low_case_count_diagnostic.Rd @@ -0,0 +1,26 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extract_diagnostics.R +\name{low_case_count_diagnostic} +\alias{low_case_count_diagnostic} +\title{Calculate low case count diagnostic flag} +\usage{ +low_case_count_diagnostic(df) +} +\arguments{ +\item{df}{A dataframe as returned by \code{\link[=read_data]{read_data()}}. The dataframe must +include columns such as \code{reference_date} (a date vector) and \code{confirm} +(the number of confirmed cases per day).} +} +\value{ +A logical value (TRUE or FALSE) indicating whether either of the last +two weeks in the dataset had fewer than 10 cases per week. +} +\description{ +The diagnostic flag is TRUE if either of the \emph{last} two weeks of the dataset +have fewer than an aggregate 10 cases per week. This aggregation excludes the +count from confirmed outliers, which have been set to NA in the data. +} +\details{ +This function assumes that the \code{df} input dataset has been +"completed": that any implicit missingness has been made explicit. +} diff --git a/man/sample_processing_functions.Rd b/man/sample_processing_functions.Rd new file mode 100644 index 00000000..01ee99ce --- /dev/null +++ b/man/sample_processing_functions.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/write_output.R +\name{sample_processing_functions} +\alias{sample_processing_functions} +\alias{process_samples} +\alias{process_quantiles} +\title{Process posterior samples from a Stan fit object (raw draws).} +\usage{ +process_samples(fit, geo_value, model, disease) + +process_quantiles(fit, geo_value, model, disease, quantiles) +} +\arguments{ +\item{fit}{An EpiNow2 fit object with posterior estimates.} + +\item{disease, geo_value, model}{Metadata for downstream processing.} + +\item{quantiles}{A vector of quantiles to base to \code{\link[tidybayes:reexports]{tidybayes::median_qi()}}} +} +\value{ +A data.table of posterior draws or quantiles, merged and processed. +} +\description{ +Extracts raw posterior samples from a Stan fit object and post-processes +them, including merging with a fact table and standardizing the parameter +names. If calling \code{process_quantiles()} the 50\% and 95\% intervals are +returned in \code{{tidybayes}} format. +} diff --git a/man/write_model_outputs.Rd b/man/write_model_outputs.Rd new file mode 100644 index 00000000..f1991ed0 --- /dev/null +++ b/man/write_model_outputs.Rd @@ -0,0 +1,38 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/write_output.R +\name{write_model_outputs} +\alias{write_model_outputs} +\title{Write model outputs to specified directories} +\usage{ +write_model_outputs( + fit, + samples, + summaries, + output_dir, + job_id, + task_id, + metadata = list() +) +} +\arguments{ +\item{fit}{An EpiNow2 fit object with posterior estimates.} + +\item{samples}{A data.table as returned by \code{\link[=process_samples]{process_samples()}}} + +\item{summaries}{A data.table as returned by \code{\link[=process_quantiles]{process_quantiles()}}} + +\item{output_dir}{String. The base output directory path.} + +\item{job_id}{String. The identifier for the job.} + +\item{task_id}{String. The identifier for the task.} + +\item{metadata}{List. Additional metadata to be included in the output.} +} +\value{ +Invisible NULL. The function is called for its side effects. +} +\description{ +Processes the model fit, extracts samples and quantiles, +and writes them to the appropriate directories. +} diff --git a/man/write_output_dir_structure.Rd b/man/write_output_dir_structure.Rd new file mode 100644 index 00000000..e9cf169c --- /dev/null +++ b/man/write_output_dir_structure.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/write_output.R +\name{write_output_dir_structure} +\alias{write_output_dir_structure} +\title{Create output directory structure for a given job and task.} +\usage{ +write_output_dir_structure(output_dir, job_id, task_id) +} +\arguments{ +\item{output_dir}{String. The base output directory path.} + +\item{job_id}{String. The identifier for the job.} + +\item{task_id}{String. The identifier for the task.} +} +\value{ +The path to the base output directory (invisible). +} +\description{ +This function generates the necessary directory structure for storing output +files related to a job and its tasks, including directories for raw samples +and summarized quantiles. +} diff --git a/tests/testthat/data/sample_fit.rds b/tests/testthat/data/sample_fit.rds new file mode 100644 index 00000000..7a137f7e Binary files /dev/null and b/tests/testthat/data/sample_fit.rds differ diff --git a/tests/testthat/helper-write_parameter_file.R b/tests/testthat/helper-write_parameter_file.R index 387ec9ca..0d7e60a5 100644 --- a/tests/testthat/helper-write_parameter_file.R +++ b/tests/testthat/helper-write_parameter_file.R @@ -6,6 +6,7 @@ write_sample_parameters_file <- function(value, parameter, start_date, end_date) { + Sys.sleep(0.05) df <- data.frame( start_date = as.Date(start_date), geo_value = state, @@ -16,6 +17,7 @@ write_sample_parameters_file <- function(value, ) con <- DBI::dbConnect(duckdb::duckdb()) + on.exit(DBI::dbDisconnect(con)) duckdb::duckdb_register(con, "test_table", df) # This is bad practice but `dbBind()` doesn't allow us to parameterize COPY @@ -24,7 +26,6 @@ write_sample_parameters_file <- function(value, # guard against a SQL injection attack. query <- paste0("COPY (SELECT * FROM test_table) TO '", path, "'") DBI::dbExecute(con, query) - DBI::dbDisconnect(con) invisible(path) } diff --git a/tests/testthat/test-extract_diagnostics.R b/tests/testthat/test-extract_diagnostics.R new file mode 100644 index 00000000..f7cc948b --- /dev/null +++ b/tests/testthat/test-extract_diagnostics.R @@ -0,0 +1,158 @@ +test_that("Fitted model extracts diagnostics", { + # Arrange + data_path <- test_path("data/test_data.parquet") + con <- DBI::dbConnect(duckdb::duckdb()) + data <- DBI::dbGetQuery(con, " + SELECT + report_date, + reference_date, + disease, + geo_value AS state_abb, + value AS confirm + FROM read_parquet(?) + WHERE reference_date <= '2023-01-22'", + params = list(data_path) + ) + DBI::dbDisconnect(con) + fit_path <- test_path("data", "sample_fit.rds") + fit <- readRDS(fit_path) + + # Expected diagnostics + expected <- data.frame( + diagnostic = c( + "mean_accept_stat", + "p_divergent", + "n_divergent", + "p_max_treedepth", + "p_high_rhat", + "n_high_rhat", + "diagnostic_flag", + "low_case_count_flag" + ), + value = c( + 0.94240233, + 0.00000000, + 0.00000000, + 0.00000000, + 0.00000000, + 0.00000000, + 0.00000000, + 0.00000000 + ), + job_id = rep("test", 8), + task_id = rep("test", 8), + disease = rep("test", 8), + geo_value = rep("test", 8), + model = rep("test", 8), + stringsAsFactors = FALSE + ) + actual <- extract_diagnostics( + fit, + data, + "test", + "test", + "test", + "test", + "test" + ) + + testthat::expect_equal( + actual, + expected + ) +}) + +test_that("Cases below threshold returns TRUE", { + # Arrange + true_df <- data.frame( + reference_date = seq.Date( + from = as.Date("2023-01-01"), + by = "day", + length.out = 14 + ), + confirm = c(9, rep(0, 12), 9) + ) + + # Act + diagnostic <- low_case_count_diagnostic(true_df) + + # Assert + expect_true(diagnostic) +}) + +test_that("Cases above threshold returns FALSE", { + # Arrange + false_df <- data.frame( + reference_date = seq.Date( + from = as.Date("2023-01-01"), + by = "day", + length.out = 14 + ), + confirm = rep(10, 14) + ) + + # Act + diagnostic <- low_case_count_diagnostic(false_df) + + # Assert + expect_false(diagnostic) +}) + + +test_that("Only the last two weeks are evalated", { + # Arrange + # 3 weeks, first week would pass but last week does not + df <- data.frame( + reference_date = seq.Date( + from = as.Date("2023-01-01"), + by = "day", + length.out = 21 + ), + # Week 1: 700, Week 2: 700, Week 3: 0 + confirm = c(rep(100, 14), rep(0, 7)) + ) + + # Act + diagnostic <- low_case_count_diagnostic(df) + + # Assert + expect_true(diagnostic) +}) + +test_that("Old approach's negative is now positive", { + # Arrange + df <- data.frame( + reference_date = seq.Date( + from = as.Date("2023-01-01"), + by = "day", + length.out = 14 + ), + # Week 1: 21, Week 2: 0 + confirm = c(rep(3, 7), rep(0, 7)) + ) + + # Act + diagnostic <- low_case_count_diagnostic(df) + + # Assert + expect_true(diagnostic) +}) + +test_that("NAs are evalated as 0", { + # Arrange + df <- data.frame( + reference_date = seq.Date( + from = as.Date("2023-01-01"), + by = "day", + length.out = 14 + ), + # Week 1: 6 (not NA!), Week 2: 700 + confirm = c(NA_real_, rep(1, 6), rep(100, 7)) + ) + + # Act + diagnostic <- low_case_count_diagnostic(df) + + # Assert + expect_true(diagnostic) +}) diff --git a/tests/testthat/test-write_output.R b/tests/testthat/test-write_output.R new file mode 100644 index 00000000..343f704b --- /dev/null +++ b/tests/testthat/test-write_output.R @@ -0,0 +1,317 @@ +test_that("write_model_outputs writes files and directories correctly", { + # Setup: Create a temporary directory and mock inputs + temp_output_dir <- tempfile() + dir.create(temp_output_dir) + + job_id <- "job_123" + task_id <- "task_456" + + # Create mock fit object + mock_fit <- list(estimates = 1:5) + mock_samples <- data.frame(x = 1) + mock_summaries <- data.frame(y = 2) + mock_metadata <- list(author = "Test", date = "2023-01-01") + + # Run the function + withr::with_tempdir({ + write_model_outputs( + mock_fit, + mock_samples, + mock_summaries, + ".", + job_id, + task_id, + mock_metadata + ) + + # Check if the directory structure was created + expect_true(dir.exists(file.path(job_id, "samples"))) + expect_true(dir.exists(file.path(job_id, "summaries"))) + expect_true(dir.exists(file.path(job_id, "tasks", task_id))) + + # Check if raw samples Parquet file was written + samples_file <- file.path( + job_id, + "samples", + paste0(task_id, ".parquet") + ) + expect_true(file.exists(samples_file)) + + # Check if summarized quantiles Parquet file was written + summarized_file <- file.path( + job_id, + "summaries", + paste0(task_id, ".parquet") + ) + expect_true(file.exists(summarized_file)) + + # Check if model rds file was written + model_file <- file.path( + job_id, + "tasks", + task_id, + "model.rds" + ) + expect_true(file.exists(model_file)) + + # Check if metadata JSON file was written + metadata_file <- file.path( + job_id, + "tasks", + task_id, + "metadata.json" + ) + expect_true(file.exists(metadata_file)) + + # Check file contents are right + con <- DBI::dbConnect(duckdb::duckdb()) + on.exit(expr = DBI::dbDisconnect(con)) + raw_samples_data <- DBI::dbGetQuery(con, + "SELECT * FROM read_parquet(?)", + params = list(samples_file) + ) + expect_equal(raw_samples_data, mock_samples) + + raw_summaries_data <- DBI::dbGetQuery(con, + "SELECT * FROM read_parquet(?)", + params = list(summarized_file) + ) + expect_equal(raw_summaries_data, mock_summaries) + + + written_metadata <- jsonlite::read_json(metadata_file) + jsonlite::write_json(mock_metadata, "expected.json") + expected_metadata <- jsonlite::read_json(metadata_file) + + expect_equal(written_metadata, expected_metadata) + }) +}) + +test_that("write_model_outputs handles errors correctly", { + # Setup: Use an invalid directory to trigger an error + invalid_output_dir <- "/invalid/path" + + # Create mock inputs + mock_fit <- list(samples = 1:5) + mock_metadata <- list(author = "Test", date = Sys.Date()) + + # Expect the function to raise a warning due to the invalid directory + withr::with_tempdir({ + expect_warning( + write_model_outputs( + mock_fit, + invalid_output_dir, + "job_123", + "task_456", + mock_metadata + ), + class = "no_outputs" + ) + }) +}) + +test_that("write_output_dir_structure generates dirs", { + withr::with_tempdir({ + write_output_dir_structure(".", job_id = "job", task_id = "task") + + expect_true(dir.exists(file.path("job", "tasks", "task"))) + expect_true(dir.exists(file.path("job", "samples"))) + expect_true(dir.exists(file.path("job", "summaries"))) + }) +}) + +test_that("process_quantiles works as expected", { + # Load the sample fit object + fit <- readRDS(test_path("data", "sample_fit.rds")) + + # Run the function on the fit object + result <- process_quantiles( + fit, + "test_geo", + "test_model", + "test_disease", + c(0.5, 0.95) + ) + + # Test 1: Check if the result is a data.table + expect_true( + data.table::is.data.table(result), + "The result should be a data.table" + ) + + # Test 2: Check if the necessary columns exist in the result + expected_columns <- c( + "time", + "_variable", + "value", + "_lower", + "_upper", + "_width", + "_point", + "_interval", + "reference_date", + "geo_value", + "model", + "disease" + ) + expect_setequal( + colnames(result), expected_columns + ) + + # Test 3: Check if the result contains the correct number of rows + expected_num_rows <- 50 + expect_equal(nrow(result), expected_num_rows, + info = paste("The result should have", expected_num_rows, "rows") + ) + + # Test 4: Check if the `parameter` column contains the expected values + expected_parameters <- c( + "Rt", + "expected_nowcast_cases", + "expected_obs_cases", + "growth_rate", + "pp_nowcast_cases" + ) + unique_parameters <- sort(unique(as.character(result[["_variable"]]))) + expect_equal( + unique_parameters, expected_parameters + ) + + # Test 5: Check if there are no missing values + expect_false( + anyNA(result), + "Columns have NA values" + ) + + # Test 6: Verify the left join: all `time` values from + # `stan_draws` should exist in the result + stan_draws <- tidybayes::gather_draws( + fit[["estimates"]][["fit"]], + imputed_reports[time], + obs_reports[time], + R[time], + r[time] + ) |> + tidybayes::median_qi(.width = c(0.5, 0.95)) |> + data.table::as.data.table() + + expect_true( + all(stan_draws$time %in% result$time), + "All time values from the Stan fit should be present in the result" + ) +}) + +test_that("process_samples works as expected", { + # Load the sample fit object + fit <- readRDS(test_path("data", "sample_fit.rds")) + + # Run the function on the fit object + result <- process_samples(fit, "test_geo", "test_model", "test_disease") + + # Test 1: Check if the result is a data.table + expect_true( + data.table::is.data.table(result), + "The result should be a data.table" + ) + + # Test 2: Check if the necessary columns exist in the result + expected_columns <- c( + "time", + "_variable", + "_chain", + "_iteration", + "_draw", + "value", + "reference_date", + "geo_value", + "model", + "disease" + ) + expect_setequal( + colnames(result), expected_columns + ) + + # Test 3: Check if the result contains the correct number of rows + expected_num_rows <- 2500 # Replace with actual expected value + expect_equal(nrow(result), expected_num_rows, + info = paste("The result should have", expected_num_rows, "rows") + ) + + # Test 4: Check if the `parameter` column contains the expected values + expected_parameters <- c( + "Rt", + "expected_nowcast_cases", + "expected_obs_cases", + "growth_rate", + "pp_nowcast_cases" + ) + unique_parameters <- sort(unique(as.character(result[["_variable"]]))) + expect_equal( + unique_parameters, expected_parameters + ) + + # Test 5: Check if there are no missing values + expect_false( + anyNA(result), + "Columns have NA values" + ) + + # Test 6: Verify the left join: all `time` values from + # `stan_draws` should exist in the result + stan_draws <- tidybayes::gather_draws( + fit[["estimates"]][["fit"]], + imputed_reports[time], + obs_reports[time], + R[time], + r[time] + ) |> + tidybayes::median_qi(.width = c(0.5, 0.95)) |> + data.table::as.data.table() + + expect_true( + all(stan_draws$time %in% result$time), + "All time values from the Stan fit should be present in the result" + ) +}) + +test_that("write_parquet successfully writes data to parquet", { + # Prepare temporary file and sample data + + + temp_path <- "test.parquet" + test_data <- data.frame( + id = 1:5, + value = c("A", "B", "C", "D", "E") + ) + + # Run the function + withr::with_tempdir({ + result <- write_parquet(test_data, temp_path) + + # Check if the function returns the correct path + expect_equal(result, temp_path) + + # Check if the parquet file exists + expect_true(file.exists(temp_path)) + + # Read the file back to ensure data was written correctly + con <- DBI::dbConnect(duckdb::duckdb()) + on.exit(expr = DBI::dbDisconnect(con)) + actual <- DBI::dbGetQuery(con, "SELECT * FROM 'test.parquet'") + + # Verify the data matches the input + expect_equal(actual, test_data) + }) +}) + +test_that("write_parquet handles errors correctly", { + # Prepare a temporary path that should fail (invalid directory) + invalid_path <- "/invalid/path/test.parquet" + test_data <- data.frame(id = 1:5, value = c("A", "B", "C", "D", "E")) + + # Expect the function to throw an error when writing to an invalid path + expect_error( + write_parquet(test_data, invalid_path), + class = "wrapped_invalid_query" + ) +})