diff --git a/R/args.R b/R/args.R index c885bb2e..7b4a8686 100644 --- a/R/args.R +++ b/R/args.R @@ -254,6 +254,9 @@ SampleArgs <- R6::R6Class( fileext = ".json" ) for (i in seq_along(inv_metric_paths)) { + if (length(inv_metric[[i]]) == 1 && metric == "diag_e") { + inv_metric[[i]] <- array(inv_metric[[i]], dim = c(1)) + } write_stan_json(list(inv_metric = inv_metric[[i]]), inv_metric_paths[i]) } diff --git a/R/fit.R b/R/fit.R index cc1e2e53..99feca4c 100644 --- a/R/fit.R +++ b/R/fit.R @@ -1742,7 +1742,10 @@ inv_metric <- function(matrix = TRUE) { out <- private$inv_metric_ if (matrix && !is.matrix(out[[1]])) { # convert each vector to a diagonal matrix - out <- lapply(out, diag) + out <- lapply(out, function(x) diag(x, nrow = length(x))) + } else if (length(out[[1]]) == 1) { + # convert each scalar to an array with dimension 1 + out <- lapply(out, array, dim = c(1)) } out } diff --git a/tests/testthat/test-model-sample-metric.R b/tests/testthat/test-model-sample-metric.R index 38159d40..422442fa 100644 --- a/tests/testthat/test-model-sample-metric.R +++ b/tests/testthat/test-model-sample-metric.R @@ -4,6 +4,9 @@ set_cmdstan_path() mod <- testing_model("bernoulli") data_list <- testing_data("bernoulli") +mod2 <- testing_model("logistic") +data_list2 <- testing_data("logistic") + test_that("sample() method works with provided inv_metrics", { inv_metric_vector <- array(1, dim = c(1)) @@ -54,6 +57,77 @@ test_that("sample() method works with provided inv_metrics", { }) +test_that("sample() method works with inv_metrics extracted from previous fit with 1 parameter", { + expect_sample_output(fit_r <- mod$sample(data = data_list, + chains = 2, + seed = 123)) + inv_metric_vector <- fit_r$inv_metric(matrix = FALSE) + inv_metric_matrix <- fit_r$inv_metric() + + expect_equal(dim(inv_metric_vector[[1]]), 1) + expect_equal(dim(inv_metric_matrix[[1]]), c(1, 1)) + + expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list, + chains = 1, + metric = "diag_e", + inv_metric = inv_metric_vector[[1]], + seed = 123))) + + expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list, + chains = 1, + metric = "dense_e", + inv_metric = inv_metric_matrix[[1]], + seed = 123))) + + expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list, + chains = 2, + metric = "diag_e", + inv_metric = inv_metric_vector, + seed = 123))) + + expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list, + chains = 2, + metric = "dense_e", + inv_metric = inv_metric_matrix, + seed = 123))) +}) + +test_that("sample() method works with inv_metrics extracted from previous fit with > 1 parameter", { + expect_sample_output(fit_r <- mod2$sample(data = data_list2, + chains = 2, + seed = 123)) + inv_metric_vector <- fit_r$inv_metric(matrix = FALSE) + inv_metric_matrix <- fit_r$inv_metric() + + expect_equal(length(inv_metric_vector[[1]]), 4) + expect_equal(dim(inv_metric_matrix[[1]]), c(4, 4)) + + expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2, + chains = 1, + metric = "diag_e", + inv_metric = inv_metric_vector[[1]], + seed = 123))) + + expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2, + chains = 1, + metric = "dense_e", + inv_metric = inv_metric_matrix[[1]], + seed = 123))) + + expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2, + chains = 2, + metric = "diag_e", + inv_metric = inv_metric_vector, + seed = 123))) + + expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2, + chains = 2, + metric = "dense_e", + inv_metric = inv_metric_matrix, + seed = 123))) +}) + + test_that("sample() method works with lists of inv_metrics", { inv_metric_vector <- array(1, dim = c(1)) inv_metric_vector_json <- test_path("resources", "metric", "bernoulli.inv_metric.diag_e.json")