Skip to content

Commit

Permalink
Merge branch 'master' into expose_new_stan_args
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns committed Apr 19, 2024
2 parents 695dba6 + 30c945c commit 6daf429
Show file tree
Hide file tree
Showing 12 changed files with 196 additions and 21 deletions.
9 changes: 3 additions & 6 deletions .github/workflows/R-CMD-check-wsl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,8 @@ jobs:

- uses: actions/checkout@v4

- uses: r-lib/actions/[email protected]
with:
r-version: 'release'
rtools-version: '42'
- uses: r-lib/actions/[email protected]
- uses: r-lib/actions/[email protected]
- uses: r-lib/actions/[email protected]

- name: Query dependencies
run: |
Expand All @@ -58,7 +55,7 @@ jobs:
install.packages("curl")
shell: Rscript {0}

- uses: Vampire/setup-wsl@v2
- uses: Vampire/setup-wsl@v3
with:
distribution: Ubuntu-22.04
use-cache: 'false'
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ jobs:
sudo apt-get install -y libcurl4-openssl-dev || true
sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev || true
- uses: r-lib/actions/[email protected].2
- uses: r-lib/actions/[email protected].7
with:
r-version: ${{ matrix.config.r }}
rtools-version: ${{ matrix.config.rtools }}
- uses: r-lib/actions/[email protected].2
- uses: r-lib/actions/[email protected].7

- name: Query dependencies
run: |
Expand Down
11 changes: 4 additions & 7 deletions .github/workflows/Test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ jobs:
if: "!startsWith(github.ref, 'refs/tags/') && github.ref != 'refs/heads/master'"
- uses: actions/checkout@v4

- uses: r-lib/actions/[email protected].2
- uses: r-lib/actions/[email protected].2
- uses: r-lib/actions/[email protected].7
- uses: r-lib/actions/[email protected].7

- name: Install Ubuntu dependencies
run: |
Expand Down Expand Up @@ -85,12 +85,9 @@ jobs:
steps:
- uses: actions/checkout@v4

- uses: r-lib/actions/[email protected]
with:
r-version: 'release'
rtools-version: '42'
- uses: r-lib/actions/[email protected]

- uses: r-lib/actions/[email protected].2
- uses: r-lib/actions/[email protected].7

- name: Query dependencies
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/cmdstan-tarball-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ jobs:
sudo apt-get install -y libcurl4-openssl-dev || true
sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev || true
- uses: r-lib/actions/[email protected].2
- uses: r-lib/actions/[email protected].7
with:
r-version: ${{ matrix.config.r }}
rtools-version: ${{ matrix.config.rtools }}

- uses: r-lib/actions/[email protected].2
- uses: r-lib/actions/[email protected].7

- name: Query dependencies
run: |
Expand Down
7 changes: 7 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ S3method(as_draws,CmdStanMCMC)
S3method(as_draws,CmdStanMLE)
S3method(as_draws,CmdStanPathfinder)
S3method(as_draws,CmdStanVB)
export(as.CmdStanDiagnose)
export(as.CmdStanGQ)
export(as.CmdStanLaplace)
export(as.CmdStanMCMC)
export(as.CmdStanMLE)
export(as.CmdStanPathfinder)
export(as.CmdStanVB)
export(as_cmdstan_fit)
export(as_draws)
export(as_mcmc.list)
Expand Down
3 changes: 3 additions & 0 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ SampleArgs <- R6::R6Class(
fileext = ".json"
)
for (i in seq_along(inv_metric_paths)) {
if (length(inv_metric[[i]]) == 1 && metric == "diag_e") {
inv_metric[[i]] <- array(inv_metric[[i]], dim = c(1))
}
write_stan_json(list(inv_metric = inv_metric[[i]]), inv_metric_paths[i])
}

Expand Down
5 changes: 4 additions & 1 deletion R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -1784,7 +1784,10 @@ inv_metric <- function(matrix = TRUE) {
out <- private$inv_metric_
if (matrix && !is.matrix(out[[1]])) {
# convert each vector to a diagonal matrix
out <- lapply(out, diag)
out <- lapply(out, function(x) diag(x, nrow = length(x)))
} else if (length(out[[1]]) == 1) {
# convert each scalar to an array with dimension 1
out <- lapply(out, array, dim = c(1))
}
out
}
Expand Down
54 changes: 54 additions & 0 deletions R/generics.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@

#' Coercion methods for CmdStan objects
#'
#' These methods are used to coerce objects into `cmdstanr` objects.
#' Primarily intended for other packages to use when interfacing
#' with `cmdstanr`.
#'
#' @param object to be coerced
#' @param ... additional arguments
#'
#' @name cmdstan_coercion
NULL

#' @rdname cmdstan_coercion
#' @export
as.CmdStanMCMC <- function(object, ...) {
UseMethod("as.CmdStanMCMC")
}

