diff --git a/R/utils.R b/R/utils.R index bf65b513d..c2209047c 100644 --- a/R/utils.R +++ b/R/utils.R @@ -780,7 +780,11 @@ create_skeleton <- function(param_metadata, model_variables, names(model_variables$generated_quantities)) } lapply(param_metadata[target_params], function(par_dims) { - array(0, dim = ifelse(length(par_dims) == 0, 1, par_dims)) + if ((length(par_dims) == 0)) { + array(0, dim = 1) + } else { + array(0, dim = par_dims) + } }) } diff --git a/tests/testthat/test-model-methods.R b/tests/testthat/test-model-methods.R index 8352a04f3..a92b181ad 100644 --- a/tests/testthat/test-model-methods.R +++ b/tests/testthat/test-model-methods.R @@ -277,3 +277,36 @@ test_that("Model methods can be initialised for models with no data", { expect_no_error(fit <- mod$sample()) expect_equal(fit$log_prob(5), -12.5) }) + +test_that("Variable skeleton returns correct dimensions for matrices", { + skip_if(os_is_wsl()) + + stan_file <- write_stan_file(" + data { + int N; + int K; + } + parameters { + real x_real; + matrix[N,K] x_mat; + vector[K] x_vec; + row_vector[K] x_rowvec; + } + model { + x_real ~ std_normal(); + }") + mod <- cmdstan_model(stan_file, compile_model_methods = TRUE, + force_recompile = TRUE) + fit <- mod$sample(data = list(N = 4, K = 3), chains = 1, + iter_warmup = 1, iter_sampling = 1) + + target_skeleton <- list( + x_real = array(0, dim = 1), + x_mat = array(0, dim = c(4, 3)), + x_vec = array(0, dim = c(3)), + x_rowvec = array(0, dim = c(3)) + ) + + expect_equal(fit$variable_skeleton(), + target_skeleton) +})