Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add structure to foi_index across the package #221

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 67 additions & 30 deletions R/build_stan_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,40 +67,63 @@ sf_none <- function() {
#' be estimated when sampling.
ntorresd marked this conversation as resolved.
Show resolved Hide resolved
#' @inheritParams fit_seromodel
#' @param group_size Age groups size
#' @return Integer vector with the indexes numerating each year/age
#' (depending on the model).
#' @param model_type Type of the model. Either "age" or "time"
#' @return Data frame with the indexes numerating each age/year (depending on
ntorresd marked this conversation as resolved.
Show resolved Hide resolved
#' the model). A single FOI value will be estimated for ages/years assigned
#' with the same index
#' @examples
#' data(chagas2012)
#' foi_index <- get_foi_index(chagas2012, group_size = 25)
#' foi_index <- get_foi_index(chagas2012, group_size = 25, model_type = "time")
#' @export
get_foi_index <- function(
serosurvey,
group_size
group_size,
model_type
) {
checkmate::assert_int(
group_size,
lower = 1,
upper = max(serosurvey$age_max)
)
# Check model_type correspond to a valid model
stopifnot(
"model_type must be either 'time' or 'age'" =
model_type %in% c("time", "age")
)

foi_index <- unlist(
purrr::map(
seq(
1,
max(serosurvey$age_max) / group_size,
1),
rep,
times = group_size
)
# Check group_size dimension is in the right range
checkmate::assert_int(
group_size,
lower = 1,
upper = max(serosurvey$age_max)
)

foi_index <- c(
foi_index,
rep(
max(foi_index),
max(serosurvey$age_max) - length(foi_index)
)
foi_indexes <- unlist(
purrr::map(
seq(
1,
max(serosurvey$age_max) / group_size,
1),
rep,
times = group_size
)
)

foi_indexes <- c(
foi_indexes,
rep(
max(foi_indexes),
max(serosurvey$age_max) - length(foi_indexes)
)
)

if (model_type == "time") {
survey_year <- unique(serosurvey$survey_year)
foi_index <- data.frame(
year = seq(survey_year - max(serosurvey$age_max), survey_year - 1),
foi_index = foi_indexes
)
} else if (model_type == "age") {
foi_index <- data.frame(
age = seq(1, max(serosurvey$age_max), 1),
foi_index = foi_indexes
)
}

return(foi_index)
}
Expand Down Expand Up @@ -194,19 +217,33 @@ build_stan_data <- function(
set_stan_data_defaults(
is_log_foi = is_log_foi,
is_seroreversion = is_seroreversion
)
)

if (is.null(foi_index)) {
foi_index_default <- get_foi_index(serosurvey = serosurvey, group_size = 1)
if (model_type == "constant") {
stan_data <- c(
stan_data,
list(foi_index = rep(1, max(serosurvey$age_max)))
)
} else if (is.null(foi_index) && model_type != "constant") {
foi_index_default <- get_foi_index(
serosurvey = serosurvey,
group_size = 1,
model_type = model_type
)
stan_data <- c(
stan_data,
list(foi_index = foi_index_default)
list(foi_index = foi_index_default$foi_index)
)
} else {
# TODO: check that foi_index is the right size
validate_foi_index(
foi_index = foi_index,
serosurvey = serosurvey,
model_type = model_type
)

stan_data <- c(
stan_data,
list(foi_index = foi_index)
list(foi_index = foi_index$foi_index)
)
}
config_file <- system.file("extdata", "config.yml", package = "serofoi")
Expand Down
2 changes: 1 addition & 1 deletion R/fit_seromodel.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ set_foi_init <- function(
#' seromodel <- fit_seromodel(
#' serosurvey = veev2012,
#' model_type = "time",
#' foi_index = get_foi_index(veev2012, group_size = 30)
#' foi_index = get_foi_index(veev2012, group_size = 30, model_type = "time")
#' )
#' @export
fit_seromodel <- function(
Expand Down
25 changes: 25 additions & 0 deletions R/validation.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,28 @@ validate_survey_and_foi_consistency_age_time <- function( #nolint
"not exceed max age in survey_features."
)
}

validate_foi_index <- function(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest checking that the FOI chunks are consecutive: e.g. that we don't get 1, 1, 1, 2, 2, 2, 1, 1, 3, 3. I think we also need to have indices for all integers up until the max. So we would raise a warning for 1, 1, 1, 3, 3 for instance because it misses a 2.

foi_index,
serosurvey,
model_type
) {
# Check model_type correspond to a valid model
stopifnot(
"model_type must be either 'time' or 'age'" =
model_type %in% c("time", "age")
)

# validate that foi_index has the right columns
if (model_type == "age") {
checkmate::assert_names(names(foi_index), must.include = "age")
} else if (model_type == "time") {
checkmate::assert_names(names(foi_index), must.include = "year")
}

# validate that foi_index has the right size
stopifnot(
"foi_index must be the right size" =
nrow(foi_index) == max(serosurvey$age_max)
)
}
2 changes: 1 addition & 1 deletion man/fit_seromodel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 7 additions & 4 deletions man/get_foi_index.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

86 changes: 86 additions & 0 deletions tests/testthat/test-get_foi_index.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Setup for testing ----
# Sample survey features to use in test
survey_features <- data.frame(
age_min = c(1, 6, 11, 16, 21),
age_max = c(5, 10, 15, 20, 25),
survey_year = 2025
)

# Test get_foi_index ----
test_that("get_foi_index returns correct output for valid model types", {
# Test for model_type "age"
result_age <- get_foi_index(survey_features, group_size = 5, model_type = "age")

# Check if the data frame has the correct structure for "age"
checkmate::assert_names(
names(result_age),
must.include = c("age", "foi_index")
)
expect_equal(nrow(result_age), max(survey_features$age_max))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are more things we could check: e.g. max index, all indices represented up until that max index, that we only have contiguous indices. I realise this may seem overkill given the simplicity of this function, but we could imagine changing this function in the future and we'd want to check that it still works as intended.


# Test for model_type "time"
result_time <- get_foi_index(survey_features, group_size = 5, model_type = "time")

# Check if the data frame has the correct structure for "time"
checkmate::assert_names(
names(result_time),
must.include = c("year", "foi_index")
)
expect_equal(nrow(result_time), max(survey_features$age_max))
})

test_that("get_foi_index returns an error for invalid model_type", {
# Test for invalid model_type
expect_error(
get_foi_index(serosurvey, model_type = "constant"),
regexp = "model_type must be either 'time' or 'age'"
)
})

test_that("get_foi_index handles different group_size correctly", {
# Test when group_size = 1 (edge case)
result_1 <- get_foi_index(
survey_features,
group_size = 1,
model_type = "age"
)
expected_1 <- seq(1, max(survey_features$age_max))

expect_equal(nrow(result_1), max(survey_features$age_max))
expect_equal(result_1$foi_index, expected_1)

# Test when group_size equals the maximum age (edge case)
result_max <- get_foi_index(
survey_features,
group_size = max(survey_features$age_max),
model_type = "time"
)
expected_max <- rep(1, 25)
expect_equal(nrow(result_max), max(survey_features$age_max))
expect_equal(result_max$foi_index, expected_max)

# Test when max age is not divisible by group_size
result_no_div <- get_foi_index(
survey_features,
group_size = 7,
model_type = "time"
)
# the remaining times are indexed in the last chunk
expected_no_div <- c(rep(1, 7), rep(2, 7), rep(3, 7 + 4))
expect_equal(nrow(result_no_div), max(survey_features$age_max))
expect_equal(result_no_div$foi_index, expected_no_div)
})

test_that("get_foi_index throws an error for invalid group_size", {
# Test for group_size > max age_max
expect_error(
get_foi_index(survey_features, group_size = 30, model_type = "age"),
regexp = "Assertion on 'group_size' failed"
)

# Test for group_size less than 1
expect_error(
get_foi_index(survey_features, group_size = 0, model_type = "age"),
regexp = "Assertion on 'group_size' failed"
)
})
19 changes: 15 additions & 4 deletions vignettes/articles/foi_models.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,14 @@ mutate(survey_year = 2050)
The simulated dataset `foi_sim_sw_dec` contains information about 250 samples of individuals between 1 and 50 years old (5 samples per age) with age groups of 5 years length. The following code shows how to implement the slow time-varying normal model to this simulated serosurvey:

```{r tv_normal model, include = TRUE, echo = TRUE, results="hide", errors = FALSE, warning = FALSE, message = FALSE, fig.width=4, fig.asp=1.5, fig.align="center", out.width ="50%", fig.keep="all"}

foi_index <- data.frame(
year = seq(2000, 2049),
foi_index = rep(c(1, 2, 3), c(25, 10, 15))
)
seromodel_time_normal <- fit_seromodel(
serosurvey = serosurvey_sw_dec,
model_type = "time",
foi_index = rep(c(1, 2, 3), c(25, 10, 15)),
foi_index = foi_index,
iter = 1500
)
plot_seromodel(
Expand Down Expand Up @@ -257,11 +260,15 @@ mutate(survey_year = 2050)
The simulated serosurvey tests 250 individuals between 1 and 50 years old by the year 2050. The implementation of the fast epidemic model can be obtained running the following lines of code:

```{r tv_normal_log model, include = TRUE, echo = TRUE, results="hide", errors = FALSE, warning = FALSE, message = FALSE, fig.width=4, fig.asp=1.5, fig.align="center", out.width ="50%", fig.keep="all"}
foi_index <- data.frame(
year = seq(2000, 2049),
foi_index = rep(c(1, 2, 3), c(30, 5, 15))
)
seromodel_log_time_normal <- fit_seromodel(
serosurvey = serosurvey_large_epi,
model_type = "time",
is_log_foi = TRUE,
foi_index = rep(c(1, 2, 3), c(30, 5, 15)),
foi_index = foi_index,
iter = 2000
)

Expand Down Expand Up @@ -296,10 +303,14 @@ plot_constant <- plot_seromodel(
size_text = 6
)

foi_index <- data.frame(
year = seq(2000, 2049),
foi_index = rep(c(1, 2, 3), c(30, 5, 15))
)
seromodel_time_normal <- fit_seromodel(
serosurvey = serosurvey_large_epi,
model_type = "time",
foi_index = rep(c(1, 2, 3), c(30, 5, 15)),
foi_index = foi_index,
iter = 2000
)
plot_time_normal <- plot_seromodel(
Expand Down
Loading
Loading