From c190c5d473c976f0d4726ca5f42a262903f42f52 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 25 Sep 2023 12:01:52 +0300 Subject: [PATCH 1/2] Fix handling of single-length inits for containers --- R/args.R | 7 +++++++ tests/testthat/test-model-init.R | 22 ++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/R/args.R b/R/args.R index 7794510e4..56d08611b 100644 --- a/R/args.R +++ b/R/args.R @@ -826,6 +826,13 @@ process_init_list <- function(init, num_procs, model_variables = NULL) { if (!all(is_parameter_value_supplied)) { missing_parameter_values[[i]] <- parameter_names[!is_parameter_value_supplied] } + for (par_name in parameter_names[is_parameter_value_supplied]) { + # Make sure that initial values for single-element containers don't get + # unboxed when writing to JSON + if (model_variables$parameters[[par_name]]$dimensions == 1 && is.null(attr(init[[i]][[par_name]], "dim"))) { + init[[i]][[par_name]] <- array(init[[i]][[par_name]], dim = 1) + } + } } if (length(missing_parameter_values) > 0) { warning_message <- c( diff --git a/tests/testthat/test-model-init.R b/tests/testthat/test-model-init.R index 221c8dfb3..cbf66c264 100644 --- a/tests/testthat/test-model-init.R +++ b/tests/testthat/test-model-init.R @@ -262,3 +262,25 @@ test_that("print message if not all parameters are initialized", { fixed = TRUE ) }) + +test_that("Initial values for single-element containers treated correctly", { + modcode <- " + data { + real y_mean; + } + parameters { + vector[1] y; + } + model { + y_mean ~ normal(y[1], 1); + } + " + mod <- cmdstan_model(write_stan_file(modcode), force_recompile = TRUE) + expect_no_error( + fit <- mod$sample( + data = list(y_mean = 0), + init = list(list(y = c(0))), + chains = 1 + ) + ) +}) From 50616d0d070913c2112ac12b5b3f2a312f41107e Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 25 Sep 2023 18:37:54 +0300 Subject: [PATCH 2/2] Fix dim check --- R/args.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/args.R b/R/args.R index 56d08611b..a2a31af08 100644 --- a/R/args.R +++ b/R/args.R @@ -829,7 +829,7 @@ process_init_list <- function(init, num_procs, model_variables = NULL) { for (par_name in parameter_names[is_parameter_value_supplied]) { # Make sure that initial values for single-element containers don't get # unboxed when writing to JSON - if (model_variables$parameters[[par_name]]$dimensions == 1 && is.null(attr(init[[i]][[par_name]], "dim"))) { + if (model_variables$parameters[[par_name]]$dimensions == 1 && length(init[[i]][[par_name]]) == 1) { init[[i]][[par_name]] <- array(init[[i]][[par_name]], dim = 1) } }