Skip to content

Commit

Permalink
Update output schema
Browse files Browse the repository at this point in the history
  • Loading branch information
zsusswein committed Oct 1, 2024
1 parent 8161d43 commit 54338ed
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 31 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ repos:
entry: Cannot commit .Rhistory, .RData, .Rds or .rds.
language: fail
files: '\.(Rhistory|RData|Rds|rds)$'
exclude: '^tests/testthat/data/.*\.rds$'
# `exclude: <regex>` to allow committing specific files
#####
# Python
Expand Down
16 changes: 13 additions & 3 deletions R/extract_diagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#' @param data A data frame containing the input data used in the model fit.
#' @param job_id A unique identifier for the job or task being processed.
#' @param task_id A unique identifier for the task being performed.
#'
#' @param disease,geo_value,model Metadata for downstream processing.
#'
#' @return A \code{data.frame} containing the extracted diagnostic metrics. The
#' data frame includes the following columns:
Expand All @@ -23,6 +23,7 @@
#' \item \code{disease}: The disease/pathogen being analyzed.
#' \item \code{job_id}: The unique identifier for the job.
#' \item \code{task_id}: The unique identifier for the task.
#' \item \code{disease,geo_value,model}: Metadata for downstream processing.
#' }
#'
#' @details
Expand All @@ -43,7 +44,13 @@
#' any diagnostic thresholds are exceeded.
#' }
#' @export
extract_diagnostics <- function(fit, data, job_id, task_id) {
extract_diagnostics <- function(fit,
data,
job_id,
task_id,
disease,
geo_value,
model) {
low_case_count <- low_case_count_diagnostic(data)

epinow2_diagnostics <- rstan::get_sampler_params(fit$estimates$fit,
Expand Down Expand Up @@ -94,7 +101,10 @@ extract_diagnostics <- function(fit, data, job_id, task_id) {
diagnostic = diagnostic_names,
value = diagnostic_values,
job_id = job_id,
task_id = task_id
task_id = task_id,
disease = disease,
geo_value = geo_value,
model = model
)
}

Expand Down
35 changes: 28 additions & 7 deletions R/write_output.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ write_model_outputs <- function(
job_id,
"tasks",
task_id,
"model.RDS"
"model.rds"
)
saveRDS(fit, model_path)
cli::cli_alert_success("Wrote model to {.path {model_path}}")
Expand Down Expand Up @@ -193,7 +193,12 @@ extract_draws_from_fit <- function(fit) {
#' @return A data.table with merged posterior draws and standardized parameter
#' names.
#' @noRd
post_process_and_merge <- function(draws, fact_table) {
post_process_and_merge <- function(
draws,
fact_table,
geo_value,
model,
disease) {
# Step 1: Left join the date-time-parameter map onto the Stan draws
merged_dt <- merge(
draws,
Expand All @@ -218,13 +223,18 @@ post_process_and_merge <- function(draws, fact_table) {
".point", ".interval", "date", ".iteration"
),
new = c(
"_draw", "_chain", "_variable", "_value", "_lower", "_upper", "_width",
"_draw", "_chain", "_variable", "value", "_lower", "_upper", "_width",
"_point", "_interval", "reference_date", "_iteration"
),
# If using summaries, skip draws-specific names
skip_absent = TRUE
)

# Metadata for downstream querying without path parsing or joins
data.table::set(merged_dt, j = "geo_value", value = factor(geo_value))
data.table::set(merged_dt, j = "model", value = factor(model))
data.table::set(merged_dt, j = "disease", value = factor(disease))

return(merged_dt)
}

Expand All @@ -236,6 +246,7 @@ post_process_and_merge <- function(draws, fact_table) {
#' returned in `{tidybayes}` format.
#'
#' @param fit An EpiNow2 fit object with posterior estimates.
#' @param disease,geo_value,model Metadata for downstream processing.
#'
#' @return A data.table of posterior draws or quantiles, merged and processed.
#'
Expand All @@ -244,17 +255,21 @@ NULL

#' @rdname sample_processing_functions
#' @export
process_samples <- function(fit) {
process_samples <- function(fit, geo_value, model, disease) {
draws_list <- extract_draws_from_fit(fit)
raw_processed_output <- post_process_and_merge(
draws_list$stan_draws, draws_list$fact_table
draws_list$stan_draws,
draws_list$fact_table,
geo_value,
model,
disease
)
return(raw_processed_output)
}

#' @rdname sample_processing_functions
#' @export
process_quantiles <- function(fit) {
process_quantiles <- function(fit, geo_value, model, disease) {
# Step 1: Extract the draws
draws_list <- extract_draws_from_fit(fit)

Expand All @@ -268,7 +283,13 @@ process_quantiles <- function(fit) {
data.table::as.data.table()

# Step 3: Post-process summarized draws
post_process_and_merge(summarized_draws, draws_list$fact_table)
post_process_and_merge(
summarized_draws,
draws_list$fact_table,
geo_value,
model,
disease
)
}

write_parquet <- function(data, path) {
Expand Down
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,20 @@ This package implements functions for:

## Output format

Outputs are stored in a s

```bash
output/
├── job_<job_id>/
├── <job_id>/
│ ├── raw_samples/
│ │ ├── raw_samples_task_<task_id>.parquet
│ ├── summarized_quantiles/
│ │ ├── summarized_quantiles_task_<task_id>.parquet
│ ├── tasks/
│ │ ├── task_<task_id>/
│ │ ├── <task_id>/
│ │ │ ├── model.stan
│ │ │ ├── metadata.json
│ │ │ ├── task.log
│ │ │ ├── logs.txt
│ │ │ └── error.log
│ ├── job_metadata.json
```
Expand Down
5 changes: 4 additions & 1 deletion man/extract_diagnostics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions man/sample_processing_functions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

File renamed without changes.
3 changes: 2 additions & 1 deletion tests/testthat/helper-write_parameter_file.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ write_sample_parameters_file <- function(value,
parameter,
start_date,
end_date) {
Sys.sleep(0.05)
df <- data.frame(
start_date = as.Date(start_date),
geo_value = state,
Expand All @@ -16,6 +17,7 @@ write_sample_parameters_file <- function(value,
)

con <- DBI::dbConnect(duckdb::duckdb())
on.exit(DBI::dbDisconnect(con))

duckdb::duckdb_register(con, "test_table", df)
# This is bad practice but `dbBind()` doesn't allow us to parameterize COPY
Expand All @@ -24,7 +26,6 @@ write_sample_parameters_file <- function(value,
# guard against a SQL injection attack.
query <- paste0("COPY (SELECT * FROM test_table) TO '", path, "'")
DBI::dbExecute(con, query)
DBI::dbDisconnect(con)

invisible(path)
}
15 changes: 13 additions & 2 deletions tests/testthat/test-extract_diagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ test_that("Fitted model extracts diagnostics", {
params = list(data_path)
)
DBI::dbDisconnect(con)
fit_path <- test_path("data", "sample_fit.RDS")
fit_path <- test_path("data", "sample_fit.rds")
fit <- readRDS(fit_path)

# Expected diagnostics
Expand All @@ -37,9 +37,20 @@ test_that("Fitted model extracts diagnostics", {
),
job_id = rep("test", 6),
task_id = rep("test", 6),
disease = rep("test", 6),
geo_value = rep("test", 6),
model = rep("test", 6),
stringsAsFactors = FALSE
)
actual <- extract_diagnostics(fit, data, "test", "test")
actual <- extract_diagnostics(
fit,
data,
"test",
"test",
"test",
"test",
"test"
)

testthat::expect_equal(
actual,
Expand Down
40 changes: 28 additions & 12 deletions tests/testthat/test-write_output.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ test_that("write_model_outputs writes files and directories correctly", {
)
expect_true(file.exists(summarized_file))

# Check if model RDS file was written
# Check if model rds file was written
model_file <- file.path(
job_id,
"tasks",
task_id,
"model.RDS"
"model.rds"
)
expect_true(file.exists(model_file))

Expand Down Expand Up @@ -122,10 +122,10 @@ test_that("write_output_dir_structure generates dirs", {

test_that("process_quantiles works as expected", {
# Load the sample fit object
fit <- readRDS(test_path("data", "sample_fit.RDS"))
fit <- readRDS(test_path("data", "sample_fit.rds"))

# Run the function on the fit object
result <- process_quantiles(fit)
result <- process_quantiles(fit, "test_geo", "test_model", "test_disease")

# Test 1: Check if the result is a data.table
expect_true(
Expand All @@ -135,9 +135,18 @@ test_that("process_quantiles works as expected", {

# Test 2: Check if the necessary columns exist in the result
expected_columns <- c(
"time", "_variable", "_value",
"_lower", "_upper", "_width",
"_point", "_interval", "reference_date"
"time",
"_variable",
"value",
"_lower",
"_upper",
"_width",
"_point",
"_interval",
"reference_date",
"geo_value",
"model",
"disease"
)
expect_equal(
colnames(result), expected_columns
Expand Down Expand Up @@ -187,10 +196,10 @@ test_that("process_quantiles works as expected", {

test_that("process_samples works as expected", {
# Load the sample fit object
fit <- readRDS(test_path("data", "sample_fit.RDS"))
fit <- readRDS(test_path("data", "sample_fit.rds"))

# Run the function on the fit object
result <- process_samples(fit)
result <- process_samples(fit, "test_geo", "test_model", "test_disease")

# Test 1: Check if the result is a data.table
expect_true(
Expand All @@ -200,9 +209,16 @@ test_that("process_samples works as expected", {

# Test 2: Check if the necessary columns exist in the result
expected_columns <- c(
"time", "_variable", "_chain",
"_iteration", "_draw", "_value",
"reference_date"
"time",
"_variable",
"_chain",
"_iteration",
"_draw",
"value",
"reference_date",
"geo_value",
"model",
"disease"
)
expect_equal(
colnames(result), expected_columns
Expand Down

0 comments on commit 54338ed

Please sign in to comment.