Skip to content

Commit

Permalink
completing first version suitability model
Browse files Browse the repository at this point in the history
  • Loading branch information
gilesjohnr committed Oct 6, 2024
1 parent 31862f5 commit 8e561fe
Show file tree
Hide file tree
Showing 23 changed files with 152,466 additions and 151,940 deletions.
119 changes: 104 additions & 15 deletions R/est_suitability.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
#' }
#'
#' @note
#' The LSTM model uses lagged climate variables to predict cholera suitability. The model's predictions are saved as
#' The LSTM model uses climate variables to predict cholera suitability. The model's predictions are saved as
#' a CSV file and a plot showing the model fit (accuracy and loss) is generated.
#'
#' @seealso
Expand All @@ -77,6 +77,8 @@ est_suitability <- function(PATHS, include_lagged_covariates=FALSE) {
elevation_data <- read.csv(path, stringsAsFactors = FALSE)
d_all <- merge(d_all, elevation_data[c("iso_code", "elevation")], by="iso_code")

if (53 %in% d_all$week) stop("week index is out of bounds")

message("Adding covariates...")
covariates <- c(
"temperature_2m_mean", "temperature_2m_max", "temperature_2m_min",
Expand All @@ -86,7 +88,8 @@ est_suitability <- function(PATHS, include_lagged_covariates=FALSE) {
"dew_point_2m_mean", "dew_point_2m_min", "dew_point_2m_max",
"precipitation_sum", "snowfall_sum", "pressure_msl_mean",
"soil_moisture_0_to_10cm_mean", "et0_fao_evapotranspiration_sum",
"DMI", "ENSO3", "ENSO34", "ENSO4", "elevation"
"DMI", "ENSO3", "ENSO34", "ENSO4", "elevation"#,
#"year", "month", "week"
)

X_all <- d_all[, colnames(d_all) %in% covariates]
Expand All @@ -106,8 +109,8 @@ est_suitability <- function(PATHS, include_lagged_covariates=FALSE) {
# Step 3: Standardize the features (covariates)
if (include_lagged_covariates) {

X_all_lagged <- MOSAIC::make_lagged_data(X_all, lags=1:3)
X_lagged <- MOSAIC::make_lagged_data(X, lags=1:3)
X_all_lagged <- MOSAIC::make_lagged_data(X_all, lags=1:7)
X_lagged <- MOSAIC::make_lagged_data(X, lags=1:7)

X_all_scaled <- scale(X_all_lagged)
X_scaled <- scale(X_lagged)
Expand Down Expand Up @@ -146,33 +149,123 @@ est_suitability <- function(PATHS, include_lagged_covariates=FALSE) {
cat('Number of samples in test set:', dim(X_test)[1], '\n')
cat('Number of positive samples in test set:', sum(y_test == 1), '\n')



message("Compiling LSTM model...")
# Define an exponential decay schedule for learning rate
lr_schedule <- tf$keras$optimizers$schedules$ExponentialDecay(
initial_learning_rate = 0.001,
decay_steps = 10000,
decay_rate = 0.9
)

# Step 6: Build and Compile the LSTM model
model <- keras_model_sequential() %>%
layer_lstm(units = 200, input_shape = c(timesteps, n_features), return_sequences = FALSE,
layer_lstm(units = 500, input_shape = c(timesteps, n_features), return_sequences = TRUE,
kernel_regularizer = regularizer_l2(0.001)) %>%
layer_dropout(rate = 0.2) %>%
layer_dropout(rate = 0.5) %>%
layer_lstm(units = 250, return_sequences = TRUE, kernel_regularizer = regularizer_l2(0.001)) %>%
layer_dropout(rate = 0.5) %>%
layer_lstm(units = 200, return_sequences = FALSE, kernel_regularizer = regularizer_l2(0.001)) %>%
layer_dropout(rate = 0.5) %>%
layer_dense(units = 1, activation = 'sigmoid')

# Compile the model with the learning rate schedule
model %>% compile(
optimizer = optimizer_adam(learning_rate = 0.001),
optimizer = optimizer_adam(learning_rate = lr_schedule), # Using the exponential decay schedule
loss = 'binary_crossentropy',
metrics = 'accuracy'
)

print(model)

message("Training LSTM model...")
# Step 7: Train the model
history <- model %>% fit(
X_train,
y_train_array,
epochs = 100,
batch_size = 512,
validation_split = 0.2
epochs = 200,
batch_size = 1024,
validation_split = 0.2,
callbacks = list(callback_early_stopping(patience = 10)) # Early stopping for stability
)

# Step 8: Evaluate the model
message("Calculating overall model fit...")
score <- model %>% evaluate(X_test, y_test_array)
cat('Test loss:', score$loss, '\n')
cat('Test accuracy:', score$acc, '\n')




df <- data.frame(
epoch = 1:length(history$metrics$loss),
loss = history$metrics$loss,
val_loss = history$metrics$val_loss,
accuracy = history$metrics$accuracy,
val_accuracy = history$metrics$val_accuracy
)

# Step 8: Evaluate the model and get final test loss and accuracy
message("Calculating overall model fit...")
score <- model %>% evaluate(X_test, y_test_array)
final_loss <- score$loss
final_accuracy <- score$acc

cat('Test loss:', final_loss, '\n')
cat('Test accuracy:', final_accuracy, '\n')



df <- data.frame(
epoch = 1:length(history$metrics$loss),
loss = history$metrics$loss,
val_loss = history$metrics$val_loss,
accuracy = history$metrics$accuracy,
val_accuracy = history$metrics$val_accuracy
)

loss_plot <-
ggplot(df, aes(x = epoch)) +
geom_line(aes(y = loss, color = "Training Loss")) +
geom_line(aes(y = val_loss, color = "Validation Loss")) +
labs(x = "Epoch", y = "Loss", title = "Training and Validation Loss") +
scale_color_manual(name = "",
values = c("Training Loss" = "blue3",
"Validation Loss" = "red3")) +
theme_bw() + # White background
theme(legend.position = "bottom") +
annotate("text", x = max(df$epoch), y = max(df$loss),
label = paste("Final Test Loss:", round(final_loss, 2)),
vjust=5, hjust=1, color = "blue3")

accuracy_plot <-
ggplot(df, aes(x = epoch)) +
geom_line(aes(y = accuracy, color = "Training Accuracy")) +
geom_line(aes(y = val_accuracy, color = "Validation Accuracy")) +
labs(x = "Epoch", y = "Accuracy", title = "Training and Validation Accuracy") +
scale_color_manual(name = "",
values = c("Training Accuracy" = "green4",
"Validation Accuracy" = "darkorange")) +
theme_bw() + # White background
theme(legend.position = "bottom") +
annotate("text", x = max(df$epoch), y = min(df$accuracy),
label = paste("Final Test Accuracy:", round(final_accuracy, 2)),
vjust=-5, hjust=1, color = "green4")

combined_plot <- plot_grid(accuracy_plot, loss_plot, labels = "AUTO", ncol = 2, align = "v")

print(combined_plot)

plot_file <- file.path(PATHS$DOCS_FIGURES, "suitability_LSTM_fit.png")
ggplot2::ggsave(filename = plot_file, plot = combined_plot, width = 8, height = 4, units = "in", dpi = 300)
message(glue::glue("Model fit plot saved to: {plot_file}"))




message("Predicting response values...")
# Step 9: Make predictions on all data
d$pred <- model %>% predict(X_reshaped)
d_all$pred <- model %>% predict(X_all_reshaped)
Expand All @@ -196,17 +289,13 @@ est_suitability <- function(PATHS, include_lagged_covariates=FALSE) {
# Remove rows where pred is NA before applying LOESS to avoid errors
filter(!is.na(pred)) %>%
mutate(
# Use LOESS smoothing after ensuring that all rows are present
pred_smooth = stats::predict(loess(pred ~ as.numeric(date_start), span = 0.01))
pred_smooth = inv_logit(stats::predict(loess(logit(pred) ~ as.numeric(date_start), span = 0.01)))
) %>%
ungroup()


if (53 %in% d_all$week) stop("week index is out of bounds")

#tmp <- d_all[d_all$iso_code == "MOZ",]
#plot(tmp$date_start, tmp$pred, type='l')
#lines(tmp$date_start, tmp$pred_smooth, col='red')

# Save predictions to CSV
path <- file.path(PATHS$MODEL_INPUT, "pred_psi_suitability.csv")
Expand Down
152 changes: 98 additions & 54 deletions R/est_symptomatic_prop.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,76 +27,120 @@ est_symptomatic_prop <- function(PATHS) {
# Read the symptomatic proportion data
df <- utils::read.csv(file.path(PATHS$DATA_SYMPTOMATIC, "summary_symptomatic_cases.csv"))

# Simulate sigma (proportion of infections that are symptomatic) using a Beta distribution

# Make a Beta distribution to simulate sigma (proportion of infections that are symptomatic)

quantiles <- c(0.0001, 0.0275, 0.25, 0.5, 0.75, 0.975, 0.9999)
probs <- c(min(df$ci_lo, na.rm = TRUE),
quantile(df$ci_lo, probs = 0.0275, na.rm = TRUE),
quantile(df$ci_lo, probs = 0.25, na.rm = TRUE),
mean(df$mean, na.rm = TRUE),
quantile(df$ci_hi, probs = 0.75, na.rm = TRUE),
quantile(df$ci_hi, probs = 0.975, na.rm = TRUE),
max(df$ci_hi, na.rm = TRUE))

prm <- propvacc::get_beta_params(quantiles = quantiles, probs = probs)
samps <- stats::rbeta(1000, shape1 = prm$shape1, shape2 = prm$shape2)
ci <- stats::quantile(samps, probs = c(0.025, 0.5, 0.975))

param_df <- MOSAIC::make_param_df(
variable_name = 'sigma',
variable_description = 'proportion symptomatic',
parameter_distribution = 'beta',
parameter_name = names(prm),
parameter_value = unlist(prm)
)
probs <- c(min(df$ci_lo, na.rm=T),
quantile(df$ci_lo, probs=0.0275, na.rm=T),
quantile(df$ci_lo, probs=0.25, na.rm=T),
mean(df$mean, na.rm=T),
quantile(df$ci_hi, probs=0.75, na.rm=T),
quantile(df$ci_hi, probs=0.975, na.rm=T),
max(df$ci_hi, na.rm=T))

path <- file.path(PATHS$MODEL_INPUT, "param_sigma_prop_symptomatic.csv")
write.csv(param_df, path, row.names = FALSE)
message(paste("Parameter data frame for symptomatic proportion (theta) saved to:", path))
prm <- get_beta_params(quantiles=quantiles, probs=probs)

# Create Beta distribution plot
samps <- rbeta(1000, shape1 = prm$shape1, shape2 = prm$shape2)
ci <- quantile(samps, probs = c(0.025, 0.5, 0.975))
df_samples <- data.frame(x = samps)

x_vals <- seq(0, 1, length.out = 1000)
y_vals <- stats::dbeta(x_vals, prm$shape1, prm$shape2)
y_vals <- dbeta(x_vals, prm$shape1, prm$shape2)
df_beta <- data.frame(x = x_vals, y = y_vals)

p1 <- ggplot2::ggplot(df_samples, ggplot2::aes(x = x)) +
ggplot2::geom_histogram(ggplot2::aes(y = ..density..), bins = 35, fill = "#1B4F72", color = 'white', alpha = 0.5) +
ggplot2::geom_line(data = df_beta, ggplot2::aes(x = x, y = y), color = "black", size = 1) +
ggplot2::geom_vline(xintercept = ci[c(1, 3)], linetype = "dashed", color = "grey20", size = 0.25) +
ggplot2::geom_vline(xintercept = ci[2], linetype = "dashed", color = "grey20", size = 0.25) +
ggplot2::labs(title = "A", x = "", y = "") +
ggplot2::scale_x_continuous(limits = c(-0.02, 1.25), breaks = seq(0, 1, 0.25), expand = c(0, 0)) +
ggplot2::scale_y_continuous(expand = c(0.005, 0.005)) +
ggplot2::theme_minimal(base_size = 14)

# Comparison with previous studies
p1 <-
ggplot(df_samples, aes(x = x)) +
geom_histogram(aes(y = ..density..), bins = 35, fill = "#1B4F72", color='white', alpha = 0.5) +
#stat_function(fun = dbeta, args = list(shape1 = prm$shape1, shape2 = prm$shape2), color = "black", size = 1) +
geom_line(data = df_beta, aes(x = x, y = y), color = "black", size = 1) + # Plot Beta distribution
geom_vline(xintercept = ci[c(1,3)], linetype = "dashed", color = "grey20", size = 0.25) +
geom_vline(xintercept = ci[2], linetype = "dashed", color = "grey20", size = 0.25) +
labs(title = "A", x = "", y = "") +
scale_x_continuous(limits = c(-0.02, 1.25), breaks=seq(0, 1, 0.25), expand=c(0,0)) +
scale_y_continuous(expand=c(0.005, 0.005)) +
theme_minimal(base_size = 14) +
theme(
legend.position = "none",
panel.grid.major.y = element_blank(),
panel.grid.major.x = element_line(color='grey80', size=0.25),
panel.grid.minor = element_blank(),
axis.title.x = element_text(margin = margin(t = 30), hjust=0.3),
plot.margin = unit(c(0.25, 0.25, 0, 0), "inches")
)



# How does this compare with previous studies?



#pal <- c("#1B4F72", "#239B56", "#884EA0", "#D35400", "#7D3C98", "#566573", "#CD6155", "#5D6D7E", "#AF601A")
pal <- c("#274001", "#828a00", "#D35400", "#7D3C98", "#1B4F72", "#a62f03", "#400d01", "#4d8584")

p2 <- ggplot2::ggplot(df, ggplot2::aes(x = source, y = mean, color = source)) +
ggplot2::geom_hline(yintercept = ci[c(1, 3)], linetype = "dashed", color = "grey20", size = 0.25) +
ggplot2::geom_hline(yintercept = ci[2], linetype = "dashed", color = "grey20", size = 0.25) +
ggplot2::geom_rect(ggplot2::aes(xmin = as.numeric(factor(source)) - 0.2,
xmax = as.numeric(factor(source)) + 0.2,
ymin = ci_lo, ymax = ci_hi, fill = source), alpha = 0.3, color = NA) +
ggplot2::geom_rect(ggplot2::aes(xmin = as.numeric(factor(source)) - 0.2,
xmax = as.numeric(factor(source)) + 0.2,
ymin = mean - 0.0025, ymax = mean + 0.0025, fill = source)) +
ggplot2::labs(title = "B", x = "", y = "Proportion of infections that are symptomatic") +
ggplot2::scale_fill_manual(values = pal) +
ggplot2::scale_color_manual(values = pal) +
ggplot2::theme_minimal(base_size = 14) +
ggplot2::coord_flip()

# Combine and save plot
combo <- cowplot::plot_grid(p1, p2, ncol = 1, rel_heights = c(1, 1.5), align = 'vh')
p2 <-
ggplot(df, aes(x = source, y = mean, color = source)) +
geom_hline(yintercept = ci[c(1,3)], linetype = "dashed", color = "grey20", size = 0.25) +
geom_hline(yintercept = ci[2], linetype = "dashed", color = "grey20", size = 0.25) +
geom_rect(aes(xmin = as.numeric(factor(source)) - 0.2,
xmax = as.numeric(factor(source)) + 0.2,
ymin = ci_lo, ymax = ci_hi,
fill = source),
alpha = 0.3, color=NA) +
geom_rect(aes(xmin = as.numeric(factor(source)) - 0.2,
xmax = as.numeric(factor(source)) + 0.2,
ymin = mean-0.0025, ymax = mean+0.0025, fill = source)) +
geom_text(aes(label = note2), vjust = -1, hjust = 1.25, size=3) +
annotate("text", x = 1, y = 1.02, hjust = 0, vjust = 0.5, label = "Pakistan", color = pal[1], alpha=0.7, size = 3.5) +
annotate("text", x = 2, y = 1.02, hjust = 0, vjust = 0.5, label = "Haiti", color = pal[2], alpha=0.7, size = 3.5) +
annotate("text", x = 3, y = 1.02, hjust = 0, vjust = 0.5, label = "Bangladesh", color = pal[3], alpha=0.7, size = 3.5) +
annotate("text", x = 4, y = 1.02, hjust = 0, vjust = 0.5, label = "Endemic regions", color = pal[4], alpha=0.7, size = 3.5) +
annotate("text", x = 5, y = 1.02, hjust = 0, vjust = 0.5, label = "Bangladesh", color = pal[5], alpha=0.7, size = 3.5) +
annotate("text", x = 6, y = 1.02, hjust = 0, vjust = 0.5, label = "Haiti", color = pal[6], alpha=0.7, size = 3.5) +
labs(title = "B", x = "", y = "Proportion of infections that are symptomatic") +
scale_fill_manual(values = pal) +
scale_color_manual(values = pal) +
scale_y_continuous(limits = c(-0.02, 1.25), breaks=seq(0, 1, 0.25), expand=c(0,0)) +
theme_minimal(base_size = 14) +
theme(
legend.position = "none",
panel.grid.major.y = element_blank(),
panel.grid.major.x = element_line(color='grey80', size=0.25),
panel.grid.minor = element_blank(),
axis.title.x = element_text(margin = margin(t = 30), hjust=0.3),
axis.title.y = element_text(margin = margin(r = 15)),
plot.margin = unit(c(0, 0.25, 0.25, 0), "inches")
) +
coord_flip()


combo <- plot_grid(p1, p2, ncol = 1, rel_heights = c(1, 1.5), align='vh')
print(combo)

figure_path <- file.path(PATHS$DOCS_FIGURES, "proportion_symptomatic.png")

figure_path <- file.path(PATHS$DOCS_FIGURES, "proportion_symptomatic.png")
grDevices::png(filename = figure_path, width = 2400, height = 2400, units = "px", res = 300)
print(combo)
grDevices::dev.off()

message(paste("Symptomatic proportion plot saved to:", figure_path))



# Save parameters

param_df <- MOSAIC::make_param_df(
variable_name = 'sigma',
variable_description = 'proportion symptomatic',
parameter_distribution = 'beta',
parameter_name = names(prm),
parameter_value = unlist(prm)
)

path <- file.path(PATHS$MODEL_INPUT, "param_sigma_prop_symptomatic.csv")
write.csv(param_df, path, row.names = FALSE)
message(paste("Parameter data frame for symptomatic proportion (theta) saved to:", path))



}
2 changes: 2 additions & 0 deletions R/get_ENSO_historical.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ get_ENSO_historical <- function() {

# Helper function to download and process each dataset
process_enso_historical <- function(url, variable_name) {

# Read the raw data, skipping the header lines and excluding missing values (-99.99)
raw_data <- utils::read.table(url, skip = 1, fill = TRUE, na.strings = "-99.99", header = FALSE)

Expand Down Expand Up @@ -60,6 +61,7 @@ get_ENSO_historical <- function() {
enso_long <- dplyr::arrange(enso_long, year, month)

return(enso_long)

}

# Process and combine all datasets (ENSO3, ENSO34, ENSO4)
Expand Down
Loading

0 comments on commit 8e561fe

Please sign in to comment.