Skip to content

Commit

Permalink
make cases prediction optional, switch markdown to use function, impr…
Browse files Browse the repository at this point in the history
…ove guidance
  • Loading branch information
simon-smart88 committed Nov 13, 2024
1 parent c83b400 commit f7cd64f
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 38 deletions.
25 changes: 19 additions & 6 deletions R/pred_pred_f.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#' @param fit disag_model. Object returned by disag_model function that
#' contains all the necessary objects for generating predictions.
#' @param aggregation SpatRaster. The aggregation raster
#' @param cases logical. Whether to predictions of cases
#' @param predict_iid logical. Whether to generate predictions including the iid effect
#' @param uncertain logical. Whether or not to generate upper and lower credible
#' intervals
Expand All @@ -20,31 +21,43 @@
#' @author Simon Smart <simon.smart@@cantab.net>
#' @export

pred_pred <- function(fit, aggregation, predict_iid, uncertain = FALSE, N = NULL, CI = NULL, async = FALSE){
pred_pred <- function(fit, aggregation, cases, predict_iid, uncertain = FALSE, N = NULL, CI = NULL, async = FALSE){

if (async){
fit$data$covariate_rasters <- terra::unwrap(fit$data$covariate_rasters)
aggregation <- terra::unwrap(aggregation)
}

prediction <- disaggregation::predict_model(fit, predict_iid = predict_iid)
prediction$cases <- prediction$prediction * aggregation

if (cases){
prediction$cases <- prediction$prediction * aggregation
}

if (!is.null(prediction$field)){
terra::crs(prediction$field) <- terra::crs(fit$data$covariate_rasters[[1]])
prediction$field <- terra::mask(prediction$field, fit$data$covariate_rasters[[1]])
}

if (!is.null(prediction$iid)){
prediction$iid <- terra::mask(prediction$iid, fit$data$covariate_rasters[[1]])
}

if (uncertain){
prediction$uncertainty <- disaggregation::predict_uncertainty(fit, predict_iid = predict_iid, N = N, CI = CI)
}

names(prediction)[which(names(prediction) == "prediction")] <- "prediction (rate)"
names(prediction)[which(names(prediction) == "cases")] <- "prediction (cases)"
if (cases){
names(prediction)[which(names(prediction) == "cases")] <- "prediction (cases)"
}


if (async){
prediction$`prediction (rate)` <- terra::wrap(prediction$`prediction (rate)`)
prediction$`prediction (cases)` <- terra::wrap(prediction$`prediction (cases)`)
if (cases){
prediction$`prediction (cases)` <- terra::wrap(prediction$`prediction (cases)`)
}
prediction$covariates <- terra::wrap(prediction$covariates)

if (!is.null(prediction$field)){
Expand All @@ -54,8 +67,8 @@ pred_pred <- function(fit, aggregation, predict_iid, uncertain = FALSE, N = NULL
prediction$iid <- terra::wrap(prediction$iid)
}
if (uncertain){
prediction$uncertainty$predictions_ci$`lower CI` <- terra::wrap(prediction$uncertainty$predictions_ci$`lower CI`)
prediction$uncertainty$predictions_ci$`upper CI` <- terra::wrap(prediction$uncertainty$predictions_ci$`upper CI`)
prediction$uncertainty_lower <- terra::wrap(prediction$uncertainty$predictions_ci$`lower CI`)
prediction$uncertainty_upper <- terra::wrap(prediction$uncertainty$predictions_ci$`upper CI`)
}
}

Expand Down
4 changes: 2 additions & 2 deletions inst/shiny/modules/core_load.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ core_load_module_server <- function(id, common, modules, map, COMPONENT_MODULES,
common$transfer$prediction <- unwrap_terra(common$transfer$prediction)
common$transfer$field <- unwrap_terra(common$transfer$field)
common$transfer$covariates <- unwrap_terra(common$transfer$covariates)
common$pred$uncertainty$predictions_ci$`lower CI` <- unwrap_terra(common$pred$uncertainty$predictions_ci$`lower CI`)
common$pred$uncertainty$predictions_ci$`upper CI` <- unwrap_terra(common$pred$uncertainty$predictions_ci$`upper CI`)
common$pred$uncertainty_lower <- unwrap_terra(common$pred$uncertainty_lower)
common$pred$uncertainty_upper <- unwrap_terra(common$pred$uncertainty_upper)

#restore map and results for used modules
for (used_module in names(common$meta)){
Expand Down
8 changes: 4 additions & 4 deletions inst/shiny/modules/core_save.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ core_save_module_server <- function(id, common, modules, COMPONENTS, main_input)
common$transfer$prediction <- wrap_terra(common$transfer$prediction)
common$transfer$field <- wrap_terra(common$transfer$field)
common$transfer$covariates <- wrap_terra(common$transfer$covariates)
common$pred$uncertainty$predictions_ci$`lower CI` <- wrap_terra(common$pred$uncertainty$predictions_ci$`lower CI`)
common$pred$uncertainty$predictions_ci$`upper CI` <- wrap_terra(common$pred$uncertainty$predictions_ci$`upper CI`)
common$pred$uncertainty_lower <- wrap_terra(common$pred$uncertainty_lower)
common$pred$uncertainty_upper <- wrap_terra(common$pred$uncertainty_upper)

#save file
saveRDS(common, file)
Expand All @@ -92,8 +92,8 @@ core_save_module_server <- function(id, common, modules, COMPONENTS, main_input)
common$transfer$prediction <- unwrap_terra(common$transfer$prediction)
common$transfer$field <- unwrap_terra(common$transfer$field)
common$transfer$covariates <- unwrap_terra(common$transfer$covariates)
common$pred$uncertainty$predictions_ci$`lower CI` <- unwrap_terra(common$pred$uncertainty$predictions_ci$`lower CI`)
common$pred$uncertainty$predictions_ci$`upper CI` <- unwrap_terra(common$pred$uncertainty$predictions_ci$`upper CI`)
common$pred$uncertainty_lower <- unwrap_terra(common$pred$uncertainty_lower)
common$pred$uncertainty_upper <- unwrap_terra(common$pred$uncertainty_upper)

close_loading_modal()
}
Expand Down
21 changes: 14 additions & 7 deletions inst/shiny/modules/pred_pred.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pred_pred_module_ui <- function(id) {
ns <- shiny::NS(id)
tagList(
uiOutput(ns("iid_out")),
shinyWidgets::materialSwitch(ns("cases"), "Include cases?", FALSE, status = "success"),
shinyWidgets::materialSwitch(ns("uncertain"), "Include uncertainty?", FALSE, status = "success"),
conditionalPanel("input.uncertain === true", ns = ns,
tags$label("Uncertainty parameters"),
Expand Down Expand Up @@ -67,12 +68,15 @@ pred_pred_module_server <- function(id, common, parent_session, map) {
if (!input$uncertain){
common$tasks$pred_pred$invoke(fit = common$fit,
aggregation = common[[aggregation]],
cases = input$cases,
predict_iid = predict_iid,
async = TRUE)
} else {
common$tasks$pred_pred$invoke(fit = common$fit,
aggregation = common[[aggregation]],
cases = input$cases,
predict_iid = predict_iid,
uncertain = input$uncertain,
N = input$uncertain_n,
CI = input$uncertain_ci,
async = TRUE)
Expand All @@ -84,6 +88,8 @@ pred_pred_module_server <- function(id, common, parent_session, map) {

# METADATA ####
common$meta$pred_pred$used <- TRUE
common$meta$pred_pred$cases <- input$cases

if (input$uncertain){
common$meta$pred_pred$uncertain <- input$uncertain
common$meta$pred_pred$uncertain_n <- input$uncertain_n
Expand All @@ -100,16 +106,15 @@ pred_pred_module_server <- function(id, common, parent_session, map) {
results <- observe({
# LOAD INTO COMMON ####

common$pred<- common$tasks$pred_pred$result()
common$pred <- common$tasks$pred_pred$result()
results$suspend()

common$pred$field <- unwrap_terra(common$pred$field)
common$pred$`prediction (rate)` <- unwrap_terra(common$pred$`prediction (rate)`)
common$pred$`prediction (cases)` <- unwrap_terra(common$pred$`prediction (cases)`)
common$pred$covariates <- unwrap_terra(common$pred$covariates)
common$pred$iid <- unwrap_terra(common$pred$iid)
common$pred$uncertainty$predictions_ci$`lower CI` <- unwrap_terra(common$pred$uncertainty$predictions_ci$`lower CI`)
common$pred$uncertainty$predictions_ci$`upper CI` <- unwrap_terra(common$pred$uncertainty$predictions_ci$`upper CI`)
common$pred$uncertainty_lower <- unwrap_terra(common$pred$uncertainty_lower)
common$pred$uncertainty_upper <- unwrap_terra(common$pred$uncertainty_upper)

common$logger |> writeLog(type = "complete", "Model predictions are available")
# TRIGGER
Expand Down Expand Up @@ -152,6 +157,7 @@ pred_pred_module_server <- function(id, common, parent_session, map) {
save = function() {list(
### Manual save start
### Manual save end
cases = input$cases,
uncertain_n = input$uncertain_n,
uncertain_ci = input$uncertain_ci,
uncertain = input$uncertain,
Expand All @@ -164,6 +170,7 @@ pred_pred_module_server <- function(id, common, parent_session, map) {
updateNumericInput(session, "uncertain_ci", value = state$uncertain_ci)
shinyWidgets::updateMaterialSwitch(session, "uncertain", value = state$uncertain)
shinyWidgets::updateMaterialSwitch(session, "iid", value = state$iid)
shinyWidgets::updateMaterialSwitch(session, "cases", value = state$cases)
}
))
})
Expand All @@ -176,8 +183,8 @@ pred_pred_module_map <- function(map, common) {
}
}
if (!is.null(common$meta$pred_pred$uncertain)){
raster_map(map, common, common$pred$uncertainty$predictions_ci$`lower CI`, "Lower credible interval")
raster_map(map, common, common$pred$uncertainty$predictions_ci$`upper CI`, "Upper credible interval")
raster_map(map, common, common$pred$uncertainty_lower, "Lower credible interval")
raster_map(map, common, common$pred$uncertainty_upper, "Upper credible interval")
}
}

Expand All @@ -186,7 +193,7 @@ pred_pred_module_rmd <- function(common) {
list(
pred_knit = !is.null(common$meta$pred_pred$used),
pred_iid = common$meta$pred_pred$iid,
pred_uncertain_knit = !is.null(common$meta$pred_pred$uncertain),
pred_cases = common$meta$pred_pred$cases,
pred_uncertain = common$meta$pred_pred$uncertain,
pred_uncertain_n = common$meta$pred_pred$uncertain_n,
pred_uncertain_ci = common$meta$pred_pred$uncertain_ci
Expand Down
35 changes: 18 additions & 17 deletions inst/shiny/modules/pred_pred.Rmd
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
```{asis, echo = {{pred_knit}}, eval = {{pred_knit}}, include = {{pred_knit}}}
Generate predictions from the model
Generate predictions from the model and plot the generated rasters
```

```{r, echo = {{pred_knit}}, include = {{pred_knit}}}
prediction <- disaggregation::predict_model(fitted_model, predict_iid = {{pred_iid}})
plot(prediction$prediction)
prediction$cases <- prediction$prediction * aggregation_prepared
plot(prediction$cases)
prediction <- disagapp::pred_pred(fitted_model,
aggregation_prepared,
cases = {{pred_cases}},
predict_iid = {{pred_iid}},
uncertain = {{pred_uncertain}},
N = {{pred_uncertain_n}},
CI = {{pred_uncertain_ci}})
plot(prediction$`prediction (rate)`)
if (!is.null(prediction$`prediction (cases)`)){
plot(prediction$`prediction (cases)`)
}
if (!is.null(prediction$field)){
terra::crs(prediction$field) <- terra::crs(prepared_data$covariate_rasters[[1]])
prediction$field <- terra::mask(prediction$field, prepared_data$covariate_rasters[[1]])
plot(prediction$field)
}
if (!is.null(prediction$iid)){
plot(prediction$iid)
}
```

```{r, echo = {{pred_uncertain_knit}}, include = {{pred_uncertain_knit}}}
uncertainty <- disaggregation::predict_uncertainty(fitted_model, predict_iid = {{pred_iid}}, N = {{pred_uncertain_n}}, CI = {{pred_uncertain_ci}})
if (!is.null(prediction$uncertainty)){
plot(uncertainty$predictions_ci$`lower CI`)
plot(uncertainty$predictions_ci$`upper CI`)
}
plot(uncertainty$predictions_ci$`lower CI`)
plot(uncertainty$predictions_ci$`upper CI`)
```

4 changes: 2 additions & 2 deletions inst/shiny/modules/pred_pred.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

**BACKGROUND**

This module generates predictions from the fitted model for the area of interest.
This module generates predictions from the fitted model for the area of interest. You can choose to convert the rate predictions to cases and to generate uncertainty predictions by toggling the switches. Once you have generated the predictions you can download a `.zip` file of them.

**IMPLEMENTATION**

Click "Produce model predictions" and the output will be visible on the map.
This module uses `disaggregation::predict_model()` to generate the rate predictions and `disaggregation::predict_uncertainty()` to generate the uncertainty predictions. The cases are generated by multiplying the rate prediction by the aggregation raster. If you have included an IID effect in the model then you can choose whether or not to include this term when generating the predictions. If you choose to generate uncertainty predictions you can select the credible interval of the uncertainty and the number of iterations used to generate the uncertainty predictions. Once you have made your selection about which predictions to generate, click "Produce model predictions" and the output will be visible on the map.
3 changes: 3 additions & 0 deletions tests/testthat/test-complete_analysis.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ test_that("{shinytest2} recording: e2e_complete_analysis", {

app$set_inputs(tabs = "pred")
app$set_inputs(predSel = "pred_pred")
app$set_inputs("pred_pred-cases" = TRUE)
app$set_inputs("pred_pred-iid" = TRUE)
app$set_inputs("pred_pred-uncertain" = TRUE)
app$click(selector = "#pred_pred-run")
app$wait_for_value(input = "pred_pred-complete")
app$set_inputs(main = "Save")
Expand Down

0 comments on commit f7cd64f

Please sign in to comment.