diff --git a/.github/workflows/R-CMD-check-wsl.yaml b/.github/workflows/R-CMD-check-wsl.yaml index b24a728c..a62e58b8 100644 --- a/.github/workflows/R-CMD-check-wsl.yaml +++ b/.github/workflows/R-CMD-check-wsl.yaml @@ -37,11 +37,8 @@ jobs: - uses: actions/checkout@v4 - - uses: r-lib/actions/setup-r@v2.8.2 - with: - r-version: 'release' - rtools-version: '42' - - uses: r-lib/actions/setup-pandoc@v2.8.2 + - uses: r-lib/actions/setup-r@v2.8.7 + - uses: r-lib/actions/setup-pandoc@v2.8.7 - name: Query dependencies run: | @@ -58,7 +55,7 @@ jobs: install.packages("curl") shell: Rscript {0} - - uses: Vampire/setup-wsl@v2 + - uses: Vampire/setup-wsl@v3 with: distribution: Ubuntu-22.04 use-cache: 'false' diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index 1ff60668..9830ff5d 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -57,11 +57,11 @@ jobs: sudo apt-get install -y libcurl4-openssl-dev || true sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev || true - - uses: r-lib/actions/setup-r@v2.8.2 + - uses: r-lib/actions/setup-r@v2.8.7 with: r-version: ${{ matrix.config.r }} rtools-version: ${{ matrix.config.rtools }} - - uses: r-lib/actions/setup-pandoc@v2.8.2 + - uses: r-lib/actions/setup-pandoc@v2.8.7 - name: Query dependencies run: | diff --git a/.github/workflows/Test-coverage.yaml b/.github/workflows/Test-coverage.yaml index 69c19fdc..311020ef 100644 --- a/.github/workflows/Test-coverage.yaml +++ b/.github/workflows/Test-coverage.yaml @@ -34,8 +34,8 @@ jobs: if: "!startsWith(github.ref, 'refs/tags/') && github.ref != 'refs/heads/master'" - uses: actions/checkout@v4 - - uses: r-lib/actions/setup-r@v2.8.2 - - uses: r-lib/actions/setup-pandoc@v2.8.2 + - uses: r-lib/actions/setup-r@v2.8.7 + - uses: r-lib/actions/setup-pandoc@v2.8.7 - name: Install Ubuntu dependencies run: | @@ -85,12 +85,9 @@ jobs: steps: - uses: actions/checkout@v4 - - uses: r-lib/actions/setup-r@v2.8.2 - with: - r-version: 'release' - rtools-version: '42' + - uses: r-lib/actions/setup-r@v2.8.7 - - uses: r-lib/actions/setup-pandoc@v2.8.2 + - uses: r-lib/actions/setup-pandoc@v2.8.7 - name: Query dependencies run: | diff --git a/.github/workflows/cmdstan-tarball-check.yaml b/.github/workflows/cmdstan-tarball-check.yaml index fd995436..d9869f78 100644 --- a/.github/workflows/cmdstan-tarball-check.yaml +++ b/.github/workflows/cmdstan-tarball-check.yaml @@ -40,12 +40,12 @@ jobs: sudo apt-get install -y libcurl4-openssl-dev || true sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev || true - - uses: r-lib/actions/setup-r@v2.8.2 + - uses: r-lib/actions/setup-r@v2.8.7 with: r-version: ${{ matrix.config.r }} rtools-version: ${{ matrix.config.rtools }} - - uses: r-lib/actions/setup-pandoc@v2.8.2 + - uses: r-lib/actions/setup-pandoc@v2.8.7 - name: Query dependencies run: | diff --git a/NAMESPACE b/NAMESPACE index c8a6217d..559452fb 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -6,6 +6,13 @@ S3method(as_draws,CmdStanMCMC) S3method(as_draws,CmdStanMLE) S3method(as_draws,CmdStanPathfinder) S3method(as_draws,CmdStanVB) +export(as.CmdStanDiagnose) +export(as.CmdStanGQ) +export(as.CmdStanLaplace) +export(as.CmdStanMCMC) +export(as.CmdStanMLE) +export(as.CmdStanPathfinder) +export(as.CmdStanVB) export(as_cmdstan_fit) export(as_draws) export(as_mcmc.list) diff --git a/R/args.R b/R/args.R index 01378ddf..2a317f0d 100644 --- a/R/args.R +++ b/R/args.R @@ -266,6 +266,9 @@ SampleArgs <- R6::R6Class( fileext = ".json" ) for (i in seq_along(inv_metric_paths)) { + if (length(inv_metric[[i]]) == 1 && metric == "diag_e") { + inv_metric[[i]] <- array(inv_metric[[i]], dim = c(1)) + } write_stan_json(list(inv_metric = inv_metric[[i]]), inv_metric_paths[i]) } diff --git a/R/fit.R b/R/fit.R index 8fc12402..ea527009 100644 --- a/R/fit.R +++ b/R/fit.R @@ -1784,7 +1784,10 @@ inv_metric <- function(matrix = TRUE) { out <- private$inv_metric_ if (matrix && !is.matrix(out[[1]])) { # convert each vector to a diagonal matrix - out <- lapply(out, diag) + out <- lapply(out, function(x) diag(x, nrow = length(x))) + } else if (length(out[[1]]) == 1) { + # convert each scalar to an array with dimension 1 + out <- lapply(out, array, dim = c(1)) } out } diff --git a/R/generics.R b/R/generics.R new file mode 100644 index 00000000..663cb859 --- /dev/null +++ b/R/generics.R @@ -0,0 +1,54 @@ + +#' Coercion methods for CmdStan objects +#' +#' These methods are used to coerce objects into `cmdstanr` objects. +#' Primarily intended for other packages to use when interfacing +#' with `cmdstanr`. +#' +#' @param object to be coerced +#' @param ... additional arguments +#' +#' @name cmdstan_coercion +NULL + +#' @rdname cmdstan_coercion +#' @export +as.CmdStanMCMC <- function(object, ...) { + UseMethod("as.CmdStanMCMC") +} + +#' @rdname cmdstan_coercion +#' @export +as.CmdStanMLE <- function(object, ...) { + UseMethod("as.CmdStanMLE") +} + +#' @rdname cmdstan_coercion +#' @export +as.CmdStanLaplace <- function(object, ...) { + UseMethod("as.CmdStanLaplace") +} + +#' @rdname cmdstan_coercion +#' @export +as.CmdStanVB <- function(object, ...) { + UseMethod("as.CmdStanVB") +} + +#' @rdname cmdstan_coercion +#' @export +as.CmdStanPathfinder <- function(object, ...) { + UseMethod("as.CmdStanPathfinder") +} + +#' @rdname cmdstan_coercion +#' @export +as.CmdStanGQ <- function(object, ...) { + UseMethod("as.CmdStanGQ") +} + +#' @rdname cmdstan_coercion +#' @export +as.CmdStanDiagnose <- function(object, ...) { + UseMethod("as.CmdStanDiagnose") +} diff --git a/R/install.R b/R/install.R index 7d92f5e9..995c14a9 100644 --- a/R/install.R +++ b/R/install.R @@ -855,8 +855,10 @@ rtools4x_version <- function() { rtools_ver <- "40" } else if (R.version$minor < "3.0") { rtools_ver <- "42" - } else { + } else if (R.version$minor < "5.0") { rtools_ver <- "43" + } else { + rtools_ver <- "44" } rtools_ver } diff --git a/R/utils.R b/R/utils.R index 2420fd69..05d04731 100644 --- a/R/utils.R +++ b/R/utils.R @@ -394,8 +394,9 @@ assert_valid_draws_format <- function(format) { } if (format %in% c("rvars", "draws_rvars")) { stop( - "\nWe are fixing a bug in fit$draws(format = 'draws_rvars').", - "\nFor now please use posterior::as_draws_rvars(fit$draws()) instead." + "\nTo use the rvar format please convert after extracting the draws, ", + "e.g., posterior::as_draws_rvars(fit$draws()).", + call. = FALSE ) } } diff --git a/man/cmdstan_coercion.Rd b/man/cmdstan_coercion.Rd new file mode 100644 index 00000000..d2d2b483 --- /dev/null +++ b/man/cmdstan_coercion.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/generics.R +\name{cmdstan_coercion} +\alias{cmdstan_coercion} +\alias{as.CmdStanMCMC} +\alias{as.CmdStanMLE} +\alias{as.CmdStanLaplace} +\alias{as.CmdStanVB} +\alias{as.CmdStanPathfinder} +\alias{as.CmdStanGQ} +\alias{as.CmdStanDiagnose} +\title{Coercion methods for CmdStan objects} +\usage{ +as.CmdStanMCMC(object, ...) + +as.CmdStanMLE(object, ...) + +as.CmdStanLaplace(object, ...) + +as.CmdStanVB(object, ...) + +as.CmdStanPathfinder(object, ...) + +as.CmdStanGQ(object, ...) + +as.CmdStanDiagnose(object, ...) +} +\arguments{ +\item{object}{to be coerced} + +\item{...}{additional arguments} +} +\description{ +These methods are used to coerce objects into \code{cmdstanr} objects. +Primarily intended for other packages to use when interfacing +with \code{cmdstanr}. +} diff --git a/tests/testthat/test-model-sample-metric.R b/tests/testthat/test-model-sample-metric.R index 38159d40..422442fa 100644 --- a/tests/testthat/test-model-sample-metric.R +++ b/tests/testthat/test-model-sample-metric.R @@ -4,6 +4,9 @@ set_cmdstan_path() mod <- testing_model("bernoulli") data_list <- testing_data("bernoulli") +mod2 <- testing_model("logistic") +data_list2 <- testing_data("logistic") + test_that("sample() method works with provided inv_metrics", { inv_metric_vector <- array(1, dim = c(1)) @@ -54,6 +57,77 @@ test_that("sample() method works with provided inv_metrics", { }) +test_that("sample() method works with inv_metrics extracted from previous fit with 1 parameter", { + expect_sample_output(fit_r <- mod$sample(data = data_list, + chains = 2, + seed = 123)) + inv_metric_vector <- fit_r$inv_metric(matrix = FALSE) + inv_metric_matrix <- fit_r$inv_metric() + + expect_equal(dim(inv_metric_vector[[1]]), 1) + expect_equal(dim(inv_metric_matrix[[1]]), c(1, 1)) + + expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list, + chains = 1, + metric = "diag_e", + inv_metric = inv_metric_vector[[1]], + seed = 123))) + + expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list, + chains = 1, + metric = "dense_e", + inv_metric = inv_metric_matrix[[1]], + seed = 123))) + + expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list, + chains = 2, + metric = "diag_e", + inv_metric = inv_metric_vector, + seed = 123))) + + expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list, + chains = 2, + metric = "dense_e", + inv_metric = inv_metric_matrix, + seed = 123))) +}) + +test_that("sample() method works with inv_metrics extracted from previous fit with > 1 parameter", { + expect_sample_output(fit_r <- mod2$sample(data = data_list2, + chains = 2, + seed = 123)) + inv_metric_vector <- fit_r$inv_metric(matrix = FALSE) + inv_metric_matrix <- fit_r$inv_metric() + + expect_equal(length(inv_metric_vector[[1]]), 4) + expect_equal(dim(inv_metric_matrix[[1]]), c(4, 4)) + + expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2, + chains = 1, + metric = "diag_e", + inv_metric = inv_metric_vector[[1]], + seed = 123))) + + expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2, + chains = 1, + metric = "dense_e", + inv_metric = inv_metric_matrix[[1]], + seed = 123))) + + expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2, + chains = 2, + metric = "diag_e", + inv_metric = inv_metric_vector, + seed = 123))) + + expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2, + chains = 2, + metric = "dense_e", + inv_metric = inv_metric_matrix, + seed = 123))) +}) + + test_that("sample() method works with lists of inv_metrics", { inv_metric_vector <- array(1, dim = c(1)) inv_metric_vector_json <- test_path("resources", "metric", "bernoulli.inv_metric.diag_e.json")