From db752061353b4f37cf82b14bb2bdfa3b48daebd2 Mon Sep 17 00:00:00 2001 From: Nate McIntosh Date: Tue, 26 Nov 2024 10:33:53 -0600 Subject: [PATCH] Config class is now operational This took a little more finessing than expected. Particularly around using lists for the sampler opts and the priors, I went back to lists from S7 objects, and added a default list that shows the desired keys and the expected types. --- NAMESPACE | 4 - NEWS.md | 4 +- R/config.R | 130 ++++++++++------------------ R/pipeline.R | 26 ++++-- man/Config.Rd | 149 +++++++++++++++++++++++++++++++-- man/Exclusions.Rd | 5 +- man/GpPrior.Rd | 15 ---- man/Priors.Rd | 16 ---- man/RtPrior.Rd | 17 ---- man/SamplerOpts.Rd | 32 ------- tests/testthat/test-pipeline.R | 6 +- 11 files changed, 216 insertions(+), 188 deletions(-) delete mode 100644 man/GpPrior.Rd delete mode 100644 man/Priors.Rd delete mode 100644 man/RtPrior.Rd delete mode 100644 man/SamplerOpts.Rd diff --git a/NAMESPACE b/NAMESPACE index bdf3211..d57090e 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -5,12 +5,8 @@ export(Data) export(DelayInterval) export(Exclusions) export(GenerationInterval) -export(GpPrior) export(Parameters) -export(Priors) export(RightTruncation) -export(RtPrior) -export(SamplerOpts) export(apply_exclusions) export(download_from_azure_blob) export(execute_model_logic) diff --git a/NEWS.md b/NEWS.md index eae644d..5907392 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # CFAEpiNow2Pipeline (development version) +* Creating a Config class to make syncing configuration differences easier. +* Add a JSON reader for the Config class. +* Use the Config class throughout the pipeline. * Adding a script to setup the Azure Batch Pool to link the container. * Adding new action to post a comment on PRs with a link to the rendered pkgdown site. * Add inner pipeline responsible for running the model fitting process @@ -33,4 +36,3 @@ * Fix bugs in parameter reading from local test run * Fix bugs in parameter reading from local test run * Add "US" as an option in `state_abb` -* Creating a Config class to make syncing configuration differences easier. diff --git a/R/config.R b/R/config.R index 087849f..9cc8027 100644 --- a/R/config.R +++ b/R/config.R @@ -7,11 +7,14 @@ character_or_null <- S7::new_union(S7::class_character, NULL) #' @param path A string specifying the path to a CSV file containing exclusion #' data. It should include at least the columns: `reference_date`, #' `report_date`, ' `state_abb`, `disease`. +#' @param blob_storage_container Optional. The name of the blob storage +#' container to get it from. If NULL, will look locally. #' @export Exclusions <- S7::new_class( "Exclusions", properties = list( - path = character_or_null + path = character_or_null, + blob_storage_container = character_or_null ) ) @@ -26,7 +29,7 @@ Exclusions <- S7::new_class( GenerationInterval <- S7::new_class( "GenerationInterval", properties = list( - path = S7::class_character, + path = character_or_null, blob_storage_container = character_or_null ) ) @@ -42,7 +45,7 @@ GenerationInterval <- S7::new_class( DelayInterval <- S7::new_class( "DelayInterval", properties = list( - path = S7::class_character, + path = character_or_null, blob_storage_container = character_or_null ) ) @@ -58,7 +61,7 @@ DelayInterval <- S7::new_class( RightTruncation <- S7::new_class( "RightTruncation", properties = list( - path = S7::class_character, + path = character_or_null, blob_storage_container = character_or_null ) ) @@ -80,51 +83,6 @@ Parameters <- S7::new_class( ) ) -#' RtPrior Class -#' -#' Represents the Rt prior parameters. -#' -#' @param mean A numeric value representing the mean of the Rt prior. -#' @param sd A numeric value representing the standard deviation of the Rt -#' prior. -#' @export -RtPrior <- S7::new_class( - "RtPrior", - properties = list( - mean = S7::class_numeric, - sd = S7::class_numeric - ) -) - -#' GpPrior Class -#' -#' Represents the Gaussian Process prior parameters. -#' -#' @param alpha_sd A numeric value representing the standard deviation of the -#' alpha parameter in the GP prior. -#' @export -GpPrior <- S7::new_class( - "GpPrior", - properties = list( - alpha_sd = S7::class_numeric - ) -) - -#' Priors Class -#' -#' Holds all prior-related configurations for the pipeline. -#' -#' @param rt An instance of `RtPrior` class. -#' @param gp An instance of `GpPrior` class. -#' @export -Priors <- S7::new_class( - "Priors", - properties = list( - rt = S7::S7_class(RtPrior()), - gp = S7::S7_class(GpPrior()) - ) -) - #' Data Class #' #' Represents the data-related configurations. @@ -148,30 +106,6 @@ Data <- S7::new_class( ) ) -#' SamplerOpts Class -#' -#' Represents the sampler options for the pipeline. -#' -#' @param cores An integer specifying the number of CPU cores to use. -#' @param chains An integer specifying the number of Markov chains. -#' @param iter_warmup An integer specifying the number of warmup iterations. -#' @param iter_sampling An integer specifying the number of sampling iterations. -#' @param adapt_delta A numeric value for the target acceptance probability. -#' @param max_treedepth An integer specifying the maximum tree depth for the -#' sampler. -#' @export -SamplerOpts <- S7::new_class( - "SamplerOpts", - properties = list( - cores = S7::class_integer, - chains = S7::class_integer, - iter_warmup = S7::class_integer, - iter_sampling = S7::class_integer, - adapt_delta = S7::class_numeric, - max_treedepth = S7::class_integer - ) -) - #' Config Class #' #' Represents the complete configuration for the pipeline. @@ -188,11 +122,13 @@ SamplerOpts <- S7::new_class( #' @param data An instance of `Data` class containing data configurations. #' @param seed An integer for setting the random seed. #' @param horizon An integer specifying the forecasting horizon. -#' @param priors An instance of `Priors` class containing prior configurations. +#' @param priors A list of lists. The first level should contain the key `rt` +#' with elements `mean` and `sd` and the key `gp` with element `alpha_sd`. #' @param parameters An instance of `Parameters` class containing parameter #' configurations. -#' @param sampler_opts An instance of `SamplerOpts` class containing sampler -#' options. +#' @param sampler_opts A list. The Stan sampler options to be passed through +#' EpiNow2. It has required keys: `cores`, `chains`, `iter_warmup`, +#' `iter_sampling`, `max_treedepth`, and `adapt_delta`. #' @param exclusions An instance of `Exclusions` class containing exclusion #' criteria. #' @param config_version A numeric value specifying the configuration version. @@ -201,6 +137,8 @@ SamplerOpts <- S7::new_class( #' @param model A string specifying the model to be used. #' @param report_date A string representing the report date. Formatted as #' "YYYY-MM-DD". +#' @param as_of_date A string representing the as-of date. Formatted as +#' "YYYY-MM-DD". #' @export Config <- S7::new_class( "Config", @@ -210,6 +148,7 @@ Config <- S7::new_class( min_reference_date = S7::class_character, max_reference_date = S7::class_character, report_date = S7::class_character, + as_of_date = S7::class_character, disease = S7::class_character, geo_value = S7::class_character, geo_type = S7::class_character, @@ -219,9 +158,30 @@ Config <- S7::new_class( config_version = S7::class_character, quantile_width = S7::new_property(S7::class_vector, default = c(0.5, 0.95)), data = S7::S7_class(Data()), - priors = S7::S7_class(Priors()), + # Using a list instead of an S7 object, because EpiNow2 expects a list, and + # because it reduces changes to the pipeline code. + # Adding a default that shows the required keys, with the expected types as + # the values. Should fail loudly if the default values are used, but will + # be useful to the user to see what is expected. + priors = S7::new_property(S7::class_list, default = list( + rt = list(mean = S7::class_numeric, sd = S7::class_numeric), + gp = list(alpha_sd = S7::class_numeric) + )), parameters = S7::S7_class(Parameters()), - sampler_opts = S7::S7_class(SamplerOpts()), + # Using a list instead of an S7 object, because stan expects a list, and + # because it reduces changes to the pipeline code. + # Adding a default that shows the required keys, with the expected types as + # the values. Should fail loudly if the default values are used, but will + # be useful to the user to see what is expected. + # Using a list here also reduces changes to the pipeline code. + sampler_opts = S7::new_property(S7::class_list, default = list( + cores = S7::class_integer, + chains = S7::class_integer, + iter_warmup = S7::class_integer, + iter_sampling = S7::class_integer, + max_treedepth = S7::class_integer, + adapt_delta = S7::class_numeric + )), exclusions = S7::S7_class(Exclusions()) ) ) @@ -242,12 +202,8 @@ read_json_into_config <- function(config_path) { # create a more automated way to do this, we can remove this. str2class <- list( data = Data, - priors = Priors, parameters = Parameters, - sampler_opts = SamplerOpts, exclusions = Exclusions, - rt = RtPrior, - gp = GpPrior, generation_interval = GenerationInterval, delay_interval = DelayInterval, right_truncation = RightTruncation @@ -260,9 +216,9 @@ read_json_into_config <- function(config_path) { missing_properties <- setdiff(S7::prop_names(Config()), names(raw_input)) # Error out if missing any fields if (length(missing_properties) > 0) { - cli::cli_abort(c( - "The following properties are missing from the config file:", - "{.var missing_properties}" + cli::cli_alert_info(c( + "The following expected properties were not in the config file:", + "{.var {missing_properties}}" )) } @@ -279,7 +235,9 @@ read_json_into_config <- function(config_path) { raw_data[[prop_name]], str2class[[prop_name]] ) } else if (!(prop_name %in% S7::prop_names(class_to_fill()))) { - cli::cli_warn("No Config field matching {.var {prop_name}}. Not using.") + cli::cli_alert_info( + "No Config field matching {.var {prop_name}}. Not using." + ) } else { # Else set it directly S7::prop(config, prop_name) <- raw_data[[prop_name]] diff --git a/R/pipeline.R b/R/pipeline.R index a51a169..7ed7bc7 100644 --- a/R/pipeline.R +++ b/R/pipeline.R @@ -71,15 +71,24 @@ orchestrate_pipeline <- function(config_path, blob_storage_container = NULL, output_dir = "/") { - # TODO: Add config reader here - config <- jsonlite::read_json(config_path, - simplifyVector = TRUE + config <- rlang::try_fetch( + read_json_into_config(config_path), + error = function(con) { + cli::cli_warn("Bad config file", + parent = con, + class = "Bad_config" + ) + FALSE + } ) + if (typeof(config) == "logical") { + return(invisible(FALSE)) + } write_output_dir_structure( output_dir = output_dir, - job_id = config[["job_id"]], - task_id = config[["task_id"]] + job_id = config@job_id, + task_id = config@task_id ) # Set up logs @@ -104,8 +113,8 @@ orchestrate_pipeline <- function(config_path, ) on.exit(sink(file = NULL)) cli::cli_alert_info("Starting run at {Sys.time()}") - cli::cli_alert_info("Using job id {.field {config[['job_id']]}}") - cli::cli_alert_info("Using task id {.field {config[['task_id']]}}") + cli::cli_alert_info("Using job id {.field {config@job_id}}") + cli::cli_alert_info("Using task id {.field {config@task_id}}") # Errors within `execute_model_logic()` are downgraded to warnings so # they can be logged and stored in Blob. If there is an error, @@ -153,7 +162,8 @@ execute_model_logic <- function(config, output_dir) { min_reference_date = config@min_reference_date ) - if (!rlang::is_null(config@exclusions@path)) { + # rlang::is_empty() checks for empty and NULL values + if (!rlang::is_empty(config@exclusions@path)) { exclusions_df <- read_exclusions(config@exclusions@path) cases_df <- apply_exclusions(cases_df, exclusions_df) } else { diff --git a/man/Config.Rd b/man/Config.Rd index 7f75f9c..7ab564b 100644 --- a/man/Config.Rd +++ b/man/Config.Rd @@ -10,6 +10,7 @@ Config( min_reference_date = character(0), max_reference_date = character(0), report_date = character(0), + as_of_date = character(0), disease = character(0), geo_value = character(0), geo_type = character(0), @@ -19,9 +20,142 @@ Config( config_version = character(0), quantile_width = c(0.5, 0.95), data = Data(), - priors = Priors(), + priors = list(rt = list(mean = structure(list(classes = list(structure(list(class = + "integer", constructor_name = "integer", constructor = function (.data = integer(0)) + + .data, validator = function (object) + { + if (base_class(object) != name) { + + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) + + } + }), class = "S7_base_class"), structure(list(class = "double", constructor_name = + "double", constructor = function (.data = numeric(0)) + .data, validator = function + (object) + { + if (base_class(object) != name) { + + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) + } + + }), class = "S7_base_class"))), class = "S7_union"), sd = structure(list(classes = + list(structure(list(class = "integer", constructor_name = "integer", constructor = + function (.data = integer(0)) + .data, validator = function (object) + { + if + (base_class(object) != name) { + + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) + } + + }), class = "S7_base_class"), structure(list(class = "double", constructor_name = + "double", constructor = function (.data = numeric(0)) + .data, validator = function + (object) + { + if (base_class(object) != name) { + + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) + } + + }), class = "S7_base_class"))), class = "S7_union")), gp = list(alpha_sd = + structure(list(classes = list(structure(list(class = "integer", constructor_name = + "integer", constructor = function (.data = integer(0)) + .data, validator = function + (object) + { + if (base_class(object) != name) { + + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) + } + + }), class = "S7_base_class"), structure(list(class = "double", constructor_name = + "double", constructor = function (.data = numeric(0)) + .data, validator = function + (object) + { + if (base_class(object) != name) { + + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) + } + + }), class = "S7_base_class"))), class = "S7_union"))), parameters = Parameters(), - sampler_opts = SamplerOpts(), + sampler_opts = list(cores = structure(list(class = "integer", constructor_name = + "integer", constructor = function (.data = integer(0)) + .data, validator = function + (object) + { + if (base_class(object) != name) { + + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) + } + + }), class = "S7_base_class"), chains = structure(list(class = "integer", + constructor_name = "integer", constructor = function (.data = integer(0)) + .data, + validator = function (object) + { + if (base_class(object) != name) { + + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) + } + + }), class = "S7_base_class"), iter_warmup = structure(list(class = "integer", + constructor_name = "integer", constructor = function (.data = integer(0)) + .data, + validator = function (object) + { + if (base_class(object) != name) { + + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) + } + + }), class = "S7_base_class"), iter_sampling = structure(list(class = "integer", + constructor_name = "integer", constructor = function (.data = integer(0)) + .data, + validator = function (object) + { + if (base_class(object) != name) { + + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) + } + + }), class = "S7_base_class"), max_treedepth = structure(list(class = "integer", + constructor_name = "integer", constructor = function (.data = integer(0)) + .data, + validator = function (object) + { + if (base_class(object) != name) { + + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) + } + + }), class = "S7_base_class"), adapt_delta = structure(list(classes = + list(structure(list(class = "integer", constructor_name = "integer", constructor = + function (.data = integer(0)) + .data, validator = function (object) + { + if + (base_class(object) != name) { + + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) + } + + }), class = "S7_base_class"), structure(list(class = "double", constructor_name = + "double", constructor = function (.data = numeric(0)) + .data, validator = function + (object) + { + if (base_class(object) != name) { + + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) + } + + }), class = "S7_base_class"))), class = "S7_union")), exclusions = Exclusions() ) } @@ -39,6 +173,9 @@ date. Formatted as "YYYY-MM-DD".} \item{report_date}{A string representing the report date. Formatted as "YYYY-MM-DD".} +\item{as_of_date}{A string representing the as-of date. Formatted as +"YYYY-MM-DD".} + \item{disease}{A string specifying the disease being modeled.} \item{geo_value}{A string specifying the geographic value, usually a state.} @@ -58,13 +195,15 @@ quantiles.} \item{data}{An instance of \code{Data} class containing data configurations.} -\item{priors}{An instance of \code{Priors} class containing prior configurations.} +\item{priors}{A list of lists. The first level should contain the key \code{rt} +with elements \code{mean} and \code{sd} and the key \code{gp} with element \code{alpha_sd}.} \item{parameters}{An instance of \code{Parameters} class containing parameter configurations.} -\item{sampler_opts}{An instance of \code{SamplerOpts} class containing sampler -options.} +\item{sampler_opts}{A list. The Stan sampler options to be passed through +EpiNow2. It has required keys: \code{cores}, \code{chains}, \code{iter_warmup}, +\code{iter_sampling}, \code{max_treedepth}, and \code{adapt_delta}.} \item{exclusions}{An instance of \code{Exclusions} class containing exclusion criteria.} diff --git a/man/Exclusions.Rd b/man/Exclusions.Rd index 9f4c43a..d58cc13 100644 --- a/man/Exclusions.Rd +++ b/man/Exclusions.Rd @@ -4,12 +4,15 @@ \alias{Exclusions} \title{Exclusions Class} \usage{ -Exclusions(path = character(0)) +Exclusions(path = character(0), blob_storage_container = character(0)) } \arguments{ \item{path}{A string specifying the path to a CSV file containing exclusion data. It should include at least the columns: \code{reference_date}, \code{report_date}, ' \code{state_abb}, \code{disease}.} + +\item{blob_storage_container}{Optional. The name of the blob storage +container to get it from. If NULL, will look locally.} } \description{ Represents exclusion criteria for the pipeline. diff --git a/man/GpPrior.Rd b/man/GpPrior.Rd deleted file mode 100644 index ea5ec25..0000000 --- a/man/GpPrior.Rd +++ /dev/null @@ -1,15 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/config.R -\name{GpPrior} -\alias{GpPrior} -\title{GpPrior Class} -\usage{ -GpPrior(alpha_sd = integer(0)) -} -\arguments{ -\item{alpha_sd}{A numeric value representing the standard deviation of the -alpha parameter in the GP prior.} -} -\description{ -Represents the Gaussian Process prior parameters. -} diff --git a/man/Priors.Rd b/man/Priors.Rd deleted file mode 100644 index 8845c92..0000000 --- a/man/Priors.Rd +++ /dev/null @@ -1,16 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/config.R -\name{Priors} -\alias{Priors} -\title{Priors Class} -\usage{ -Priors(rt = RtPrior(), gp = GpPrior()) -} -\arguments{ -\item{rt}{An instance of \code{RtPrior} class.} - -\item{gp}{An instance of \code{GpPrior} class.} -} -\description{ -Holds all prior-related configurations for the pipeline. -} diff --git a/man/RtPrior.Rd b/man/RtPrior.Rd deleted file mode 100644 index 2b99938..0000000 --- a/man/RtPrior.Rd +++ /dev/null @@ -1,17 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/config.R -\name{RtPrior} -\alias{RtPrior} -\title{RtPrior Class} -\usage{ -RtPrior(mean = integer(0), sd = integer(0)) -} -\arguments{ -\item{mean}{A numeric value representing the mean of the Rt prior.} - -\item{sd}{A numeric value representing the standard deviation of the Rt -prior.} -} -\description{ -Represents the Rt prior parameters. -} diff --git a/man/SamplerOpts.Rd b/man/SamplerOpts.Rd deleted file mode 100644 index f3149da..0000000 --- a/man/SamplerOpts.Rd +++ /dev/null @@ -1,32 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/config.R -\name{SamplerOpts} -\alias{SamplerOpts} -\title{SamplerOpts Class} -\usage{ -SamplerOpts( - cores = integer(0), - chains = integer(0), - iter_warmup = integer(0), - iter_sampling = integer(0), - adapt_delta = integer(0), - max_treedepth = integer(0) -) -} -\arguments{ -\item{cores}{An integer specifying the number of CPU cores to use.} - -\item{chains}{An integer specifying the number of Markov chains.} - -\item{iter_warmup}{An integer specifying the number of warmup iterations.} - -\item{iter_sampling}{An integer specifying the number of sampling iterations.} - -\item{adapt_delta}{A numeric value for the target acceptance probability.} - -\item{max_treedepth}{An integer specifying the maximum tree depth for the -sampler.} -} -\description{ -Represents the sampler options for the pipeline. -} diff --git a/tests/testthat/test-pipeline.R b/tests/testthat/test-pipeline.R index 08bad59..a0dda93 100644 --- a/tests/testthat/test-pipeline.R +++ b/tests/testthat/test-pipeline.R @@ -71,7 +71,7 @@ test_that("Pipeline run produces expected outputs with exclusions", { test_that("Process pipeline produces expected outputs and returns success", { # Arrange config_path <- test_path("data", "sample_config_with_exclusion.json") - config <- jsonlite::read_json(config_path) + config <- read_json_into_config(config_path) # Read from locally output_dir <- "pipeline_test" on.exit(unlink(output_dir, recursive = TRUE)) @@ -87,7 +87,7 @@ test_that("Process pipeline produces expected outputs and returns success", { # Assert output files all exist expect_pipeline_files_written( output_dir, - config[["job_id"]], - config[["task_id"]] + config@job_id, + config@task_id ) })