generated from CDCgov/template
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Ignore CI/CD stuff in Rbuildignore * Extract diagnostics from fitted model * Basic output schema * Use `.pre-commit.config.yaml` from main To fix weirdness with unicode parsing error from.....somewhere? * Update output schema * Bump NEWS * Bump NEWS * Expand on readme * Use setequal for column name checks h/t @natemcintosh * Apply suggestions from code review Co-authored-by: Adam Howes <[email protected]> * Update with Adam's review * Update R/write_output.R * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update R/extract_diagnostics.R Co-authored-by: Katie Gostic (she/her) <[email protected]> * Add alert with dates for low case count diagnostic * Apply suggestions from code review Co-authored-by: Adam Howes <[email protected]> * Use new R-universe Stan repository * Update README.md Co-authored-by: Katie Gostic (she/her) <[email protected]> * Update README.md Co-authored-by: Katie Gostic (she/her) <[email protected]> * Condense dir creation * Expose quantiles for summarization * Save the description of the different EpiNow2 params * Add comment explaining why dates work * Clarify comment on EpiNow2 param outputs * Add `reports` to output --------- Co-authored-by: Adam Howes <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Katie Gostic (she/her) <[email protected]>
- Loading branch information
1 parent
f27ed7a
commit a72c5ba
Showing
19 changed files
with
1,298 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
NEWS.md merge=union | ||
|
||
# Normal text let sit to auto | ||
*.htm text | ||
*.html text | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
#' Extract diagnostic metrics from model fit and data | ||
#' | ||
#' This function extracts various diagnostic metrics from a fitted `EpiNow2` | ||
#' model and provided data. It checks for low case counts and computes | ||
#' diagnostics from the fitted model, including the mean acceptance | ||
#' statistic, divergent transitions, maximum tree depth, and Rhat values. | ||
#' These diagnostics are then flagged if they exceed specific thresholds, | ||
#' and the results are returned as a data frame. | ||
#' | ||
#' @param fit The model fit object from `EpiNow2` | ||
#' @param data A data frame containing the input data used in the model fit. | ||
#' @param job_id A unique identifier for the job | ||
#' @param task_id A unique identifier for the task | ||
#' @param disease,geo_value,model Metadata for downstream processing. | ||
#' | ||
#' @return A \code{data.frame} containing the extracted diagnostic metrics. The | ||
#' data frame includes the following columns: | ||
#' \itemize{ | ||
#' \item \code{diagnostic}: The name of the diagnostic metric. | ||
#' \item \code{value}: The value of the diagnostic metric. | ||
#' \item \code{job_id}: The unique identifier for the job. | ||
#' \item \code{task_id}: The unique identifier for the task. | ||
#' \item \code{disease,geo_value,model}: Metadata for downstream processing. | ||
#' } | ||
#' | ||
#' @details | ||
#' The following diagnostics are calculated: | ||
#' \itemize{ | ||
#' \item \code{mean_accept_stat}: The average acceptance statistic across | ||
#' all chains. | ||
#' \item \code{p_divergent}: The *proportion* of divergent transitions across | ||
#' all samples. | ||
#' \item \code{n_divergent}: The *number* of divergent transitions across | ||
#' all samples. | ||
#' \item \code{p_max_treedepth}: The proportion of samples that hit the | ||
#' maximum tree depth. | ||
#' \item \code{p_high_rhat}: The *proportion* of parameters with Rhat values | ||
#' greater than 1.05, indicating potential convergence issues. | ||
#' \item \code{n_high_rhat}: The *number* of parameters with Rhat values | ||
#' greater than 1.05, indicating potential convergence issues. | ||
#' \item \code{low_case_count_flag}: A flag indicating if there are low case | ||
#' counts in the data. See \code{low_case_count_diagnostic()} for more | ||
#' information on this diagnostic. | ||
#' \item \code{epinow2_diagnostic_flag}: A combined flag that indicates if | ||
#' any diagnostic thresholds are exceeded. The diagnostic thresholds | ||
#' (1) mean_accept_stat < 0.1, (2) p_divergent > 0.0075, (3) | ||
#' p_max_treedepth > 0.05, and (4) p_high_rhat > 0.0075. | ||
#' } | ||
#' @export | ||
extract_diagnostics <- function(fit, | ||
data, | ||
job_id, | ||
task_id, | ||
disease, | ||
geo_value, | ||
model) { | ||
low_case_count <- low_case_count_diagnostic(data) | ||
|
||
epinow2_diagnostics <- rstan::get_sampler_params(fit$estimates$fit, | ||
inc_warmup = FALSE | ||
) | ||
mean_accept_stat <- mean( | ||
sapply(epinow2_diagnostics, function(x) mean(x[, "accept_stat__"])) | ||
) | ||
p_divergent <- mean( | ||
rstan::get_divergent_iterations(fit$estimates$fit), | ||
na.rm = TRUE | ||
) | ||
n_divergent <- sum( | ||
rstan::get_divergent_iterations(fit$estimates$fit), | ||
na.rm = TRUE | ||
) | ||
p_max_treedepth <- mean( | ||
rstan::get_max_treedepth_iterations(fit$estimates$fit), | ||
na.rm = TRUE | ||
) | ||
p_high_rhat <- mean( | ||
rstan::summary(fit$estimates$fit)$summary[, "Rhat"] > 1.05, | ||
na.rm = TRUE | ||
) | ||
n_high_rhat <- sum( | ||
rstan::summary(fit$estimates$fit)$summary[, "Rhat"] > 1.05, | ||
na.rm = TRUE | ||
) | ||
|
||
|
||
# Combine all diagnostic flags into one flag | ||
diagnostic_flag <- any( | ||
mean_accept_stat < 0.1, | ||
p_divergent > 0.0075, # 0.0075 = 15 in 2000 samples are divergent | ||
p_max_treedepth > 0.05, | ||
p_high_rhat > 0.0075 | ||
) | ||
# Create individual vectors for the columns of the diagnostics data frame | ||
diagnostic_names <- c( | ||
"mean_accept_stat", | ||
"p_divergent", | ||
"n_divergent", | ||
"p_max_treedepth", | ||
"p_high_rhat", | ||
"n_high_rhat", | ||
"diagnostic_flag", | ||
"low_case_count_flag" | ||
) | ||
diagnostic_values <- c( | ||
mean_accept_stat, | ||
p_divergent, | ||
n_divergent, | ||
p_max_treedepth, | ||
p_high_rhat, | ||
n_high_rhat, | ||
diagnostic_flag, | ||
low_case_count | ||
) | ||
|
||
data.frame( | ||
diagnostic = diagnostic_names, | ||
value = diagnostic_values, | ||
job_id = job_id, | ||
task_id = task_id, | ||
disease = disease, | ||
geo_value = geo_value, | ||
model = model | ||
) | ||
} | ||
|
||
#' Calculate low case count diagnostic flag | ||
#' | ||
#' The diagnostic flag is TRUE if either of the _last_ two weeks of the dataset | ||
#' have fewer than an aggregate 10 cases per week. This aggregation excludes the | ||
#' count from confirmed outliers, which have been set to NA in the data. | ||
#' | ||
#' This function assumes that the `df` input dataset has been | ||
#' "completed": that any implicit missingness has been made explicit. | ||
#' | ||
#' @param df A dataframe as returned by [read_data()]. The dataframe must | ||
#' include columns such as `reference_date` (a date vector) and `confirm` | ||
#' (the number of confirmed cases per day). | ||
#' | ||
#' @return A logical value (TRUE or FALSE) indicating whether either of the last | ||
#' two weeks in the dataset had fewer than 10 cases per week. | ||
#' @export | ||
low_case_count_diagnostic <- function(df) { | ||
cli::cli_alert_info("Calculating low case count diagnostic") | ||
# Get the dates in the last and second-to-last weeks | ||
last_date <- as.Date(max(df[["reference_date"]], na.rm = TRUE)) | ||
# Create week sequences explicitly in case of missingness | ||
ult_week_min <- last_date - 6 | ||
ult_week_max <- last_date | ||
pen_week_min <- last_date - 13 | ||
pen_week_max <- last_date - 7 | ||
ultimate_week_dates <- seq.Date( | ||
from = ult_week_min, | ||
to = ult_week_max, | ||
by = "day" | ||
) | ||
penultimate_week_dates <- seq.Date( | ||
from = pen_week_min, | ||
to = pen_week_max, | ||
by = "day" | ||
) | ||
|
||
ultimate_week_count <- sum( | ||
df[ | ||
df[["reference_date"]] %in% ultimate_week_dates, | ||
"confirm" | ||
], | ||
na.rm = TRUE | ||
) | ||
penultimate_week_count <- sum( | ||
df[ | ||
df[["reference_date"]] %in% penultimate_week_dates, | ||
"confirm" | ||
], | ||
na.rm = TRUE | ||
) | ||
|
||
|
||
cli::cli_alert_info(c( | ||
"Ultimate week spans {format(ult_week_min, '%a, %Y-%m-%d')} ", | ||
"to {format(ult_week_max, '%a, %Y-%m-%d')} with ", | ||
"count {.val {ultimate_week_count}}" | ||
)) | ||
cli::cli_alert_info(c( | ||
"Penultimate week spans ", | ||
"{format(pen_week_min, '%a, %Y-%m-%d')} to ", | ||
"{format(pen_week_max, '%a, %Y-%m-%d')} with ", | ||
"count {.val {penultimate_week_count}}" | ||
)) | ||
|
||
any( | ||
ultimate_week_count < 10, | ||
penultimate_week_count < 10 | ||
) | ||
} |
Oops, something went wrong.