Skip to content

Commit

Permalink
Merge pull request #935 from venpopov/inv_metric_1par
Browse files Browse the repository at this point in the history
Fix incorrect format of inv_metric when only 1 parameter in model
  • Loading branch information
jgabry authored Mar 19, 2024
2 parents 82f9d9a + c69ef19 commit ae1b7b3
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 1 deletion.
3 changes: 3 additions & 0 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ SampleArgs <- R6::R6Class(
fileext = ".json"
)
for (i in seq_along(inv_metric_paths)) {
if (length(inv_metric[[i]]) == 1 && metric == "diag_e") {
inv_metric[[i]] <- array(inv_metric[[i]], dim = c(1))
}
write_stan_json(list(inv_metric = inv_metric[[i]]), inv_metric_paths[i])
}

Expand Down
5 changes: 4 additions & 1 deletion R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -1742,7 +1742,10 @@ inv_metric <- function(matrix = TRUE) {
out <- private$inv_metric_
if (matrix && !is.matrix(out[[1]])) {
# convert each vector to a diagonal matrix
out <- lapply(out, diag)
out <- lapply(out, function(x) diag(x, nrow = length(x)))
} else if (length(out[[1]]) == 1) {
# convert each scalar to an array with dimension 1
out <- lapply(out, array, dim = c(1))
}
out
}
Expand Down
74 changes: 74 additions & 0 deletions tests/testthat/test-model-sample-metric.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ set_cmdstan_path()
mod <- testing_model("bernoulli")
data_list <- testing_data("bernoulli")

mod2 <- testing_model("logistic")
data_list2 <- testing_data("logistic")


test_that("sample() method works with provided inv_metrics", {
inv_metric_vector <- array(1, dim = c(1))
Expand Down Expand Up @@ -54,6 +57,77 @@ test_that("sample() method works with provided inv_metrics", {
})


test_that("sample() method works with inv_metrics extracted from previous fit with 1 parameter", {
expect_sample_output(fit_r <- mod$sample(data = data_list,
chains = 2,
seed = 123))
inv_metric_vector <- fit_r$inv_metric(matrix = FALSE)
inv_metric_matrix <- fit_r$inv_metric()

expect_equal(dim(inv_metric_vector[[1]]), 1)
expect_equal(dim(inv_metric_matrix[[1]]), c(1, 1))

expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list,
chains = 1,
metric = "diag_e",
inv_metric = inv_metric_vector[[1]],
seed = 123)))

expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list,
chains = 1,
metric = "dense_e",
inv_metric = inv_metric_matrix[[1]],
seed = 123)))

expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list,
chains = 2,
metric = "diag_e",
inv_metric = inv_metric_vector,
seed = 123)))

expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list,
chains = 2,
metric = "dense_e",
inv_metric = inv_metric_matrix,
seed = 123)))
})

test_that("sample() method works with inv_metrics extracted from previous fit with > 1 parameter", {
expect_sample_output(fit_r <- mod2$sample(data = data_list2,
chains = 2,
seed = 123))
inv_metric_vector <- fit_r$inv_metric(matrix = FALSE)
inv_metric_matrix <- fit_r$inv_metric()

expect_equal(length(inv_metric_vector[[1]]), 4)
expect_equal(dim(inv_metric_matrix[[1]]), c(4, 4))

expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
chains = 1,
metric = "diag_e",
inv_metric = inv_metric_vector[[1]],
seed = 123)))

expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
chains = 1,
metric = "dense_e",
inv_metric = inv_metric_matrix[[1]],
seed = 123)))

expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
chains = 2,
metric = "diag_e",
inv_metric = inv_metric_vector,
seed = 123)))

expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
chains = 2,
metric = "dense_e",
inv_metric = inv_metric_matrix,
seed = 123)))
})


test_that("sample() method works with lists of inv_metrics", {
inv_metric_vector <- array(1, dim = c(1))
inv_metric_vector_json <- test_path("resources", "metric", "bernoulli.inv_metric.diag_e.json")
Expand Down

0 comments on commit ae1b7b3

Please sign in to comment.