Skip to content

Commit

Permalink
Fix variable skeleton with containers
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns committed Aug 23, 2023
1 parent f028607 commit 6cae03a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
6 changes: 5 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,11 @@ create_skeleton <- function(param_metadata, model_variables,
names(model_variables$generated_quantities))
}
lapply(param_metadata[target_params], function(par_dims) {
array(0, dim = ifelse(length(par_dims) == 0, 1, par_dims))
if ((length(par_dims) == 0)) {
array(0, dim = 1)
} else {
array(0, dim = par_dims)
}
})
}

Expand Down
33 changes: 33 additions & 0 deletions tests/testthat/test-model-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,36 @@ test_that("Model methods can be initialised for models with no data", {
expect_no_error(fit <- mod$sample())
expect_equal(fit$log_prob(5), -12.5)
})

test_that("Variable skeleton returns correct dimensions for matrices", {
skip_if(os_is_wsl())

stan_file <- write_stan_file("
data {
int N;
int K;
}
parameters {
real x_real;
matrix[N,K] x_mat;
vector[K] x_vec;
row_vector[K] x_rowvec;
}
model {
x_real ~ std_normal();
}")
mod <- cmdstan_model(stan_file, compile_model_methods = TRUE,
force_recompile = TRUE)
fit <- mod$sample(data = list(N = 4, K = 3), chains = 1,
iter_warmup = 1, iter_sampling = 1)

target_skeleton <- list(
x_real = array(0, dim = 1),
x_mat = array(0, dim = c(4, 3)),
x_vec = array(0, dim = c(3)),
x_rowvec = array(0, dim = c(3))
)

expect_equal(fit$variable_skeleton(),
target_skeleton)
})

0 comments on commit 6cae03a

Please sign in to comment.