From faf695cbca6081a9ac03f3ba1b3493e0ff7152ab Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Sat, 18 Nov 2023 00:19:21 -0600 Subject: [PATCH 1/2] add support for %*% with rvars in R > 4.3 --- DESCRIPTION | 2 +- NAMESPACE | 1 + NEWS.md | 8 ++++++++ R/rvar-math.R | 23 +++++++++++++++++++---- man/draws-index.Rd | 3 +++ man/rename_variables.Rd | 2 +- man/rvar-matmult.Rd | 12 ++++++++---- man/weight_draws.Rd | 2 +- tests/testthat/test-rvar-math.R | 9 +++++++++ vignettes/rvar.Rmd | 10 +++++++--- 10 files changed, 58 insertions(+), 14 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 52a4b00d..7212372a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: posterior Title: Tools for Working with Posterior Distributions -Version: 1.5.0 +Version: 1.5.0.9000 Date: 2023-10-31 Authors@R: c(person("Paul-Christian", "Bürkner", email = "paul.buerkner@gmail.com", role = c("aut", "cre")), person("Jonah", "Gabry", email = "jsg2201@columbia.edu", role = c("aut")), diff --git a/NAMESPACE b/NAMESPACE index 2d81d65d..142224de 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -170,6 +170,7 @@ S3method(mad,rvar) S3method(mad,rvar_ordered) S3method(match,default) S3method(match,rvar) +S3method(matrixOps,rvar) S3method(max,rvar) S3method(mcse_mean,default) S3method(mcse_mean,rvar) diff --git a/NEWS.md b/NEWS.md index b5259695..7b9af2b6 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,11 @@ +# posterior (development version) + +### Enhancements + +* Matrix multiplication of `rvar`s can now be done with the base matrix + multiplication operator (`%*%`) instead of `%**%` in R >= 4.3. + + # posterior 1.5.0 ### Enhancements diff --git a/R/rvar-math.R b/R/rvar-math.R index 37fae7d0..6676446b 100755 --- a/R/rvar-math.R +++ b/R/rvar-math.R @@ -114,14 +114,14 @@ Math.rvar_factor <- function(x, ...) { #' #' @name rvar-matmult #' @aliases %**% -#' @param x (multiple options) The object to be postmultiplied by `y`: +#' @param x,e1 (multiple options) The object to be postmultiplied by `y`: #' * An [`rvar`] #' * A [`numeric`] vector or matrix #' * A [`logical`] vector or matrix #' #' If a vector is used, it is treated as a *row* vector. #' -#' @param y (multiple options) The object to be premultiplied by `x`: +#' @param y,e2 (multiple options) The object to be premultiplied by `x`: #' * An [`rvar`] #' * A [`numeric`] vector or matrix #' * A [`logical`] vector or matrix @@ -135,8 +135,9 @@ Math.rvar_factor <- function(x, ...) { #' is used to efficiently multiply matrices across draws, so if either `x` or `y` is an [`rvar`], #' `x %**% y` will be much faster than `rdo(x %*% y)`. #' -#' Because [`rvar`] is an S3 class and S3 classes cannot properly override `%*%`, [`rvar`]s use -#' `%**%` for matrix multiplication. +#' In R >= 4.3, you can also use `%*%` in place of `%**%` for matrix multiplication +#' of [`rvar`]s. In R < 4.3, S3 classes cannot properly override `%*%`, so +#' you must use `%**%` for matrix multiplication of [`rvar`]s. #' #' @return An [`rvar`] representing the matrix product of `x` and `y`. #' @@ -208,6 +209,20 @@ Math.rvar_factor <- function(x, ...) { new_rvar(result, .nchains = nchains(x)) } +#' @rdname rvar-matmult +#' @method matrixOps rvar +#' @export +matrixOps.rvar <- function(e1, e2) { + # as of R 4.3 this group generic is only used for %*%, but that may change + # in the future (crossprod and tcrossprod are planned), so we include this + # check for safety purposes + if (.Generic != "%*%") { + stop_no_call("Cannot apply `", .Generic, "` operator to rvar objects.") + } + e1 %**% e2 +} + + #' Cholesky decomposition of random matrix #' #' Cholesky decomposition of an [`rvar`] containing a matrix. diff --git a/man/draws-index.Rd b/man/draws-index.Rd index d126e1ff..a1e21ec8 100644 --- a/man/draws-index.Rd +++ b/man/draws-index.Rd @@ -78,3 +78,6 @@ draw_ids(x) ndraws(x) } +\seealso{ +\code{\link{variables}}, \code{\link{rename_variables}} +} diff --git a/man/rename_variables.Rd b/man/rename_variables.Rd index 2f93355d..85e9df9c 100755 --- a/man/rename_variables.Rd +++ b/man/rename_variables.Rd @@ -41,5 +41,5 @@ variables(x) } \seealso{ -\code{\link{variables}}, \code{\link{mutate_variables}} +\code{\link{variables}}, \code{\link{set_variables}}, \code{\link{mutate_variables}} } diff --git a/man/rvar-matmult.Rd b/man/rvar-matmult.Rd index a5b967d4..5dcd9024 100755 --- a/man/rvar-matmult.Rd +++ b/man/rvar-matmult.Rd @@ -3,12 +3,15 @@ \name{rvar-matmult} \alias{rvar-matmult} \alias{\%**\%} +\alias{matrixOps.rvar} \title{Matrix multiplication of random variables} \usage{ x \%**\% y + +\method{matrixOps}{rvar}(e1, e2) } \arguments{ -\item{x}{(multiple options) The object to be postmultiplied by \code{y}: +\item{x, e1}{(multiple options) The object to be postmultiplied by \code{y}: \itemize{ \item An \code{\link{rvar}} \item A \code{\link{numeric}} vector or matrix @@ -17,7 +20,7 @@ x \%**\% y If a vector is used, it is treated as a \emph{row} vector.} -\item{y}{(multiple options) The object to be premultiplied by \code{x}: +\item{y, e2}{(multiple options) The object to be premultiplied by \code{x}: \itemize{ \item An \code{\link{rvar}} \item A \code{\link{numeric}} vector or matrix @@ -39,8 +42,9 @@ by \code{\link{rvar}}s and are broadcasted across all draws of the \code{\link{r is used to efficiently multiply matrices across draws, so if either \code{x} or \code{y} is an \code{\link{rvar}}, \code{x \%**\% y} will be much faster than \code{rdo(x \%*\% y)}. -Because \code{\link{rvar}} is an S3 class and S3 classes cannot properly override \code{\%*\%}, \code{\link{rvar}}s use -\verb{\%**\%} for matrix multiplication. +In R >= 4.3, you can also use \code{\%*\%} in place of \verb{\%**\%} for matrix multiplication +of \code{\link{rvar}}s. In R < 4.3, S3 classes cannot properly override \code{\%*\%}, so +you must use \verb{\%**\%} for matrix multiplication of \code{\link{rvar}}s. } \examples{ diff --git a/man/weight_draws.Rd b/man/weight_draws.Rd index ecd611c9..4601c983 100644 --- a/man/weight_draws.Rd +++ b/man/weight_draws.Rd @@ -32,7 +32,7 @@ can be returned via the \code{\link[=weights.draws]{weights.draws()}} method lat \item{...}{Arguments passed to individual methods (if applicable).} -\item{log}{(logicla) Are the weights passed already on the log scale? The +\item{log}{(logical) Are the weights passed already on the log scale? The default is \code{FALSE}, that is, expecting \code{weights} to be on the standard (non-log) scale.} } diff --git a/tests/testthat/test-rvar-math.R b/tests/testthat/test-rvar-math.R index 1088e1c7..597b7805 100755 --- a/tests/testthat/test-rvar-math.R +++ b/tests/testthat/test-rvar-math.R @@ -262,6 +262,15 @@ test_that("matrix multiplication works", { }) +test_that("%*% works in R >= 4.3", { + skip_if_not(getRversion() >= "4.3") + + x <- rvar(array(1:24, dim = c(4,2,3))) + y <- rvar(array(c(2:13,12:1), dim = c(4,3,2))) + + expect_equal(x %*% y, x %**% y) +}) + test_that("diag works", { Sigma <- as_draws_rvars(example_draws("multi_normal"))$Sigma diff --git a/vignettes/rvar.Rmd b/vignettes/rvar.Rmd index 25bcf61d..fe61835a 100755 --- a/vignettes/rvar.Rmd +++ b/vignettes/rvar.Rmd @@ -207,8 +207,12 @@ mu + 1 ``` Matrix multiplication is also implemented (using a tensor product under the hood). -Because the normal matrix multiplication operator in R (`%*%`) cannot be properly -implemented for S3 datatypes, `rvar` uses `%**%` instead. A trivial example: +In R < 4.3, the normal matrix multiplication operator (`%*%`) cannot be properly +implemented for S3 datatypes, so `rvar` uses `%**%` instead. In R ≥ 4.3, which +does support matrix multiplication for S3 datatypes, you can use `%*%` to +matrix-multiply `rvar`s. + +A trivial example: ```{r matrix_mult} Sigma %**% diag(1:3) @@ -223,7 +227,7 @@ includes: | Logical operators | `&`, `|`, `!` | | Comparison operators | `==`, `!=`, `<`, `<=`, `>=`, `>` | | Value matching | `match()`, `%in%` | -| Matrix multiplication | `%**%` | +| Matrix multiplication | `%**%`, `%*%` (R ≥ 4.3 only) | | Basic functions | `abs()`, `sign()`
`sqrt()`
`floor()`, `ceiling()`, `trunc()`, `round()`, `signif()` | | Logarithms and exponentials | `exp()`, `expm1()`
`log()`, `log10()`, `log2()`, `log1p()` | | Trigonometric functions | `cos()`, `sin()`, `tan()`
`cospi()`, `sinpi()`, `tanpi()`
`acos()`, `asin()`, `atan()`| From 263a5e695625d5f3b0e352f3fe5efdceb4c28e29 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Sat, 18 Nov 2023 10:26:46 -0600 Subject: [PATCH 2/2] delayed registration for matrixOps for compatibilty with R < 4.3 --- NAMESPACE | 1 - R/rvar-math.R | 11 ++++++----- R/zzz.R | 6 ++++++ man/rvar-matmult.Rd | 6 +++--- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 142224de..2d81d65d 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -170,7 +170,6 @@ S3method(mad,rvar) S3method(mad,rvar_ordered) S3method(match,default) S3method(match,rvar) -S3method(matrixOps,rvar) S3method(max,rvar) S3method(mcse_mean,default) S3method(mcse_mean,rvar) diff --git a/R/rvar-math.R b/R/rvar-math.R index 6676446b..11a8ca3c 100755 --- a/R/rvar-math.R +++ b/R/rvar-math.R @@ -114,14 +114,14 @@ Math.rvar_factor <- function(x, ...) { #' #' @name rvar-matmult #' @aliases %**% -#' @param x,e1 (multiple options) The object to be postmultiplied by `y`: +#' @param x (multiple options) The object to be postmultiplied by `y`: #' * An [`rvar`] #' * A [`numeric`] vector or matrix #' * A [`logical`] vector or matrix #' #' If a vector is used, it is treated as a *row* vector. #' -#' @param y,e2 (multiple options) The object to be premultiplied by `x`: +#' @param y (multiple options) The object to be premultiplied by `x`: #' * An [`rvar`] #' * A [`numeric`] vector or matrix #' * A [`logical`] vector or matrix @@ -209,17 +209,18 @@ Math.rvar_factor <- function(x, ...) { new_rvar(result, .nchains = nchains(x)) } +# This generic is not exported here as matrixOps is only in R >= 4.3, so we must +# conditionally export it in .onLoad() for compatibility with older versions #' @rdname rvar-matmult #' @method matrixOps rvar -#' @export -matrixOps.rvar <- function(e1, e2) { +matrixOps.rvar <- function(x, y) { # as of R 4.3 this group generic is only used for %*%, but that may change # in the future (crossprod and tcrossprod are planned), so we include this # check for safety purposes if (.Generic != "%*%") { stop_no_call("Cannot apply `", .Generic, "` operator to rvar objects.") } - e1 %**% e2 + x %**% y } diff --git a/R/zzz.R b/R/zzz.R index b2e73811..f8744d2c 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -9,4 +9,10 @@ # See help("s3_register", package = "vctrs") for more information. vctrs::s3_register("dplyr::dplyr_reconstruct", "draws_df") vctrs::s3_register("ggplot2::scale_type", "rvar") + + # S3 methods for matrixOps, which is a group generic that didn't exist + # until R 4.3, so we can't register it in NAMESPACE + if (getRversion() >= "4.3") { + registerS3method("matrixOps", "rvar", matrixOps.rvar, asNamespace("base")) + } } diff --git a/man/rvar-matmult.Rd b/man/rvar-matmult.Rd index 5dcd9024..1cab7364 100755 --- a/man/rvar-matmult.Rd +++ b/man/rvar-matmult.Rd @@ -8,10 +8,10 @@ \usage{ x \%**\% y -\method{matrixOps}{rvar}(e1, e2) +\method{matrixOps}{rvar}(x, y) } \arguments{ -\item{x, e1}{(multiple options) The object to be postmultiplied by \code{y}: +\item{x}{(multiple options) The object to be postmultiplied by \code{y}: \itemize{ \item An \code{\link{rvar}} \item A \code{\link{numeric}} vector or matrix @@ -20,7 +20,7 @@ x \%**\% y If a vector is used, it is treated as a \emph{row} vector.} -\item{y, e2}{(multiple options) The object to be premultiplied by \code{x}: +\item{y}{(multiple options) The object to be premultiplied by \code{x}: \itemize{ \item An \code{\link{rvar}} \item A \code{\link{numeric}} vector or matrix