Skip to content

Commit

Permalink
Config class is now operational
Browse files Browse the repository at this point in the history
This took a little more finessing than expected. Particularly around using lists for the sampler opts and the priors, I went back to lists from S7 objects, and added a default list that shows the desired keys and the expected types.
  • Loading branch information
natemcintosh committed Nov 26, 2024
1 parent eefe9e9 commit db75206
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 188 deletions.
4 changes: 0 additions & 4 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@ export(Data)
export(DelayInterval)
export(Exclusions)
export(GenerationInterval)
export(GpPrior)
export(Parameters)
export(Priors)
export(RightTruncation)
export(RtPrior)
export(SamplerOpts)
export(apply_exclusions)
export(download_from_azure_blob)
export(execute_model_logic)
Expand Down
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# CFAEpiNow2Pipeline (development version)

* Creating a Config class to make syncing configuration differences easier.
* Add a JSON reader for the Config class.
* Use the Config class throughout the pipeline.
* Adding a script to setup the Azure Batch Pool to link the container.
* Adding new action to post a comment on PRs with a link to the rendered pkgdown site.
* Add inner pipeline responsible for running the model fitting process
Expand Down Expand Up @@ -33,4 +36,3 @@
* Fix bugs in parameter reading from local test run
* Fix bugs in parameter reading from local test run
* Add "US" as an option in `state_abb`
* Creating a Config class to make syncing configuration differences easier.
130 changes: 44 additions & 86 deletions R/config.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ character_or_null <- S7::new_union(S7::class_character, NULL)
#' @param path A string specifying the path to a CSV file containing exclusion
#' data. It should include at least the columns: `reference_date`,
#' `report_date`, ' `state_abb`, `disease`.
#' @param blob_storage_container Optional. The name of the blob storage
#' container to get it from. If NULL, will look locally.
#' @export
Exclusions <- S7::new_class(
"Exclusions",
properties = list(
path = character_or_null
path = character_or_null,
blob_storage_container = character_or_null
)
)

