Skip to content

Commit

Permalink
Scalar multiplication (#355)
Browse files Browse the repository at this point in the history
* Added tests for scalar multiplication for FactoredMatrix

* Added  __mul__ and __rmul__ to FactoredMatrix

* Tests for errors when multiplying by non-scalar

* Added scalar.shape to error message

* Fixed imports to make isort happy

* Black Formatting

* Changed to random.random and randint

* Implementation dependent test for factored matrix A.
  • Loading branch information
matthiasdellago authored Jul 26, 2023
1 parent 5b26456 commit 090081f
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
59 changes: 59 additions & 0 deletions tests/unit/factored_matrix/test_multiply_by_scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import random

import pytest
import torch
from torch.testing import assert_close

from transformer_lens import FactoredMatrix


# This test function is parametrized with different types of scalars, including non-scalar tensors and arrays, to check that the correct errors are raised.
# Considers cases with and without leading dimensions as well as left and right multiplication.
@pytest.mark.parametrize(
"scalar, error_expected",
[
# Test cases with different types of scalar values.
(torch.rand(1), None), # 1-element Tensor. No error expected.
(random.random(), None), # float. No error expected.
(random.randint(-100, 100), None), # int. No error expected.
# Test cases with non-scalar values that are expected to raise errors.
(
torch.rand(2, 2),
AssertionError,
), # Non-scalar Tensor. AssertionError expected.
(torch.rand(2), AssertionError), # Non-scalar Tensor. AssertionError expected.
],
)
@pytest.mark.parametrize("leading_dim", [False, True])
@pytest.mark.parametrize("multiply_from_left", [False, True])
def test_multiply(scalar, leading_dim, multiply_from_left, error_expected):
# Prepare a FactoredMatrix, with or without leading dimensions
if leading_dim:
a = torch.rand(6, 2, 3)
b = torch.rand(6, 3, 4)
else:
a = torch.rand(2, 3)
b = torch.rand(3, 4)

fm = FactoredMatrix(a, b)

if error_expected:
# If an error is expected, check that the correct exception is raised.
with pytest.raises(error_expected):
if multiply_from_left:
_ = fm * scalar
else:
_ = scalar * fm
else:
# If no error is expected, check that the multiplication results in the correct value.
# Use FactoredMatrix.AB to calculate the product of the two factor matrices before comparing with the expected value.
if multiply_from_left:
assert_close((fm * scalar).AB, (a @ b) * scalar)
else:
assert_close((scalar * fm).AB, scalar * (a @ b))
# This next test is implementation dependant and can be broken and removed at any time!
# It checks that the multiplication is performed on the A factor matrix.
if multiply_from_left:
assert_close((fm * scalar).A, a * scalar)
else:
assert_close((scalar * fm).A, scalar * a)
16 changes: 16 additions & 0 deletions transformer_lens/FactoredMatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,22 @@ def __rmatmul__(
elif isinstance(other, FactoredMatrix):
return other.A @ (other.B @ self)

def __mul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix:
"""
Left scalar multiplication. Scalar multiplication distributes over matrix multiplication, so we can just multiply one of the factor matrices by the scalar.
"""
if isinstance(scalar, torch.Tensor):
assert (
scalar.numel() == 1
), f"Tensor must be a scalar for use with * but was of shape {scalar.shape}. For matrix multiplication, use @ instead."
return FactoredMatrix(self.A * scalar, self.B)

def __rmul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix:
"""
Right scalar multiplication. For scalar multiplication from the right, we can reuse the __mul__ method.
"""
return self * scalar

@property
@typeguard_ignore
def AB(self) -> Float[torch.Tensor, "*leading_dims ldim rdim"]:
Expand Down

0 comments on commit 090081f

Please sign in to comment.