From 6cae03aa1a9344a9a2a8c35acdc5bfbe0f7b9056 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Wed, 23 Aug 2023 11:06:42 +0300 Subject: [PATCH 1/2] Fix variable skeleton with containers --- R/utils.R | 6 +++++- tests/testthat/test-model-methods.R | 33 +++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/R/utils.R b/R/utils.R index bf65b513d..c2209047c 100644 --- a/R/utils.R +++ b/R/utils.R @@ -780,7 +780,11 @@ create_skeleton <- function(param_metadata, model_variables, names(model_variables$generated_quantities)) } lapply(param_metadata[target_params], function(par_dims) { - array(0, dim = ifelse(length(par_dims) == 0, 1, par_dims)) + if ((length(par_dims) == 0)) { + array(0, dim = 1) + } else { + array(0, dim = par_dims) + } }) } diff --git a/tests/testthat/test-model-methods.R b/tests/testthat/test-model-methods.R index 8352a04f3..a92b181ad 100644 --- a/tests/testthat/test-model-methods.R +++ b/tests/testthat/test-model-methods.R @@ -277,3 +277,36 @@ test_that("Model methods can be initialised for models with no data", { expect_no_error(fit <- mod$sample()) expect_equal(fit$log_prob(5), -12.5) }) + +test_that("Variable skeleton returns correct dimensions for matrices", { + skip_if(os_is_wsl()) + + stan_file <- write_stan_file(" + data { + int N; + int K; + } + parameters { + real x_real; + matrix[N,K] x_mat; + vector[K] x_vec; + row_vector[K] x_rowvec; + } + model { + x_real ~ std_normal(); + }") + mod <- cmdstan_model(stan_file, compile_model_methods = TRUE, + force_recompile = TRUE) + fit <- mod$sample(data = list(N = 4, K = 3), chains = 1, + iter_warmup = 1, iter_sampling = 1) + + target_skeleton <- list( + x_real = array(0, dim = 1), + x_mat = array(0, dim = c(4, 3)), + x_vec = array(0, dim = c(3)), + x_rowvec = array(0, dim = c(3)) + ) + + expect_equal(fit$variable_skeleton(), + target_skeleton) +}) From 97d11427a4603a87bf5be90d1b62603107f17cac Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Wed, 23 Aug 2023 11:08:22 +0300 Subject: [PATCH 2/2] Update test --- tests/testthat/test-model-methods.R | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/testthat/test-model-methods.R b/tests/testthat/test-model-methods.R index a92b181ad..1289cb063 100644 --- a/tests/testthat/test-model-methods.R +++ b/tests/testthat/test-model-methods.R @@ -297,14 +297,16 @@ test_that("Variable skeleton returns correct dimensions for matrices", { }") mod <- cmdstan_model(stan_file, compile_model_methods = TRUE, force_recompile = TRUE) - fit <- mod$sample(data = list(N = 4, K = 3), chains = 1, + N <- 4 + K <- 3 + fit <- mod$sample(data = list(N = N, K = K), chains = 1, iter_warmup = 1, iter_sampling = 1) target_skeleton <- list( x_real = array(0, dim = 1), - x_mat = array(0, dim = c(4, 3)), - x_vec = array(0, dim = c(3)), - x_rowvec = array(0, dim = c(3)) + x_mat = array(0, dim = c(N, K)), + x_vec = array(0, dim = K), + x_rowvec = array(0, dim = K) ) expect_equal(fit$variable_skeleton(),