Expand All @@ -26,7 +29,7 @@ Exclusions <- S7::new_class(
GenerationInterval <- S7::new_class(
"GenerationInterval",
properties = list(
path = S7::class_character,
path = character_or_null,
blob_storage_container = character_or_null
)
)
Expand All @@ -42,7 +45,7 @@ GenerationInterval <- S7::new_class(
DelayInterval <- S7::new_class(
"DelayInterval",
properties = list(
path = S7::class_character,
path = character_or_null,
blob_storage_container = character_or_null
)
)
Expand All @@ -58,7 +61,7 @@ DelayInterval <- S7::new_class(
RightTruncation <- S7::new_class(
"RightTruncation",
properties = list(
path = S7::class_character,
path = character_or_null,
blob_storage_container = character_or_null
)
)
Expand All @@ -80,51 +83,6 @@ Parameters <- S7::new_class(
)
)

#' RtPrior Class
#'
#' Represents the Rt prior parameters.
#'
#' @param mean A numeric value representing the mean of the Rt prior.
#' @param sd A numeric value representing the standard deviation of the Rt
#' prior.
#' @export
RtPrior <- S7::new_class(
"RtPrior",
properties = list(
mean = S7::class_numeric,
sd = S7::class_numeric
)
)

#' GpPrior Class
#'
#' Represents the Gaussian Process prior parameters.
#'
#' @param alpha_sd A numeric value representing the standard deviation of the
#' alpha parameter in the GP prior.
#' @export
GpPrior <- S7::new_class(
"GpPrior",
properties = list(
alpha_sd = S7::class_numeric
)
)

#' Priors Class
#'
#' Holds all prior-related configurations for the pipeline.
#'
#' @param rt An instance of `RtPrior` class.
#' @param gp An instance of `GpPrior` class.
#' @export
Priors <- S7::new_class(
"Priors",
properties = list(
rt = S7::S7_class(RtPrior()),
gp = S7::S7_class(GpPrior())
)
)

#' Data Class
#'
#' Represents the data-related configurations.
Expand All @@ -148,30 +106,6 @@ Data <- S7::new_class(
)
)

#' SamplerOpts Class
#'
#' Represents the sampler options for the pipeline.
#'
#' @param cores An integer specifying the number of CPU cores to use.
#' @param chains An integer specifying the number of Markov chains.
#' @param iter_warmup An integer specifying the number of warmup iterations.
#' @param iter_sampling An integer specifying the number of sampling iterations.
#' @param adapt_delta A numeric value for the target acceptance probability.
#' @param max_treedepth An integer specifying the maximum tree depth for the
#' sampler.
#' @export
SamplerOpts <- S7::new_class(
"SamplerOpts",
properties = list(
cores = S7::class_integer,
chains = S7::class_integer,
iter_warmup = S7::class_integer,
iter_sampling = S7::class_integer,
adapt_delta = S7::class_numeric,
max_treedepth = S7::class_integer
)
)

#' Config Class
#'
#' Represents the complete configuration for the pipeline.
Expand All @@ -188,11 +122,13 @@ SamplerOpts <- S7::new_class(
#' @param data An instance of `Data` class containing data configurations.
#' @param seed An integer for setting the random seed.
#' @param horizon An integer specifying the forecasting horizon.
#' @param priors An instance of `Priors` class containing prior configurations.
#' @param priors A list of lists. The first level should contain the key `rt`
#' with elements `mean` and `sd` and the key `gp` with element `alpha_sd`.
#' @param parameters An instance of `Parameters` class containing parameter
#' configurations.
#' @param sampler_opts An instance of `SamplerOpts` class containing sampler
#' options.
#' @param sampler_opts A list. The Stan sampler options to be passed through
#' EpiNow2. It has required keys: `cores`, `chains`, `iter_warmup`,
#' `iter_sampling`, `max_treedepth`, and `adapt_delta`.
#' @param exclusions An instance of `Exclusions` class containing exclusion
#' criteria.
#' @param config_version A numeric value specifying the configuration version.
Expand All @@ -201,6 +137,8 @@ SamplerOpts <- S7::new_class(
#' @param model A string specifying the model to be used.
#' @param report_date A string representing the report date. Formatted as
#' "YYYY-MM-DD".
#' @param as_of_date A string representing the as-of date. Formatted as
#' "YYYY-MM-DD".
#' @export
Config <- S7::new_class(
"Config",
Expand All @@ -210,6 +148,7 @@ Config <- S7::new_class(
min_reference_date = S7::class_character,
max_reference_date = S7::class_character,
report_date = S7::class_character,
as_of_date = S7::class_character,
disease = S7::class_character,
geo_value = S7::class_character,
geo_type = S7::class_character,
Expand All @@ -219,9 +158,30 @@ Config <- S7::new_class(
config_version = S7::class_character,
quantile_width = S7::new_property(S7::class_vector, default = c(0.5, 0.95)),
data = S7::S7_class(Data()),
priors = S7::S7_class(Priors()),
# Using a list instead of an S7 object, because EpiNow2 expects a list, and
# because it reduces changes to the pipeline code.
# Adding a default that shows the required keys, with the expected types as
# the values. Should fail loudly if the default values are used, but will
# be useful to the user to see what is expected.
priors = S7::new_property(S7::class_list, default = list(
rt = list(mean = S7::class_numeric, sd = S7::class_numeric),
gp = list(alpha_sd = S7::class_numeric)
)),
parameters = S7::S7_class(Parameters()),
sampler_opts = S7::S7_class(SamplerOpts()),
# Using a list instead of an S7 object, because stan expects a list, and
# because it reduces changes to the pipeline code.
# Adding a default that shows the required keys, with the expected types as
# the values. Should fail loudly if the default values are used, but will
# be useful to the user to see what is expected.
# Using a list here also reduces changes to the pipeline code.
sampler_opts = S7::new_property(S7::class_list, default = list(
cores = S7::class_integer,
chains = S7::class_integer,
iter_warmup = S7::class_integer,
iter_sampling = S7::class_integer,
max_treedepth = S7::class_integer,
adapt_delta = S7::class_numeric
)),
exclusions = S7::S7_class(Exclusions())
)
)
Expand All @@ -242,12 +202,8 @@ read_json_into_config <- function(config_path) {
# create a more automated way to do this, we can remove this.
str2class <- list(
data = Data,
priors = Priors,
parameters = Parameters,
sampler_opts = SamplerOpts,
exclusions = Exclusions,
rt = RtPrior,
gp = GpPrior,
generation_interval = GenerationInterval,
delay_interval = DelayInterval,
right_truncation = RightTruncation
Expand All @@ -260,9 +216,9 @@ read_json_into_config <- function(config_path) {
missing_properties <- setdiff(S7::prop_names(Config()), names(raw_input))
# Error out if missing any fields
if (length(missing_properties) > 0) {
cli::cli_abort(c(
"The following properties are missing from the config file:",
"{.var missing_properties}"
cli::cli_alert_info(c(
"The following expected properties were not in the config file:",
"{.var {missing_properties}}"
))
}

Expand All @@ -279,7 +235,9 @@ read_json_into_config <- function(config_path) {
raw_data[[prop_name]], str2class[[prop_name]]
)
} else if (!(prop_name %in% S7::prop_names(class_to_fill()))) {
cli::cli_warn("No Config field matching {.var {prop_name}}. Not using.")
cli::cli_alert_info(
"No Config field matching {.var {prop_name}}. Not using."
)
} else {
# Else set it directly
S7::prop(config, prop_name) <- raw_data[[prop_name]]
Expand Down
26 changes: 18 additions & 8 deletions R/pipeline.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,24 @@
orchestrate_pipeline <- function(config_path,
blob_storage_container = NULL,
output_dir = "/") {
# TODO: Add config reader here
config <- jsonlite::read_json(config_path,
simplifyVector = TRUE
config <- rlang::try_fetch(
read_json_into_config(config_path),
error = function(con) {
cli::cli_warn("Bad config file",
parent = con,
class = "Bad_config"
)
FALSE
}
)
if (typeof(config) == "logical") {
return(invisible(FALSE))
}

write_output_dir_structure(
output_dir = output_dir,
job_id = config[["job_id"]],
task_id = config[["task_id"]]
job_id = config@job_id,
task_id = config@task_id
)

# Set up logs
Expand All @@ -104,8 +113,8 @@ orchestrate_pipeline <- function(config_path,
)
on.exit(sink(file = NULL))
cli::cli_alert_info("Starting run at {Sys.time()}")
cli::cli_alert_info("Using job id {.field {config[['job_id']]}}")
cli::cli_alert_info("Using task id {.field {config[['task_id']]}}")
cli::cli_alert_info("Using job id {.field {config@job_id}}")
cli::cli_alert_info("Using task id {.field {config@task_id}}")

# Errors within `execute_model_logic()` are downgraded to warnings so
# they can be logged and stored in Blob. If there is an error,
Expand Down Expand Up @@ -153,7 +162,8 @@ execute_model_logic <- function(config, output_dir) {
min_reference_date = config@min_reference_date
)

if (!rlang::is_null(config@exclusions@path)) {
# rlang::is_empty() checks for empty and NULL values
if (!rlang::is_empty(config@exclusions@path)) {
exclusions_df <- read_exclusions(config@exclusions@path)
cases_df <- apply_exclusions(cases_df, exclusions_df)
} else {
Expand Down
Loading

0 comments on commit db75206

Please sign in to comment.