From ce557ad8ac68ad57f040040964248de83afcb6c0 Mon Sep 17 00:00:00 2001 From: Ven Popov Date: Tue, 19 Mar 2024 15:26:55 +0100 Subject: [PATCH 1/3] add tests showing failure of inv_metric for 1 parameter --- tests/testthat/test-model-sample-metric.R | 36 +++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/testthat/test-model-sample-metric.R b/tests/testthat/test-model-sample-metric.R index 38159d40a..7163b4c85 100644 --- a/tests/testthat/test-model-sample-metric.R +++ b/tests/testthat/test-model-sample-metric.R @@ -54,6 +54,42 @@ test_that("sample() method works with provided inv_metrics", { }) +test_that("sample() method works with inv_metrics extracted from previous fit", { + expect_sample_output(fit_r <- mod$sample(data = data_list, + chains = 2, + seed = 123)) + inv_metric_vector <- fit_r$inv_metric(matrix = FALSE) + inv_metric_matrix <- fit_r$inv_metric() + + expect_equal(dim(inv_metric_vector[[1]]), 1) + expect_equal(dim(inv_metric_matrix[[1]]), c(1, 1)) + + expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list, + chains = 1, + metric = "diag_e", + inv_metric = inv_metric_vector[[1]], + seed = 123))) + + expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list, + chains = 1, + metric = "dense_e", + inv_metric = inv_metric_matrix[[1]], + seed = 123))) + + expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list, + chains = 2, + metric = "diag_e", + inv_metric = inv_metric_vector, + seed = 123))) + + expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list, + chains = 2, + metric = "dense_e", + inv_metric = inv_metric_matrix, + seed = 123))) +}) + + test_that("sample() method works with lists of inv_metrics", { inv_metric_vector <- array(1, dim = c(1)) inv_metric_vector_json <- test_path("resources", "metric", "bernoulli.inv_metric.diag_e.json") From 69135f4d0f3e70484946794e57456727f71f2106 Mon Sep 17 00:00:00 2001 From: Ven Popov Date: Tue, 19 Mar 2024 15:48:35 +0100 Subject: [PATCH 2/3] fix case with matrix=FALSE and 1 parameter --- R/args.R | 3 +++ R/fit.R | 3 +++ 2 files changed, 6 insertions(+) diff --git a/R/args.R b/R/args.R index c885bb2e5..6dcf8095c 100644 --- a/R/args.R +++ b/R/args.R @@ -254,6 +254,9 @@ SampleArgs <- R6::R6Class( fileext = ".json" ) for (i in seq_along(inv_metric_paths)) { + if (length(inv_metric[[i]] == 1) && metric == "diag_e") { + inv_metric[[i]] <- array(inv_metric[[i]], dim = c(1)) + } write_stan_json(list(inv_metric = inv_metric[[i]]), inv_metric_paths[i]) } diff --git a/R/fit.R b/R/fit.R index cc1e2e53e..7fafcd6b4 100644 --- a/R/fit.R +++ b/R/fit.R @@ -1743,6 +1743,9 @@ inv_metric <- function(matrix = TRUE) { if (matrix && !is.matrix(out[[1]])) { # convert each vector to a diagonal matrix out <- lapply(out, diag) + } else if (length(out[[1]]) == 1) { + # convert each scalar to a 1x1 matrix + out <- lapply(out, array, dim = c(1)) } out } From daa111d9c4f664dc0de966a254b719eb26c2baac Mon Sep 17 00:00:00 2001 From: Ven Popov Date: Tue, 19 Mar 2024 16:20:15 +0100 Subject: [PATCH 3/3] fix empty matrix inv_metric with 1 parameter --- R/args.R | 2 +- R/fit.R | 4 +-- tests/testthat/test-model-sample-metric.R | 40 ++++++++++++++++++++++- 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/R/args.R b/R/args.R index 6dcf8095c..7b4a86867 100644 --- a/R/args.R +++ b/R/args.R @@ -254,7 +254,7 @@ SampleArgs <- R6::R6Class( fileext = ".json" ) for (i in seq_along(inv_metric_paths)) { - if (length(inv_metric[[i]] == 1) && metric == "diag_e") { + if (length(inv_metric[[i]]) == 1 && metric == "diag_e") { inv_metric[[i]] <- array(inv_metric[[i]], dim = c(1)) } write_stan_json(list(inv_metric = inv_metric[[i]]), inv_metric_paths[i]) diff --git a/R/fit.R b/R/fit.R index 7fafcd6b4..99feca4ce 100644 --- a/R/fit.R +++ b/R/fit.R @@ -1742,9 +1742,9 @@ inv_metric <- function(matrix = TRUE) { out <- private$inv_metric_ if (matrix && !is.matrix(out[[1]])) { # convert each vector to a diagonal matrix - out <- lapply(out, diag) + out <- lapply(out, function(x) diag(x, nrow = length(x))) } else if (length(out[[1]]) == 1) { - # convert each scalar to a 1x1 matrix + # convert each scalar to an array with dimension 1 out <- lapply(out, array, dim = c(1)) } out diff --git a/tests/testthat/test-model-sample-metric.R b/tests/testthat/test-model-sample-metric.R index 7163b4c85..422442fac 100644 --- a/tests/testthat/test-model-sample-metric.R +++ b/tests/testthat/test-model-sample-metric.R @@ -4,6 +4,9 @@ set_cmdstan_path() mod <- testing_model("bernoulli") data_list <- testing_data("bernoulli") +mod2 <- testing_model("logistic") +data_list2 <- testing_data("logistic") + test_that("sample() method works with provided inv_metrics", { inv_metric_vector <- array(1, dim = c(1)) @@ -54,7 +57,7 @@ test_that("sample() method works with provided inv_metrics", { }) -test_that("sample() method works with inv_metrics extracted from previous fit", { +test_that("sample() method works with inv_metrics extracted from previous fit with 1 parameter", { expect_sample_output(fit_r <- mod$sample(data = data_list, chains = 2, seed = 123)) @@ -89,6 +92,41 @@ test_that("sample() method works with inv_metrics extracted from previous fit", seed = 123))) }) +test_that("sample() method works with inv_metrics extracted from previous fit with > 1 parameter", { + expect_sample_output(fit_r <- mod2$sample(data = data_list2, + chains = 2, + seed = 123)) + inv_metric_vector <- fit_r$inv_metric(matrix = FALSE) + inv_metric_matrix <- fit_r$inv_metric() + + expect_equal(length(inv_metric_vector[[1]]), 4) + expect_equal(dim(inv_metric_matrix[[1]]), c(4, 4)) + + expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2, + chains = 1, + metric = "diag_e", + inv_metric = inv_metric_vector[[1]], + seed = 123))) + + expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2, + chains = 1, + metric = "dense_e", + inv_metric = inv_metric_matrix[[1]], + seed = 123))) + + expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2, + chains = 2, + metric = "diag_e", + inv_metric = inv_metric_vector, + seed = 123))) + + expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2, + chains = 2, + metric = "dense_e", + inv_metric = inv_metric_matrix, + seed = 123))) +}) + test_that("sample() method works with lists of inv_metrics", { inv_metric_vector <- array(1, dim = c(1))