Skip to content

Commit

Permalink
fix predict mixed bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jsocolar committed Dec 10, 2023
1 parent f408146 commit 7f818be
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
29 changes: 16 additions & 13 deletions R/predict_flocker.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,6 @@ predict_flocker <- function(flocker_fit, draw_ids = NULL,
)

assertthat::assert_that(is.null(new_data) | is_flocker_data(new_data))

new_data2 <- new_data
if (is.null(new_data)) {
new_data <- flocker_fit$data
} else {
new_data <- new_data$data
}

total_iter <- brms::ndraws(flocker_fit)

Expand All @@ -70,6 +63,12 @@ predict_flocker <- function(flocker_fit, draw_ids = NULL,

# rename all random effect levels so they show up as new levels
if (mixed) {
if (is.null(new_data)) {
new_data <- flocker_fit$data
} else {
new_data <- new_data$data
}

random_effects <- flocker_fit$ranef$group
if (length(random_effects) > 0) {
potential_conflicts <- vector()
Expand All @@ -88,36 +87,40 @@ predict_flocker <- function(flocker_fit, draw_ids = NULL,
new_data[, random_effects[i]] <- paste0(new_data[, random_effects[i]],
"_resampled")
}
new_data <- list(data = new_data)
class(new_data) <- "flocker_data"
}
sample_new_levels = "gaussian"
message("`sample_new_levels` set to 'gaussian' for mixed predictive checking")
assertthat::assert_that(
sample_new_levels = "gaussian",
msg = "set `sample_new_levels` to 'gaussian' for mixed predictive checking"
)
}


Z_samp <- get_Z(flocker_fit, draw_ids = draw_ids, history_condition = history_condition,
sample = TRUE, new_data = new_data2,
sample = TRUE, new_data = new_data,
allow_new_levels = allow_new_levels, sample_new_levels = sample_new_levels)

lps <- fitted_flocker(
flocker_fit,
components = "det",
draw_ids = draw_ids, new_data = new_data2, allow_new_levels = allow_new_levels,
draw_ids = draw_ids, new_data = new_data, allow_new_levels = allow_new_levels,
sample_new_levels = sample_new_levels, response = FALSE, unit_level = FALSE
)
theta_all <- boot::inv.logit(lps$linpred_det)
ndim <- length(dim(theta_all))

assertthat::assert_that(
ndim > 2,
msg = "this shouldn't happen; please report a bug"
msg = "predict_flocker error 1. This shouldn't happen; please report a bug"
)

Z_samp_array <- abind::abind(rep(list(Z_samp), dim(theta_all)[2]), along = ndim) |>
aperm(perm = c(1, ndim, (2:(ndim - 1))))

assertthat::assert_that(
identical(dim(theta_all), dim(Z_samp_array)),
msg = "this shouldn't happen; please report a bug"
msg = "predict_flocker error 2. This shouldn't happen; please report a bug"
)

predictions <- new_array(theta_all, stats::rbinom(length(theta_all), 1, theta_all * Z_samp_array))
Expand Down
6 changes: 3 additions & 3 deletions man/fitted_flocker.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 7f818be

Please sign in to comment.