Skip to content

Commit

Permalink
add coverage plots, fix bug in tidydraws test
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris committed Nov 27, 2024
1 parent 31e9725 commit 9e392df
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 2 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Imports:
purrr,
readr,
rlang,
scales,
soql,
stringr,
tibble,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export(inferencedata_to_tidy_draws)
export(location_lookup)
export(nhsn_soda_query)
export(pivot_hubverse_quantiles_wider)
export(plot_coverage_by_date)
export(plot_forecast_quantiles)
export(plot_hubverse_pointintervals)
export(plot_hubverse_quantiles)
Expand Down
78 changes: 78 additions & 0 deletions R/plot_coverage.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#' Plot of empirical forecast coverage by
#' reference date.
#'
#' @param scored Output of [scoringutils::score()], not yet
#' summarized, containing a column for coverage at the required
#' coverage level.
#' @param coverage_level Decimal coverage level to plot, e.g.
#' `0.95` or `0.5`.
#' @param coverage_col Name of the column corresponding to that
#' coverage level in `scored`. Default `interval_coverage_<x>`
#' where <x> is the coverage level as a percentage, e.g. if
#' `coverage_level = 0.95`, then if `coverage_col` is not specified,
#' `plot_coverage_by_date` will look for a column named
#' `interval_coverage_95`, as this is the default name for
#' interval coverage columns produced by [scoringutils::score()] and
#' [scoringutils::summarise_scores()]
#' @param date_col Column containing dates, which will become the
#' x-axis in the empirical coverage by date plot. This can be
#' a target date, but more commonly it will be a forecast date,
#' or `reference_date` indicating when the forecast was produced.
#' Default `"reference_date"`, the standard name for a forecast
#' date in the hubverse schema.
#' @param group_cols Other columns to group by, in addition to forecast
#' date. These will become facets in the output ggplot. Default
#' `c("target", "horizon")` (i.e. group by forecasting target and
#' forecast horizon.
#' @param ytransform transform for the y axis, a string. Passed
#' as the `transform` argument to [ggplot2::scale_y_continuous()].
#' @param ylabels labeling scheme for the y axis. Passed as
#' the `labels` argument to [ggplot2::scale_y_continuous()]. Default
#' [scales::label_percent()].
#' @return A ggplot of the empirical coverage.
#' @export
plot_coverage_by_date <- function(scored,
coverage_level,
coverage_col = NULL,
date_col = "reference_date",
group_cols = c("target", "horizon"),
ytransform = "identity",
ylabels = scales::label_percent()) {
if (is.null(coverage_col)) {
coverage_col <-
glue::glue("interval_coverage_{coverage_level * 100}")
}

summarized <- scored |>
scoringutils::summarise_scores(
by = c(date_col, group_cols)
) |>
tibble::tibble()

fig <- ggplot2::ggplot(
data = summarized,
mapping = ggplot2::aes(
x = .data[[date_col]],
y = .data[[coverage_col]],
)
) +
geom_hline(yintercept = coverage_level) +
geom_point(size = 3) +
geom_line(linewidth = 2) +
scale_y_continuous(
transform = ytransform,
labels = ylabels
) +
coord_cartesian(ylim = c(0, 1))

## facet wrap if one or many group cols, facet grid
## if exactly two
if (length(group_cols) == 1 || length(group_cols > 2)) {
fig <- fig + facet_wrap(group_cols)
} else if (length(group_cols) == 2) {
fig <- fig + facet_grid(reformulate(
group_cols[1], group_cols[2]
))
}
return(fig)
}
2 changes: 2 additions & 0 deletions man/forecasttools-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

60 changes: 60 additions & 0 deletions man/plot_coverage_by_date.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions tests/testthat/test_inferencedata_dataframe_to_tidydraws.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ testthat::test_that("inferencedata_to_tidy_draws converts data correctly", {
)
)

testthat::expect_no_error(spread_draws(result$data[[1]], a, b[x], c[y, z]))
testthat::expect_no_error(spread_draws(result$data[[2]], obs[a]))
testthat::expect_no_error(
tidybayes::spread_draws(result$data[[1]], a, b[x], c[y, z])
)
testthat::expect_no_error(
tidybayes::spread_draws(result$data[[2]], obs[a])
)
})

0 comments on commit 9e392df

Please sign in to comment.