Skip to content

Commit

Permalink
Write model run outputs (#41)
Browse files Browse the repository at this point in the history
* Ignore CI/CD stuff in Rbuildignore

* Extract diagnostics from fitted model

* Basic output schema

* Use `.pre-commit.config.yaml` from main

To fix weirdness with unicode parsing error from.....somewhere?

* Update output schema

* Bump NEWS

* Bump NEWS

* Expand on readme

* Use setequal for column name checks

h/t @natemcintosh

* Apply suggestions from code review

Co-authored-by: Adam Howes <[email protected]>

* Update with Adam's review

* Update R/write_output.R

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update R/extract_diagnostics.R

Co-authored-by: Katie Gostic (she/her) <[email protected]>

* Add alert with dates for low case count diagnostic

* Apply suggestions from code review

Co-authored-by: Adam Howes <[email protected]>

* Use new R-universe Stan repository

* Update README.md

Co-authored-by: Katie Gostic (she/her) <[email protected]>

* Update README.md

Co-authored-by: Katie Gostic (she/her) <[email protected]>

* Condense dir creation

* Expose quantiles for summarization

* Save the description of the different EpiNow2 params

* Add comment explaining why dates work

* Clarify comment on EpiNow2 param outputs

* Add `reports` to output

---------

Co-authored-by: Adam Howes <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Katie Gostic (she/her) <[email protected]>
  • Loading branch information
4 people committed Oct 15, 2024
1 parent f27ed7a commit a72c5ba
Show file tree
Hide file tree
Showing 19 changed files with 1,298 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
^CONTRIBUTING.md$
^DISCLAIMER.md$
^Dockerfile$
^Dockerfile-dependencies$
^LICENSE.md$
^Makefile$
^\.github$
Expand All @@ -10,6 +11,7 @@
^\.vscode$
^\.vscode$
^_pkgdown\.yml$
^batch-autoscale-formula.txt$
^code-of-conduct.md$
^codecov\.yml$
^data-raw$
Expand Down
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
NEWS.md merge=union

# Normal text let sit to auto
*.htm text
*.html text
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ repos:
entry: Cannot commit .Rhistory, .RData, .Rds or .rds.
language: fail
files: '\.(Rhistory|RData|Rds|rds)$'
exclude: '^tests/testthat/data/.*\.rds$'
# `exclude: <regex>` to allow committing specific files
# Secrets
- repo: https://github.com/Yelp/detect-secrets
Expand Down
11 changes: 9 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,23 @@ Imports:
AzureRMR,
AzureStor,
cli,
data.table,
DBI,
dplyr,
duckdb,
EpiNow2 (>= 1.4.0),
rlang
jsonlite,
rlang,
rstan,
tidybayes
Additional_repositories:
https://stan-dev.r-universe.dev
URL: https://cdcgov.github.io/cfa-epinow2-pipeline/
Imports:
cli,
jsonlite,
jsonvalidate,
rlang
Depends:
R (>= 2.10)
R (>= 3.50)
LazyData: true
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@ export(fetch_config)
export(validate_config)
export(apply_exclusions)
export(download_from_azure_blob)
export(extract_diagnostics)
export(fetch_blob_container)
export(fetch_credential_from_env_var)
export(fit_model)
export(format_delay_interval)
export(format_generation_interval)
export(format_right_truncation)
export(low_case_count_diagnostic)
export(process_quantiles)
export(process_samples)
export(read_data)
export(read_disease_parameters)
export(read_exclusions)
export(read_interval_pmf)
export(write_model_outputs)
export(write_output_dir_structure)
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
* Removed `add.R` placeholder
* Fix bugs in date casting caused by DuckDB v1.1.1 release
* Drop unused pre-commit hooks
* Write outputs to file
Empty file added R/es
Empty file.
195 changes: 195 additions & 0 deletions R/extract_diagnostics.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
#' Extract diagnostic metrics from model fit and data
#'
#' This function extracts various diagnostic metrics from a fitted `EpiNow2`
#' model and provided data. It checks for low case counts and computes
#' diagnostics from the fitted model, including the mean acceptance
#' statistic, divergent transitions, maximum tree depth, and Rhat values.
#' These diagnostics are then flagged if they exceed specific thresholds,
#' and the results are returned as a data frame.
#'
#' @param fit The model fit object from `EpiNow2`
#' @param data A data frame containing the input data used in the model fit.
#' @param job_id A unique identifier for the job
#' @param task_id A unique identifier for the task
#' @param disease,geo_value,model Metadata for downstream processing.
#'
#' @return A \code{data.frame} containing the extracted diagnostic metrics. The
#' data frame includes the following columns:
#' \itemize{
#' \item \code{diagnostic}: The name of the diagnostic metric.
#' \item \code{value}: The value of the diagnostic metric.
#' \item \code{job_id}: The unique identifier for the job.
#' \item \code{task_id}: The unique identifier for the task.
#' \item \code{disease,geo_value,model}: Metadata for downstream processing.
#' }
#'
#' @details
#' The following diagnostics are calculated:
#' \itemize{
#' \item \code{mean_accept_stat}: The average acceptance statistic across
#' all chains.
#' \item \code{p_divergent}: The *proportion* of divergent transitions across
#' all samples.
#' \item \code{n_divergent}: The *number* of divergent transitions across
#' all samples.
#' \item \code{p_max_treedepth}: The proportion of samples that hit the
#' maximum tree depth.
#' \item \code{p_high_rhat}: The *proportion* of parameters with Rhat values
#' greater than 1.05, indicating potential convergence issues.
#' \item \code{n_high_rhat}: The *number* of parameters with Rhat values
#' greater than 1.05, indicating potential convergence issues.
#' \item \code{low_case_count_flag}: A flag indicating if there are low case
#' counts in the data. See \code{low_case_count_diagnostic()} for more
#' information on this diagnostic.
#' \item \code{epinow2_diagnostic_flag}: A combined flag that indicates if
#' any diagnostic thresholds are exceeded. The diagnostic thresholds
#' (1) mean_accept_stat < 0.1, (2) p_divergent > 0.0075, (3)
#' p_max_treedepth > 0.05, and (4) p_high_rhat > 0.0075.
#' }
#' @export
extract_diagnostics <- function(fit,
data,
job_id,
task_id,
disease,
geo_value,
model) {
low_case_count <- low_case_count_diagnostic(data)

epinow2_diagnostics <- rstan::get_sampler_params(fit$estimates$fit,
inc_warmup = FALSE
)
mean_accept_stat <- mean(
sapply(epinow2_diagnostics, function(x) mean(x[, "accept_stat__"]))
)
p_divergent <- mean(
rstan::get_divergent_iterations(fit$estimates$fit),
na.rm = TRUE
)
n_divergent <- sum(
rstan::get_divergent_iterations(fit$estimates$fit),
na.rm = TRUE
)
p_max_treedepth <- mean(
rstan::get_max_treedepth_iterations(fit$estimates$fit),
na.rm = TRUE
)
p_high_rhat <- mean(
rstan::summary(fit$estimates$fit)$summary[, "Rhat"] > 1.05,
na.rm = TRUE
)
n_high_rhat <- sum(
rstan::summary(fit$estimates$fit)$summary[, "Rhat"] > 1.05,
na.rm = TRUE
)


# Combine all diagnostic flags into one flag
diagnostic_flag <- any(
mean_accept_stat < 0.1,
p_divergent > 0.0075, # 0.0075 = 15 in 2000 samples are divergent
p_max_treedepth > 0.05,
p_high_rhat > 0.0075
)
# Create individual vectors for the columns of the diagnostics data frame
diagnostic_names <- c(
"mean_accept_stat",
"p_divergent",
"n_divergent",
"p_max_treedepth",
"p_high_rhat",
"n_high_rhat",
"diagnostic_flag",
"low_case_count_flag"
)
diagnostic_values <- c(
mean_accept_stat,
p_divergent,
n_divergent,
p_max_treedepth,
p_high_rhat,
n_high_rhat,
diagnostic_flag,
low_case_count
)

data.frame(
diagnostic = diagnostic_names,
value = diagnostic_values,
job_id = job_id,
task_id = task_id,
disease = disease,
geo_value = geo_value,
model = model
)
}

#' Calculate low case count diagnostic flag
#'
#' The diagnostic flag is TRUE if either of the _last_ two weeks of the dataset
#' have fewer than an aggregate 10 cases per week. This aggregation excludes the
#' count from confirmed outliers, which have been set to NA in the data.
#'
#' This function assumes that the `df` input dataset has been
#' "completed": that any implicit missingness has been made explicit.
#'
#' @param df A dataframe as returned by [read_data()]. The dataframe must
#' include columns such as `reference_date` (a date vector) and `confirm`
#' (the number of confirmed cases per day).
#'
#' @return A logical value (TRUE or FALSE) indicating whether either of the last
#' two weeks in the dataset had fewer than 10 cases per week.
#' @export
low_case_count_diagnostic <- function(df) {
cli::cli_alert_info("Calculating low case count diagnostic")
# Get the dates in the last and second-to-last weeks
last_date <- as.Date(max(df[["reference_date"]], na.rm = TRUE))
# Create week sequences explicitly in case of missingness
ult_week_min <- last_date - 6
ult_week_max <- last_date
pen_week_min <- last_date - 13
pen_week_max <- last_date - 7
ultimate_week_dates <- seq.Date(
from = ult_week_min,
to = ult_week_max,
by = "day"
)
penultimate_week_dates <- seq.Date(
from = pen_week_min,
to = pen_week_max,
by = "day"
)

ultimate_week_count <- sum(
df[
df[["reference_date"]] %in% ultimate_week_dates,
"confirm"
],
na.rm = TRUE
)
penultimate_week_count <- sum(
df[
df[["reference_date"]] %in% penultimate_week_dates,
"confirm"
],
na.rm = TRUE
)


cli::cli_alert_info(c(
"Ultimate week spans {format(ult_week_min, '%a, %Y-%m-%d')} ",
"to {format(ult_week_max, '%a, %Y-%m-%d')} with ",
"count {.val {ultimate_week_count}}"
))
cli::cli_alert_info(c(
"Penultimate week spans ",
"{format(pen_week_min, '%a, %Y-%m-%d')} to ",
"{format(pen_week_max, '%a, %Y-%m-%d')} with ",
"count {.val {penultimate_week_count}}"
))

any(
ultimate_week_count < 10,
penultimate_week_count < 10
)
}
Loading

0 comments on commit a72c5ba

Please sign in to comment.