Skip to content

Commit

Permalink
Merge pull request #8 from nwfsc-cb/randomeffects
Browse files Browse the repository at this point in the history
naming
  • Loading branch information
ericward-noaa authored Nov 18, 2023
2 parents ca0e3da + 99841a4 commit a0be03b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 17 deletions.
8 changes: 6 additions & 2 deletions R/fitting.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ fit_zoid <- function(formula = NULL,
tot_re = 1,
n_groups = 1)
est_re <- FALSE
re_group_names <- NA
if (!is.null(formula)) {
model_frame <- model.frame(formula, design_matrix)
model_matrix <- model.matrix(formula, model_frame)
Expand All @@ -77,6 +78,7 @@ fit_zoid <- function(formula = NULL,
parsed_res <- res # only update if REs are in formula
est_re <- TRUE
model_matrix <- res$fixed_design_matrix
re_group_names <- res$random_effect_group_names
}
} else {
model_matrix <- matrix(1, nrow = nrow(data_matrix))
Expand Down Expand Up @@ -139,7 +141,8 @@ fit_zoid <- function(formula = NULL,
overdispersion = overdispersion,
overdispersion_prior = prior,
posterior_predict = posterior_predict,
stan_data = stan_data
stan_data = stan_data,
re_group_names = re_group_names
))
}

Expand Down Expand Up @@ -195,5 +198,6 @@ parse_re_formula <- function(formula, data) {
if(length(var_indx) > 0) n_groups <- max(var_indx)
return(list(design_matrix = design_matrix, var_indx = var_indx, n_re_by_group = n_re,
tot_re = sum(n_re), n_groups = n_groups,
fixed_design_matrix = fixed_design_matrix))
fixed_design_matrix = fixed_design_matrix,
random_effect_group_names = random_effect_group_names))
}
44 changes: 29 additions & 15 deletions R/get_pars.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
#' @return A list containing the posterior summaries of estimated parameters. At minimum,
#' this will include `p` (the estimated proportions) and `betas` (the predicted values in
#' transformed space). For models with overdispersion, an extra
#' element `phi` will also be returned, summarizing overdispersion. For predictions
#' element `phi` will also be returned, summarizing overdispersion. For models with random
#' intercepts, estimates of the group level effects will also be returned as `zetas` (again,
#' in transformed space). For predictions
#' in normal space, see [get_fitted()]
#' @importFrom rstan extract
#' @importFrom stats median quantile
Expand All @@ -36,7 +38,7 @@ get_pars <- function(fitted_model, conf_int = 0.05) {
n_group <- dim(pars$beta)[2]
n_cov <- dim(pars$beta)[3]
betas <- expand.grid(
"group" = seq(1, n_group),
"m" = seq(1, n_group),
"cov" = seq(1, n_cov),
"par" = NA,
"mean" = NA,
Expand All @@ -46,10 +48,10 @@ get_pars <- function(fitted_model, conf_int = 0.05) {
)

for (i in 1:nrow(betas)) {
betas$mean[i] <- mean(pars$beta[, betas$group[i], betas$cov[i]])
betas$median[i] <- median(pars$beta[, betas$group[i], betas$cov[i]])
betas$lo[i] <- quantile(pars$beta[, betas$group[i], betas$cov[i]], conf_int / 2.0)
betas$hi[i] <- quantile(pars$beta[, betas$group[i], betas$cov[i]], 1 - conf_int / 2.0)
betas$mean[i] <- mean(pars$beta[, betas$m[i], betas$cov[i]])
betas$median[i] <- median(pars$beta[, betas$m[i], betas$cov[i]])
betas$lo[i] <- quantile(pars$beta[, betas$m[i], betas$cov[i]], conf_int / 2.0)
betas$hi[i] <- quantile(pars$beta[, betas$m[i], betas$cov[i]], 1 - conf_int / 2.0)
betas$par[i] <- fitted_model$par_names[betas$cov[i]]
}

Expand All @@ -66,24 +68,36 @@ get_pars <- function(fitted_model, conf_int = 0.05) {

# include zetas (random group intercepts)
if (fitted_model$stan_data$est_re == 1) {
n_group <- dim(pars$zeta)[2]
n_cov <- dim(pars$zeta)[3]
m <- dim(pars$zeta)[2]
group <- dim(pars$zeta)[3]
zetas <- expand.grid(
"group" = seq(1, n_group),
"cov" = seq(1, n_cov),
"m" = seq(1, m),
"group" = seq(1, group),
"par" = NA,
"mean" = NA,
"median" = NA,
"lo" = NA,
"hi" = NA
)
for (i in 1:nrow(zetas)) {
zetas$mean[i] <- mean(pars$zeta[, zetas$group[i], zetas$cov[i]])
zetas$median[i] <- median(pars$zeta[, zetas$group[i], zetas$cov[i]])
zetas$lo[i] <- quantile(pars$zeta[, zetas$group[i], zetas$cov[i]], conf_int / 2.0)
zetas$hi[i] <- quantile(pars$zeta[, zetas$group[i], zetas$cov[i]], 1 - conf_int / 2.0)
zetas$par[i] <- fitted_model$par_names[zetas$cov[i]]
zetas$mean[i] <- mean(pars$zeta[, zetas$m[i], zetas$group[i]])
zetas$median[i] <- median(pars$zeta[, zetas$m[i], zetas$group[i]])
zetas$lo[i] <- quantile(pars$zeta[, zetas$m[i], zetas$group[i]], conf_int / 2.0)
zetas$hi[i] <- quantile(pars$zeta[, zetas$m[i], zetas$group[i]], 1 - conf_int / 2.0)
zetas$par[i] <- fitted_model$par_names[zetas$group[i]]
}
# add group names
for(i in 1:fit$stan_data$n_groups) {
if(i==1) {
ids <- rep(i,fit$stan_data$n_re_by_group[i])
} else {
ids <- c(ids, rep(i,fit$stan_data$n_re_by_group[i]))
}
}
df <- data.frame("group" = 1:max(zetas$group), "group_name" = fit$re_group_names[ids])
zetas$group_name <- ""
for(i in 1:nrow(zetas)) zetas$group_name[i] <- df$group_name[which(df$group == zetas$group[i])]
zetas <- zetas[,c("m","group","group_name","par","mean","median","lo","hi")]
par_list$zetas <- zetas
}

Expand Down
Binary file modified src/zoid.so
Binary file not shown.

0 comments on commit a0be03b

Please sign in to comment.