From 5c3cee2f02a79bd3f37ad5e39bd2c5946cdc467a Mon Sep 17 00:00:00 2001 From: Nate McIntosh Date: Fri, 29 Nov 2024 15:09:14 -0600 Subject: [PATCH] now error out on invalid config files --- R/config.R | 21 ++++- R/pipeline.R | 2 +- man/Config.Rd | 80 +++++++++---------- man/read_json_into_config.Rd | 8 +- .../data/sample_config_no_exclusion.json | 3 +- .../data/sample_config_with_exclusion.json | 3 +- tests/testthat/data/v_bad_config.json | 4 + tests/testthat/test-pipeline.R | 25 +++++- 8 files changed, 95 insertions(+), 51 deletions(-) create mode 100644 tests/testthat/data/v_bad_config.json diff --git a/R/config.R b/R/config.R index 7cf80a4..b0363cf 100644 --- a/R/config.R +++ b/R/config.R @@ -190,10 +190,15 @@ Config <- S7::new_class( #' #' @param config_path A string specifying the path to the JSON configuration #' file. +#' @param optional_fields A list of strings specifying the optional fields in +#' the JSON file. If a field is not present in the JSON file, and is marked as +#' optional, it will be set to either the empty type (e.g. `chr(0)`), or NULL. +#' If a field is not present in the JSON file, and is not marked as optional, an +#' error will be thrown. #' @return An instance of the `Config` class populated with the data from the #' JSON file. #' @export -read_json_into_config <- function(config_path) { +read_json_into_config <- function(config_path, optional_fields) { # First, our hard coded, flattened, map from strings to Classes. If any new # subclasses are added above, they will also need to be added here. If we # create a more automated way to do this, we can remove this. @@ -211,11 +216,19 @@ read_json_into_config <- function(config_path) { # Check what top level properties were not in the raw input missing_properties <- setdiff(S7::prop_names(Config()), names(raw_input)) + # Remove any optional fields from the missing properties, give info message + # about what is being given a default arg. + not_need_but_missing <- intersect(optional_fields, missing_properties) + if (length(not_need_but_missing) > 0) { + cli::cli_alert_info( + "Optional field{?s} not in config file: {.var {not_need_but_missing}}" + ) + } + missing_properties <- setdiff(missing_properties, optional_fields) # Error out if missing any fields if (length(missing_properties) > 0) { - cli::cli_alert_info(c( - "The following expected propert{?y/ies} were not in the config file:", - "{.var {missing_properties}}" + cli::cli_abort(c( + "Propert{?y/ies} not in the config file: {.var {missing_properties}}" )) } diff --git a/R/pipeline.R b/R/pipeline.R index cbb1f81..436991c 100644 --- a/R/pipeline.R +++ b/R/pipeline.R @@ -72,7 +72,7 @@ orchestrate_pipeline <- function(config_path, blob_storage_container = NULL, output_dir = "/") { config <- rlang::try_fetch( - read_json_into_config(config_path), + read_json_into_config(config_path, c("exclusions")), error = function(con) { cli::cli_warn("Bad config file", parent = con, diff --git a/man/Config.Rd b/man/Config.Rd index 1809dbe..bc23f6e 100644 --- a/man/Config.Rd +++ b/man/Config.Rd @@ -21,136 +21,136 @@ Config( data = Data(), priors = list(rt = list(mean = structure(list(classes = list(structure(list(class = "integer", constructor_name = "integer", constructor = function (.data = integer(0)) - - .data, validator = function (object) + + .data, validator = function (object) { if (base_class(object) != name) { - + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) - + } }), class = "S7_base_class"), structure(list(class = "double", constructor_name = - "double", constructor = function (.data = numeric(0)) + "double", constructor = function (.data = numeric(0)) .data, validator = function - (object) + (object) { if (base_class(object) != name) { - + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) } }), class = "S7_base_class"))), class = "S7_union"), sd = structure(list(classes = list(structure(list(class = "integer", constructor_name = "integer", constructor = - function (.data = integer(0)) - .data, validator = function (object) + function (.data = integer(0)) + .data, validator = function (object) { if (base_class(object) != name) { - + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) } }), class = "S7_base_class"), structure(list(class = "double", constructor_name = - "double", constructor = function (.data = numeric(0)) + "double", constructor = function (.data = numeric(0)) .data, validator = function - (object) + (object) { if (base_class(object) != name) { - + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) } }), class = "S7_base_class"))), class = "S7_union")), gp = list(alpha_sd = structure(list(classes = list(structure(list(class = "integer", constructor_name = - "integer", constructor = function (.data = integer(0)) + "integer", constructor = function (.data = integer(0)) .data, validator = function - (object) + (object) { if (base_class(object) != name) { - + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) } }), class = "S7_base_class"), structure(list(class = "double", constructor_name = - "double", constructor = function (.data = numeric(0)) + "double", constructor = function (.data = numeric(0)) .data, validator = function - (object) + (object) { if (base_class(object) != name) { - + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) } }), class = "S7_base_class"))), class = "S7_union"))), parameters = Parameters(), sampler_opts = list(cores = structure(list(class = "integer", constructor_name = - "integer", constructor = function (.data = integer(0)) + "integer", constructor = function (.data = integer(0)) .data, validator = function - (object) + (object) { if (base_class(object) != name) { - + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) } }), class = "S7_base_class"), chains = structure(list(class = "integer", - constructor_name = "integer", constructor = function (.data = integer(0)) + constructor_name = "integer", constructor = function (.data = integer(0)) .data, - validator = function (object) + validator = function (object) { if (base_class(object) != name) { - + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) } }), class = "S7_base_class"), iter_warmup = structure(list(class = "integer", - constructor_name = "integer", constructor = function (.data = integer(0)) + constructor_name = "integer", constructor = function (.data = integer(0)) .data, - validator = function (object) + validator = function (object) { if (base_class(object) != name) { - + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) } }), class = "S7_base_class"), iter_sampling = structure(list(class = "integer", - constructor_name = "integer", constructor = function (.data = integer(0)) + constructor_name = "integer", constructor = function (.data = integer(0)) .data, - validator = function (object) + validator = function (object) { if (base_class(object) != name) { - + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) } }), class = "S7_base_class"), max_treedepth = structure(list(class = "integer", - constructor_name = "integer", constructor = function (.data = integer(0)) + constructor_name = "integer", constructor = function (.data = integer(0)) .data, - validator = function (object) + validator = function (object) { if (base_class(object) != name) { - + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) } }), class = "S7_base_class"), adapt_delta = structure(list(classes = list(structure(list(class = "integer", constructor_name = "integer", constructor = - function (.data = integer(0)) - .data, validator = function (object) + function (.data = integer(0)) + .data, validator = function (object) { if (base_class(object) != name) { - + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) } }), class = "S7_base_class"), structure(list(class = "double", constructor_name = - "double", constructor = function (.data = numeric(0)) + "double", constructor = function (.data = numeric(0)) .data, validator = function - (object) + (object) { if (base_class(object) != name) { - + sprintf("Underlying data must be <\%s> not <\%s>", name, base_class(object)) } diff --git a/man/read_json_into_config.Rd b/man/read_json_into_config.Rd index a01d435..f1e085e 100644 --- a/man/read_json_into_config.Rd +++ b/man/read_json_into_config.Rd @@ -4,11 +4,17 @@ \alias{read_json_into_config} \title{Read JSON Configuration into Config Object} \usage{ -read_json_into_config(config_path) +read_json_into_config(config_path, optional_fields) } \arguments{ \item{config_path}{A string specifying the path to the JSON configuration file.} + +\item{optional_fields}{A list of strings specifying the optional fields in +the JSON file. If a field is not present in the JSON file, and is marked as +optional, it will be set to either the empty type (e.g. \code{chr(0)}), or NULL. +If a field is not present in the JSON file, and is not marked as optional, an +error will be thrown.} } \value{ An instance of the \code{Config} class populated with the data from the diff --git a/tests/testthat/data/sample_config_no_exclusion.json b/tests/testthat/data/sample_config_no_exclusion.json index 15725a3..2d81b7e 100644 --- a/tests/testthat/data/sample_config_no_exclusion.json +++ b/tests/testthat/data/sample_config_no_exclusion.json @@ -46,5 +46,6 @@ "iter_sampling": 50, "adapt_delta": 0.99, "max_treedepth": 12 - } + }, + "config_version": "0.1.0" } diff --git a/tests/testthat/data/sample_config_with_exclusion.json b/tests/testthat/data/sample_config_with_exclusion.json index c92e61b..5e8e61b 100644 --- a/tests/testthat/data/sample_config_with_exclusion.json +++ b/tests/testthat/data/sample_config_with_exclusion.json @@ -50,5 +50,6 @@ "iter_sampling": 50, "adapt_delta": 0.99, "max_treedepth": 12 - } + }, + "config_version": "0.1.0" } diff --git a/tests/testthat/data/v_bad_config.json b/tests/testthat/data/v_bad_config.json new file mode 100644 index 0000000..69931c1 --- /dev/null +++ b/tests/testthat/data/v_bad_config.json @@ -0,0 +1,4 @@ +{ + "job_id": "6183da58-89bc-455f-8562-4f607257a876", + "task_id": "bc0c3eb3-7158-4631-a2a9-86b97357f97e" +} \ No newline at end of file diff --git a/tests/testthat/test-pipeline.R b/tests/testthat/test-pipeline.R index 72b2338..3085d4c 100644 --- a/tests/testthat/test-pipeline.R +++ b/tests/testthat/test-pipeline.R @@ -14,7 +14,7 @@ test_that("Bad config throws warning and returns failure", { blob_storage_container = blob_storage_container, output_dir = output_dir ), - class = "Run_failed" + class = "Bad_config" ) expect_false(pipeline_success) }) @@ -71,7 +71,7 @@ test_that("Pipeline run produces expected outputs with exclusions", { test_that("Process pipeline produces expected outputs and returns success", { # Arrange config_path <- test_path("data", "sample_config_with_exclusion.json") - config <- read_json_into_config(config_path) + config <- read_json_into_config(config_path, c("exclusions")) # Read from locally output_dir <- "pipeline_test" on.exit(unlink(output_dir, recursive = TRUE)) @@ -95,7 +95,7 @@ test_that("Process pipeline produces expected outputs and returns success", { test_that("Runs on config from generator as of 2024-11-26", { # Arrange config_path <- test_path("data", "CA_COVID-19.json") - config <- read_json_into_config(config_path) + config <- read_json_into_config(config_path, c("exclusions")) # Read from locally output_dir <- "pipeline_test" on.exit(unlink(output_dir, recursive = TRUE)) @@ -115,3 +115,22 @@ test_that("Runs on config from generator as of 2024-11-26", { config@task_id ) }) + +test_that("Warning and exit for bad config file", { + # Arrange + config_path <- test_path("data", "v_bad_config.json") + # Read from locally + output_dir <- "pipeline_test" + on.exit(unlink(output_dir, recursive = TRUE)) + + # Act + expect_warning( + pipeline_success <- orchestrate_pipeline( + config_path = config_path, + blob_storage_container = NULL, + output_dir = output_dir + ), + class = "Bad_config" + ) + expect_false(pipeline_success) +})