#' @rdname cmdstan_coercion
#' @export
as.CmdStanMLE <- function(object, ...) {
UseMethod("as.CmdStanMLE")
}

#' @rdname cmdstan_coercion
#' @export
as.CmdStanLaplace <- function(object, ...) {
UseMethod("as.CmdStanLaplace")
}

#' @rdname cmdstan_coercion
#' @export
as.CmdStanVB <- function(object, ...) {
UseMethod("as.CmdStanVB")
}

#' @rdname cmdstan_coercion
#' @export
as.CmdStanPathfinder <- function(object, ...) {
UseMethod("as.CmdStanPathfinder")
}

#' @rdname cmdstan_coercion
#' @export
as.CmdStanGQ <- function(object, ...) {
UseMethod("as.CmdStanGQ")
}

#' @rdname cmdstan_coercion
#' @export
as.CmdStanDiagnose <- function(object, ...) {
UseMethod("as.CmdStanDiagnose")
}
4 changes: 3 additions & 1 deletion R/install.R
Original file line number Diff line number Diff line change
Expand Up @@ -855,8 +855,10 @@ rtools4x_version <- function() {
rtools_ver <- "40"
} else if (R.version$minor < "3.0") {
rtools_ver <- "42"
} else {
} else if (R.version$minor < "5.0") {
rtools_ver <- "43"
} else {
rtools_ver <- "44"
}
rtools_ver
}
Expand Down
5 changes: 3 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,9 @@ assert_valid_draws_format <- function(format) {
}
if (format %in% c("rvars", "draws_rvars")) {
stop(
"\nWe are fixing a bug in fit$draws(format = 'draws_rvars').",
"\nFor now please use posterior::as_draws_rvars(fit$draws()) instead."
"\nTo use the rvar format please convert after extracting the draws, ",
"e.g., posterior::as_draws_rvars(fit$draws()).",
call. = FALSE
)
}
}
Expand Down
37 changes: 37 additions & 0 deletions man/cmdstan_coercion.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

74 changes: 74 additions & 0 deletions tests/testthat/test-model-sample-metric.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ set_cmdstan_path()
mod <- testing_model("bernoulli")
data_list <- testing_data("bernoulli")

mod2 <- testing_model("logistic")
data_list2 <- testing_data("logistic")


test_that("sample() method works with provided inv_metrics", {
inv_metric_vector <- array(1, dim = c(1))
Expand Down Expand Up @@ -54,6 +57,77 @@ test_that("sample() method works with provided inv_metrics", {
})


test_that("sample() method works with inv_metrics extracted from previous fit with 1 parameter", {
expect_sample_output(fit_r <- mod$sample(data = data_list,
chains = 2,
seed = 123))
inv_metric_vector <- fit_r$inv_metric(matrix = FALSE)
inv_metric_matrix <- fit_r$inv_metric()

expect_equal(dim(inv_metric_vector[[1]]), 1)
expect_equal(dim(inv_metric_matrix[[1]]), c(1, 1))

expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list,
chains = 1,
metric = "diag_e",
inv_metric = inv_metric_vector[[1]],
seed = 123)))

expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list,
chains = 1,
metric = "dense_e",
inv_metric = inv_metric_matrix[[1]],
seed = 123)))

expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list,
chains = 2,
metric = "diag_e",
inv_metric = inv_metric_vector,
seed = 123)))

expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list,
chains = 2,
metric = "dense_e",
inv_metric = inv_metric_matrix,
seed = 123)))
})

test_that("sample() method works with inv_metrics extracted from previous fit with > 1 parameter", {
expect_sample_output(fit_r <- mod2$sample(data = data_list2,
chains = 2,
seed = 123))
inv_metric_vector <- fit_r$inv_metric(matrix = FALSE)
inv_metric_matrix <- fit_r$inv_metric()

expect_equal(length(inv_metric_vector[[1]]), 4)
expect_equal(dim(inv_metric_matrix[[1]]), c(4, 4))

expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
chains = 1,
metric = "diag_e",
inv_metric = inv_metric_vector[[1]],
seed = 123)))

expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
chains = 1,
metric = "dense_e",
inv_metric = inv_metric_matrix[[1]],
seed = 123)))

expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
chains = 2,
metric = "diag_e",
inv_metric = inv_metric_vector,
seed = 123)))

expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
chains = 2,
metric = "dense_e",
inv_metric = inv_metric_matrix,
seed = 123)))
})


test_that("sample() method works with lists of inv_metrics", {
inv_metric_vector <- array(1, dim = c(1))
inv_metric_vector_json <- test_path("resources", "metric", "bernoulli.inv_metric.diag_e.json")
Expand Down

0 comments on commit 6daf429

Please sign in to comment.