Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into feature/pathfinder-…
Browse files Browse the repository at this point in the history
…inits
  • Loading branch information
SteveBronder committed Mar 20, 2024
2 parents b1f9fdc + ae1b7b3 commit a078b18
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 11 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/R-CMD-check-wsl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ jobs:

- uses: actions/checkout@v4

- uses: r-lib/actions/setup-r@v2.7.2
- uses: r-lib/actions/setup-r@v2.8.2
with:
r-version: 'release'
rtools-version: '42'
- uses: r-lib/actions/setup-pandoc@v2.7.2
- uses: r-lib/actions/setup-pandoc@v2.8.2

- name: Query dependencies
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ jobs:
sudo apt-get install -y libcurl4-openssl-dev || true
sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev || true
- uses: r-lib/actions/setup-r@v2.7.2
- uses: r-lib/actions/setup-r@v2.8.2
with:
r-version: ${{ matrix.config.r }}
rtools-version: ${{ matrix.config.rtools }}
- uses: r-lib/actions/setup-pandoc@v2.7.2
- uses: r-lib/actions/setup-pandoc@v2.8.2

- name: Query dependencies
run: |
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/Test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ jobs:
if: "!startsWith(github.ref, 'refs/tags/') && github.ref != 'refs/heads/master'"
- uses: actions/checkout@v4

- uses: r-lib/actions/setup-r@v2.7.2
- uses: r-lib/actions/setup-pandoc@v2.7.2
- uses: r-lib/actions/setup-r@v2.8.2
- uses: r-lib/actions/setup-pandoc@v2.8.2

- name: Install Ubuntu dependencies
run: |
Expand Down Expand Up @@ -85,12 +85,12 @@ jobs:
steps:
- uses: actions/checkout@v4

- uses: r-lib/actions/setup-r@v2.7.2
- uses: r-lib/actions/setup-r@v2.8.2
with:
r-version: 'release'
rtools-version: '42'

- uses: r-lib/actions/setup-pandoc@v2.7.2
- uses: r-lib/actions/setup-pandoc@v2.8.2

- name: Query dependencies
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/cmdstan-tarball-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ jobs:
sudo apt-get install -y libcurl4-openssl-dev || true
sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev || true
- uses: r-lib/actions/setup-r@v2.7.2
- uses: r-lib/actions/setup-r@v2.8.2
with:
r-version: ${{ matrix.config.r }}
rtools-version: ${{ matrix.config.rtools }}

- uses: r-lib/actions/setup-pandoc@v2.7.2
- uses: r-lib/actions/setup-pandoc@v2.8.2

- name: Query dependencies
run: |
Expand Down
3 changes: 3 additions & 0 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ SampleArgs <- R6::R6Class(
fileext = ".json"
)
for (i in seq_along(inv_metric_paths)) {
if (length(inv_metric[[i]]) == 1 && metric == "diag_e") {
inv_metric[[i]] <- array(inv_metric[[i]], dim = c(1))
}
write_stan_json(list(inv_metric = inv_metric[[i]]), inv_metric_paths[i])
}

Expand Down
5 changes: 4 additions & 1 deletion R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -1742,7 +1742,10 @@ inv_metric <- function(matrix = TRUE) {
out <- private$inv_metric_
if (matrix && !is.matrix(out[[1]])) {
# convert each vector to a diagonal matrix
out <- lapply(out, diag)
out <- lapply(out, function(x) diag(x, nrow = length(x)))
} else if (length(out[[1]]) == 1) {
# convert each scalar to an array with dimension 1
out <- lapply(out, array, dim = c(1))
}
out
}
Expand Down
74 changes: 74 additions & 0 deletions tests/testthat/test-model-sample-metric.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ set_cmdstan_path()
mod <- testing_model("bernoulli")
data_list <- testing_data("bernoulli")

mod2 <- testing_model("logistic")
data_list2 <- testing_data("logistic")


test_that("sample() method works with provided inv_metrics", {
inv_metric_vector <- array(1, dim = c(1))
Expand Down Expand Up @@ -54,6 +57,77 @@ test_that("sample() method works with provided inv_metrics", {
})


test_that("sample() method works with inv_metrics extracted from previous fit with 1 parameter", {
expect_sample_output(fit_r <- mod$sample(data = data_list,
chains = 2,
seed = 123))
inv_metric_vector <- fit_r$inv_metric(matrix = FALSE)
inv_metric_matrix <- fit_r$inv_metric()

expect_equal(dim(inv_metric_vector[[1]]), 1)
expect_equal(dim(inv_metric_matrix[[1]]), c(1, 1))

expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list,
chains = 1,
metric = "diag_e",
inv_metric = inv_metric_vector[[1]],
seed = 123)))

expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list,
chains = 1,
metric = "dense_e",
inv_metric = inv_metric_matrix[[1]],
seed = 123)))

expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list,
chains = 2,
metric = "diag_e",
inv_metric = inv_metric_vector,
seed = 123)))

expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list,
chains = 2,
metric = "dense_e",
inv_metric = inv_metric_matrix,
seed = 123)))
})

test_that("sample() method works with inv_metrics extracted from previous fit with > 1 parameter", {
expect_sample_output(fit_r <- mod2$sample(data = data_list2,
chains = 2,
seed = 123))
inv_metric_vector <- fit_r$inv_metric(matrix = FALSE)
inv_metric_matrix <- fit_r$inv_metric()

expect_equal(length(inv_metric_vector[[1]]), 4)
expect_equal(dim(inv_metric_matrix[[1]]), c(4, 4))

expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
chains = 1,
metric = "diag_e",
inv_metric = inv_metric_vector[[1]],
seed = 123)))

expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
chains = 1,
metric = "dense_e",
inv_metric = inv_metric_matrix[[1]],
seed = 123)))

expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
chains = 2,
metric = "diag_e",
inv_metric = inv_metric_vector,
seed = 123)))

expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
chains = 2,
metric = "dense_e",
inv_metric = inv_metric_matrix,
seed = 123)))
})


test_that("sample() method works with lists of inv_metrics", {
inv_metric_vector <- array(1, dim = c(1))
inv_metric_vector_json <- test_path("resources", "metric", "bernoulli.inv_metric.diag_e.json")
Expand Down

0 comments on commit a078b18

Please sign in to comment.