diff --git a/DESCRIPTION b/DESCRIPTION index 52a4b00d..7212372a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: posterior Title: Tools for Working with Posterior Distributions -Version: 1.5.0 +Version: 1.5.0.9000 Date: 2023-10-31 Authors@R: c(person("Paul-Christian", "Bürkner", email = "paul.buerkner@gmail.com", role = c("aut", "cre")), person("Jonah", "Gabry", email = "jsg2201@columbia.edu", role = c("aut")), diff --git a/NEWS.md b/NEWS.md index b5259695..7b9af2b6 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,11 @@ +# posterior (development version) + +### Enhancements + +* Matrix multiplication of `rvar`s can now be done with the base matrix + multiplication operator (`%*%`) instead of `%**%` in R >= 4.3. + + # posterior 1.5.0 ### Enhancements diff --git a/R/rvar-math.R b/R/rvar-math.R index 37fae7d0..11a8ca3c 100755 --- a/R/rvar-math.R +++ b/R/rvar-math.R @@ -135,8 +135,9 @@ Math.rvar_factor <- function(x, ...) { #' is used to efficiently multiply matrices across draws, so if either `x` or `y` is an [`rvar`], #' `x %**% y` will be much faster than `rdo(x %*% y)`. #' -#' Because [`rvar`] is an S3 class and S3 classes cannot properly override `%*%`, [`rvar`]s use -#' `%**%` for matrix multiplication. +#' In R >= 4.3, you can also use `%*%` in place of `%**%` for matrix multiplication +#' of [`rvar`]s. In R < 4.3, S3 classes cannot properly override `%*%`, so +#' you must use `%**%` for matrix multiplication of [`rvar`]s. #' #' @return An [`rvar`] representing the matrix product of `x` and `y`. #' @@ -208,6 +209,21 @@ Math.rvar_factor <- function(x, ...) { new_rvar(result, .nchains = nchains(x)) } +# This generic is not exported here as matrixOps is only in R >= 4.3, so we must +# conditionally export it in .onLoad() for compatibility with older versions +#' @rdname rvar-matmult +#' @method matrixOps rvar +matrixOps.rvar <- function(x, y) { + # as of R 4.3 this group generic is only used for %*%, but that may change + # in the future (crossprod and tcrossprod are planned), so we include this + # check for safety purposes + if (.Generic != "%*%") { + stop_no_call("Cannot apply `", .Generic, "` operator to rvar objects.") + } + x %**% y +} + + #' Cholesky decomposition of random matrix #' #' Cholesky decomposition of an [`rvar`] containing a matrix. diff --git a/R/zzz.R b/R/zzz.R index b2e73811..f8744d2c 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -9,4 +9,10 @@ # See help("s3_register", package = "vctrs") for more information. vctrs::s3_register("dplyr::dplyr_reconstruct", "draws_df") vctrs::s3_register("ggplot2::scale_type", "rvar") + + # S3 methods for matrixOps, which is a group generic that didn't exist + # until R 4.3, so we can't register it in NAMESPACE + if (getRversion() >= "4.3") { + registerS3method("matrixOps", "rvar", matrixOps.rvar, asNamespace("base")) + } } diff --git a/man/draws-index.Rd b/man/draws-index.Rd index d126e1ff..a1e21ec8 100644 --- a/man/draws-index.Rd +++ b/man/draws-index.Rd @@ -78,3 +78,6 @@ draw_ids(x) ndraws(x) } +\seealso{ +\code{\link{variables}}, \code{\link{rename_variables}} +} diff --git a/man/rename_variables.Rd b/man/rename_variables.Rd index 2f93355d..85e9df9c 100755 --- a/man/rename_variables.Rd +++ b/man/rename_variables.Rd @@ -41,5 +41,5 @@ variables(x) } \seealso{ -\code{\link{variables}}, \code{\link{mutate_variables}} +\code{\link{variables}}, \code{\link{set_variables}}, \code{\link{mutate_variables}} } diff --git a/man/rvar-matmult.Rd b/man/rvar-matmult.Rd index a5b967d4..1cab7364 100755 --- a/man/rvar-matmult.Rd +++ b/man/rvar-matmult.Rd @@ -3,9 +3,12 @@ \name{rvar-matmult} \alias{rvar-matmult} \alias{\%**\%} +\alias{matrixOps.rvar} \title{Matrix multiplication of random variables} \usage{ x \%**\% y + +\method{matrixOps}{rvar}(x, y) } \arguments{ \item{x}{(multiple options) The object to be postmultiplied by \code{y}: @@ -39,8 +42,9 @@ by \code{\link{rvar}}s and are broadcasted across all draws of the \code{\link{r is used to efficiently multiply matrices across draws, so if either \code{x} or \code{y} is an \code{\link{rvar}}, \code{x \%**\% y} will be much faster than \code{rdo(x \%*\% y)}. -Because \code{\link{rvar}} is an S3 class and S3 classes cannot properly override \code{\%*\%}, \code{\link{rvar}}s use -\verb{\%**\%} for matrix multiplication. +In R >= 4.3, you can also use \code{\%*\%} in place of \verb{\%**\%} for matrix multiplication +of \code{\link{rvar}}s. In R < 4.3, S3 classes cannot properly override \code{\%*\%}, so +you must use \verb{\%**\%} for matrix multiplication of \code{\link{rvar}}s. } \examples{ diff --git a/man/weight_draws.Rd b/man/weight_draws.Rd index ecd611c9..4601c983 100644 --- a/man/weight_draws.Rd +++ b/man/weight_draws.Rd @@ -32,7 +32,7 @@ can be returned via the \code{\link[=weights.draws]{weights.draws()}} method lat \item{...}{Arguments passed to individual methods (if applicable).} -\item{log}{(logicla) Are the weights passed already on the log scale? The +\item{log}{(logical) Are the weights passed already on the log scale? The default is \code{FALSE}, that is, expecting \code{weights} to be on the standard (non-log) scale.} } diff --git a/tests/testthat/test-rvar-math.R b/tests/testthat/test-rvar-math.R index 1088e1c7..597b7805 100755 --- a/tests/testthat/test-rvar-math.R +++ b/tests/testthat/test-rvar-math.R @@ -262,6 +262,15 @@ test_that("matrix multiplication works", { }) +test_that("%*% works in R >= 4.3", { + skip_if_not(getRversion() >= "4.3") + + x <- rvar(array(1:24, dim = c(4,2,3))) + y <- rvar(array(c(2:13,12:1), dim = c(4,3,2))) + + expect_equal(x %*% y, x %**% y) +}) + test_that("diag works", { Sigma <- as_draws_rvars(example_draws("multi_normal"))$Sigma diff --git a/vignettes/rvar.Rmd b/vignettes/rvar.Rmd index 25bcf61d..fe61835a 100755 --- a/vignettes/rvar.Rmd +++ b/vignettes/rvar.Rmd @@ -207,8 +207,12 @@ mu + 1 ``` Matrix multiplication is also implemented (using a tensor product under the hood). -Because the normal matrix multiplication operator in R (`%*%`) cannot be properly -implemented for S3 datatypes, `rvar` uses `%**%` instead. A trivial example: +In R < 4.3, the normal matrix multiplication operator (`%*%`) cannot be properly +implemented for S3 datatypes, so `rvar` uses `%**%` instead. In R ≥ 4.3, which +does support matrix multiplication for S3 datatypes, you can use `%*%` to +matrix-multiply `rvar`s. + +A trivial example: ```{r matrix_mult} Sigma %**% diag(1:3) @@ -223,7 +227,7 @@ includes: | Logical operators | `&`, `|`, `!` | | Comparison operators | `==`, `!=`, `<`, `<=`, `>=`, `>` | | Value matching | `match()`, `%in%` | -| Matrix multiplication | `%**%` | +| Matrix multiplication | `%**%`, `%*%` (R ≥ 4.3 only) | | Basic functions | `abs()`, `sign()`
`sqrt()`
`floor()`, `ceiling()`, `trunc()`, `round()`, `signif()` | | Logarithms and exponentials | `exp()`, `expm1()`
`log()`, `log10()`, `log2()`, `log1p()` | | Trigonometric functions | `cos()`, `sin()`, `tan()`
`cospi()`, `sinpi()`, `tanpi()`
`acos()`, `asin()`, `atan()`|