diff --git a/DESCRIPTION b/DESCRIPTION index b93fd1b..1e89939 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -36,6 +36,7 @@ Imports: purrr, readr, rlang, + scales, soql, stringr, tibble, diff --git a/NAMESPACE b/NAMESPACE index a8536cc..206f8c4 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -17,6 +17,7 @@ export(inferencedata_to_tidy_draws) export(location_lookup) export(nhsn_soda_query) export(pivot_hubverse_quantiles_wider) +export(plot_coverage_by_date) export(plot_forecast_quantiles) export(plot_hubverse_pointintervals) export(plot_hubverse_quantiles) diff --git a/R/plot_coverage.R b/R/plot_coverage.R new file mode 100644 index 0000000..ef2a61e --- /dev/null +++ b/R/plot_coverage.R @@ -0,0 +1,78 @@ +#' Plot of empirical forecast coverage by +#' reference date. +#' +#' @param scored Output of [scoringutils::score()], not yet +#' summarized, containing a column for coverage at the required +#' coverage level. +#' @param coverage_level Decimal coverage level to plot, e.g. +#' `0.95` or `0.5`. +#' @param coverage_col Name of the column corresponding to that +#' coverage level in `scored`. Default `interval_coverage_` +#' where is the coverage level as a percentage, e.g. if +#' `coverage_level = 0.95`, then if `coverage_col` is not specified, +#' `plot_coverage_by_date` will look for a column named +#' `interval_coverage_95`, as this is the default name for +#' interval coverage columns produced by [scoringutils::score()] and +#' [scoringutils::summarise_scores()] +#' @param date_col Column containing dates, which will become the +#' x-axis in the empirical coverage by date plot. This can be +#' a target date, but more commonly it will be a forecast date, +#' or `reference_date` indicating when the forecast was produced. +#' Default `"reference_date"`, the standard name for a forecast +#' date in the hubverse schema. +#' @param group_cols Other columns to group by, in addition to forecast +#' date. These will become facets in the output ggplot. Default +#' `c("target", "horizon")` (i.e. group by forecasting target and +#' forecast horizon. +#' @param ytransform transform for the y axis, a string. Passed +#' as the `transform` argument to [ggplot2::scale_y_continuous()]. +#' @param ylabels labeling scheme for the y axis. Passed as +#' the `labels` argument to [ggplot2::scale_y_continuous()]. Default +#' [scales::label_percent()]. +#' @return A ggplot of the empirical coverage. +#' @export +plot_coverage_by_date <- function(scored, + coverage_level, + coverage_col = NULL, + date_col = "reference_date", + group_cols = c("target", "horizon"), + ytransform = "identity", + ylabels = scales::label_percent()) { + if (is.null(coverage_col)) { + coverage_col <- + glue::glue("interval_coverage_{coverage_level * 100}") + } + + summarized <- scored |> + scoringutils::summarise_scores( + by = c(date_col, group_cols) + ) |> + tibble::tibble() + + fig <- ggplot2::ggplot( + data = summarized, + mapping = ggplot2::aes( + x = .data[[date_col]], + y = .data[[coverage_col]], + ) + ) + + geom_hline(yintercept = coverage_level) + + geom_point(size = 3) + + geom_line(linewidth = 2) + + scale_y_continuous( + transform = ytransform, + labels = ylabels + ) + + coord_cartesian(ylim = c(0, 1)) + + ## facet wrap if one or many group cols, facet grid + ## if exactly two + if (length(group_cols) == 1 || length(group_cols > 2)) { + fig <- fig + facet_wrap(group_cols) + } else if (length(group_cols) == 2) { + fig <- fig + facet_grid(reformulate( + group_cols[1], group_cols[2] + )) + } + return(fig) +} diff --git a/man/forecasttools-package.Rd b/man/forecasttools-package.Rd index fbded7f..61d7878 100644 --- a/man/forecasttools-package.Rd +++ b/man/forecasttools-package.Rd @@ -12,6 +12,8 @@ A collection of tools for short-term forecasting within CFA Predict. Useful links: \itemize{ \item \url{https://cdcgov.github.io/forecasttools} + \item \url{https://github.com/CDCgov/forecasttools} + \item Report bugs at \url{https://github.com/CDCgov/forecasttools/issues} } } diff --git a/man/plot_coverage_by_date.Rd b/man/plot_coverage_by_date.Rd new file mode 100644 index 0000000..c4f3f6e --- /dev/null +++ b/man/plot_coverage_by_date.Rd @@ -0,0 +1,60 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/plot_coverage.R +\name{plot_coverage_by_date} +\alias{plot_coverage_by_date} +\title{Plot of empirical forecast coverage by +reference date.} +\usage{ +plot_coverage_by_date( + scored, + coverage_level, + coverage_col = NULL, + date_col = "reference_date", + group_cols = c("target", "horizon"), + ytransform = "identity", + ylabels = scales::label_percent() +) +} +\arguments{ +\item{scored}{Output of \code{\link[scoringutils:score]{scoringutils::score()}}, not yet +summarized, containing a column for coverage at the required +coverage level.} + +\item{coverage_level}{Decimal coverage level to plot, e.g. +\code{0.95} or \code{0.5}.} + +\item{coverage_col}{Name of the column corresponding to that +coverage level in \code{scored}. Default \verb{interval_coverage_} +where \if{html}{\out{}} is the coverage level as a percentage, e.g. if +\code{coverage_level = 0.95}, then if \code{coverage_col} is not specified, +\code{plot_coverage_by_date} will look for a column named +\code{interval_coverage_95}, as this is the default name for +interval coverage columns produced by \code{\link[scoringutils:score]{scoringutils::score()}} and +\code{\link[scoringutils:summarise_scores]{scoringutils::summarise_scores()}}} + +\item{date_col}{Column containing dates, which will become the +x-axis in the empirical coverage by date plot. This can be +a target date, but more commonly it will be a forecast date, +or \code{reference_date} indicating when the forecast was produced. +Default \code{"reference_date"}, the standard name for a forecast +date in the hubverse schema.} + +\item{group_cols}{Other columns to group by, in addition to forecast +date. These will become facets in the output ggplot. Default +\code{c("target", "horizon")} (i.e. group by forecasting target and +forecast horizon.} + +\item{ytransform}{transform for the y axis, a string. Passed +as the \code{transform} argument to \code{\link[ggplot2:scale_continuous]{ggplot2::scale_y_continuous()}}.} + +\item{ylabels}{labeling scheme for the y axis. Passed as +the \code{labels} argument to \code{\link[ggplot2:scale_continuous]{ggplot2::scale_y_continuous()}}. Default +\code{\link[scales:label_percent]{scales::label_percent()}}.} +} +\value{ +A ggplot of the empirical coverage. +} +\description{ +Plot of empirical forecast coverage by +reference date. +} diff --git a/tests/testthat/test_inferencedata_dataframe_to_tidydraws.R b/tests/testthat/test_inferencedata_dataframe_to_tidydraws.R index df9d12a..def21b5 100644 --- a/tests/testthat/test_inferencedata_dataframe_to_tidydraws.R +++ b/tests/testthat/test_inferencedata_dataframe_to_tidydraws.R @@ -25,6 +25,10 @@ testthat::test_that("inferencedata_to_tidy_draws converts data correctly", { ) ) - testthat::expect_no_error(spread_draws(result$data[[1]], a, b[x], c[y, z])) - testthat::expect_no_error(spread_draws(result$data[[2]], obs[a])) + testthat::expect_no_error( + tidybayes::spread_draws(result$data[[1]], a, b[x], c[y, z]) + ) + testthat::expect_no_error( + tidybayes::spread_draws(result$data[[2]], obs[a]) + ) })