Skip to content

Commit

Permalink
Merge pull request #320 from stan-dev/matrixOps
Browse files Browse the repository at this point in the history
Overload matrix multiplication operator for rvars in R >= 4.3
  • Loading branch information
paul-buerkner authored Nov 19, 2023
2 parents c6ab463 + 263a5e6 commit 9059111
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 10 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: posterior
Title: Tools for Working with Posterior Distributions
Version: 1.5.0
Version: 1.5.0.9000
Date: 2023-10-31
Authors@R: c(person("Paul-Christian", "Bürkner", email = "[email protected]", role = c("aut", "cre")),
person("Jonah", "Gabry", email = "[email protected]", role = c("aut")),
Expand Down
8 changes: 8 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# posterior (development version)

### Enhancements

* Matrix multiplication of `rvar`s can now be done with the base matrix
multiplication operator (`%*%`) instead of `%**%` in R >= 4.3.


# posterior 1.5.0

### Enhancements
Expand Down
20 changes: 18 additions & 2 deletions R/rvar-math.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ Math.rvar_factor <- function(x, ...) {
#' is used to efficiently multiply matrices across draws, so if either `x` or `y` is an [`rvar`],
#' `x %**% y` will be much faster than `rdo(x %*% y)`.
#'
#' Because [`rvar`] is an S3 class and S3 classes cannot properly override `%*%`, [`rvar`]s use
#' `%**%` for matrix multiplication.
#' In R >= 4.3, you can also use `%*%` in place of `%**%` for matrix multiplication
#' of [`rvar`]s. In R < 4.3, S3 classes cannot properly override `%*%`, so
#' you must use `%**%` for matrix multiplication of [`rvar`]s.
#'
#' @return An [`rvar`] representing the matrix product of `x` and `y`.
#'
Expand Down Expand Up @@ -208,6 +209,21 @@ Math.rvar_factor <- function(x, ...) {
new_rvar(result, .nchains = nchains(x))
}

# This generic is not exported here as matrixOps is only in R >= 4.3, so we must
# conditionally export it in .onLoad() for compatibility with older versions
#' @rdname rvar-matmult
#' @method matrixOps rvar
matrixOps.rvar <- function(x, y) {
# as of R 4.3 this group generic is only used for %*%, but that may change
# in the future (crossprod and tcrossprod are planned), so we include this
# check for safety purposes
if (.Generic != "%*%") {
stop_no_call("Cannot apply `", .Generic, "` operator to rvar objects.")
}
x %**% y
}


#' Cholesky decomposition of random matrix
#'
#' Cholesky decomposition of an [`rvar`] containing a matrix.
Expand Down
6 changes: 6 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,10 @@
# See help("s3_register", package = "vctrs") for more information.
vctrs::s3_register("dplyr::dplyr_reconstruct", "draws_df")
vctrs::s3_register("ggplot2::scale_type", "rvar")

# S3 methods for matrixOps, which is a group generic that didn't exist
# until R 4.3, so we can't register it in NAMESPACE
if (getRversion() >= "4.3") {
registerS3method("matrixOps", "rvar", matrixOps.rvar, asNamespace("base"))
}
}
3 changes: 3 additions & 0 deletions man/draws-index.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/rename_variables.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions man/rvar-matmult.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/weight_draws.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions tests/testthat/test-rvar-math.R
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,15 @@ test_that("matrix multiplication works", {

})

test_that("%*% works in R >= 4.3", {
skip_if_not(getRversion() >= "4.3")

x <- rvar(array(1:24, dim = c(4,2,3)))
y <- rvar(array(c(2:13,12:1), dim = c(4,3,2)))

expect_equal(x %*% y, x %**% y)
})

test_that("diag works", {
Sigma <- as_draws_rvars(example_draws("multi_normal"))$Sigma

Expand Down
10 changes: 7 additions & 3 deletions vignettes/rvar.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,12 @@ mu + 1
```

Matrix multiplication is also implemented (using a tensor product under the hood).
Because the normal matrix multiplication operator in R (`%*%`) cannot be properly
implemented for S3 datatypes, `rvar` uses `%**%` instead. A trivial example:
In R < 4.3, the normal matrix multiplication operator (`%*%`) cannot be properly
implemented for S3 datatypes, so `rvar` uses `%**%` instead. In R ≥ 4.3, which
does support matrix multiplication for S3 datatypes, you can use `%*%` to
matrix-multiply `rvar`s.

A trivial example:

```{r matrix_mult}
Sigma %**% diag(1:3)
Expand All @@ -223,7 +227,7 @@ includes:
| Logical operators | `&`, `|`, `!` |
| Comparison operators | `==`, `!=`, `<`, `<=`, `>=`, `>` |
| Value matching | `match()`, `%in%` |
| Matrix multiplication | `%**%` |
| Matrix multiplication | `%**%`, `%*%` (R ≥ 4.3 only) |
| Basic functions | `abs()`, `sign()`<br>`sqrt()`<br>`floor()`, `ceiling()`, `trunc()`, `round()`, `signif()` |
| Logarithms and exponentials | `exp()`, `expm1()`<br>`log()`, `log10()`, `log2()`, `log1p()` |
| Trigonometric functions | `cos()`, `sin()`, `tan()`<br>`cospi()`, `sinpi()`, `tanpi()`<br>`acos()`, `asin()`, `atan()`|
Expand Down

0 comments on commit 9059111

Please sign in to comment.