Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overload matrix multiplication operator for rvars in R >= 4.3 #320

Merged
merged 2 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: posterior
Title: Tools for Working with Posterior Distributions
Version: 1.5.0
Version: 1.5.0.9000
Date: 2023-10-31
Authors@R: c(person("Paul-Christian", "Bürkner", email = "[email protected]", role = c("aut", "cre")),
person("Jonah", "Gabry", email = "[email protected]", role = c("aut")),
Expand Down
8 changes: 8 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# posterior (development version)

### Enhancements

* Matrix multiplication of `rvar`s can now be done with the base matrix
multiplication operator (`%*%`) instead of `%**%` in R >= 4.3.


# posterior 1.5.0

### Enhancements
Expand Down
20 changes: 18 additions & 2 deletions R/rvar-math.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ Math.rvar_factor <- function(x, ...) {
#' is used to efficiently multiply matrices across draws, so if either `x` or `y` is an [`rvar`],
#' `x %**% y` will be much faster than `rdo(x %*% y)`.
#'
#' Because [`rvar`] is an S3 class and S3 classes cannot properly override `%*%`, [`rvar`]s use
#' `%**%` for matrix multiplication.
#' In R >= 4.3, you can also use `%*%` in place of `%**%` for matrix multiplication
#' of [`rvar`]s. In R < 4.3, S3 classes cannot properly override `%*%`, so
#' you must use `%**%` for matrix multiplication of [`rvar`]s.
#'
#' @return An [`rvar`] representing the matrix product of `x` and `y`.
#'
Expand Down Expand Up @@ -208,6 +209,21 @@ Math.rvar_factor <- function(x, ...) {
new_rvar(result, .nchains = nchains(x))
}

# This generic is not exported here as matrixOps is only in R >= 4.3, so we must
# conditionally export it in .onLoad() for compatibility with older versions
#' @rdname rvar-matmult
#' @method matrixOps rvar
matrixOps.rvar <- function(x, y) {
# as of R 4.3 this group generic is only used for %*%, but that may change
# in the future (crossprod and tcrossprod are planned), so we include this
# check for safety purposes
if (.Generic != "%*%") {
stop_no_call("Cannot apply `", .Generic, "` operator to rvar objects.")
}
x %**% y
}


#' Cholesky decomposition of random matrix
#'
#' Cholesky decomposition of an [`rvar`] containing a matrix.
Expand Down
6 changes: 6 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,10 @@
# See help("s3_register", package = "vctrs") for more information.
vctrs::s3_register("dplyr::dplyr_reconstruct", "draws_df")
vctrs::s3_register("ggplot2::scale_type", "rvar")

# S3 methods for matrixOps, which is a group generic that didn't exist
# until R 4.3, so we can't register it in NAMESPACE
if (getRversion() >= "4.3") {
registerS3method("matrixOps", "rvar", matrixOps.rvar, asNamespace("base"))
}
}
3 changes: 3 additions & 0 deletions man/draws-index.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/rename_variables.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions man/rvar-matmult.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/weight_draws.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions tests/testthat/test-rvar-math.R
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,15 @@ test_that("matrix multiplication works", {

})

test_that("%*% works in R >= 4.3", {
skip_if_not(getRversion() >= "4.3")

x <- rvar(array(1:24, dim = c(4,2,3)))
y <- rvar(array(c(2:13,12:1), dim = c(4,3,2)))

expect_equal(x %*% y, x %**% y)
})

test_that("diag works", {
Sigma <- as_draws_rvars(example_draws("multi_normal"))$Sigma

Expand Down
10 changes: 7 additions & 3 deletions vignettes/rvar.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,12 @@ mu + 1
```

Matrix multiplication is also implemented (using a tensor product under the hood).
Because the normal matrix multiplication operator in R (`%*%`) cannot be properly
implemented for S3 datatypes, `rvar` uses `%**%` instead. A trivial example:
In R < 4.3, the normal matrix multiplication operator (`%*%`) cannot be properly
implemented for S3 datatypes, so `rvar` uses `%**%` instead. In R ≥ 4.3, which
does support matrix multiplication for S3 datatypes, you can use `%*%` to
matrix-multiply `rvar`s.

A trivial example:

```{r matrix_mult}
Sigma %**% diag(1:3)
Expand All @@ -223,7 +227,7 @@ includes:
| Logical operators | `&`, `|`, `!` |
| Comparison operators | `==`, `!=`, `<`, `<=`, `>=`, `>` |
| Value matching | `match()`, `%in%` |
| Matrix multiplication | `%**%` |
| Matrix multiplication | `%**%`, `%*%` (R ≥ 4.3 only) |
| Basic functions | `abs()`, `sign()`<br>`sqrt()`<br>`floor()`, `ceiling()`, `trunc()`, `round()`, `signif()` |
| Logarithms and exponentials | `exp()`, `expm1()`<br>`log()`, `log10()`, `log2()`, `log1p()` |
| Trigonometric functions | `cos()`, `sin()`, `tan()`<br>`cospi()`, `sinpi()`, `tanpi()`<br>`acos()`, `asin()`, `atan()`|
Expand Down
Loading