Skip to content

Commit

Permalink
Switch to beartype (#325)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkamm authored Aug 3, 2023
1 parent 0d2827e commit 10d2f8a
Show file tree
Hide file tree
Showing 12 changed files with 54 additions and 92 deletions.
4 changes: 2 additions & 2 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ test:
make acceptance-test

unit-test:
poetry run pytest -v --typeguard-packages=transformer_lens --cov=transformer_lens/ --cov-report=term-missing --cov-branch tests/unit
poetry run pytest -v --cov=transformer_lens/ --cov-report=term-missing --cov-branch tests/unit

acceptance-test:
poetry run pytest -v --typeguard-packages=transformer_lens --cov=transformer_lens/ --cov-report=term-missing --cov-branch tests/acceptance
poetry run pytest -v --cov=transformer_lens/ --cov-report=term-missing --cov-branch tests/acceptance
29 changes: 24 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ sphinx-autobuild = {version = ">=2021.3.14", optional = true, python = ">=3.8,<3
furo = {version = ">=2022.12.7", optional = true, python = ">=3.8,<3.10"}
myst-parser = {version = ">=0.18.1", optional = true, python = ">=3.8,<3.10"}
tabulate= {version = ">=0.9.0", optional = true, python = ">=3.8,<3.10"}
typeguard = "^3.0.2"
beartype = "^0.14.1"

[tool.poetry.group.dev.dependencies]
pytest = ">=7.2.0"
Expand Down Expand Up @@ -57,6 +57,7 @@ filterwarnings = [
# More info: https://numpy.org/doc/stable/reference/distutils.html#module-numpy.distutils
"ignore:distutils Version classes are deprecated:DeprecationWarning"
]
addopts = "--jaxtyping-packages=transformer_lens,beartype.beartype"

[tool.isort]
profile = "black"
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_head_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
import torch
from typeguard import TypeCheckError
from beartype.roar import BeartypeCallHintParamViolation

from transformer_lens import HookedTransformer
from transformer_lens.head_detector import (
Expand Down Expand Up @@ -352,8 +352,8 @@ def test_detect_head_with_cache(error_measure: ErrorMeasure, expected: torch.Ten


def test_detect_head_with_invalid_head_name():
with pytest.raises((AssertionError, TypeCheckError)) as e:
detect_head(model, test_regular_sequence, "test") # type:ignore
with pytest.raises(BeartypeCallHintParamViolation) as e:
detect_head(model, test_regular_sequence, "test")


def test_detect_head_with_zero_sequence_length():
Expand Down
7 changes: 2 additions & 5 deletions tests/unit/test_svd_interpreter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import torch
import typeguard
from beartype.roar import BeartypeCallHintParamViolation

from transformer_lens import HookedTransformer, SVDInterpreter

Expand Down Expand Up @@ -113,13 +113,10 @@ def test_svd_interpreter_returns_different_answers_for_different_models():

def test_svd_interpreter_fails_on_invalid_vector_type():
svd_interpreter = SVDInterpreter(model)
with pytest.raises(typeguard.TypeCheckError) as e:
with pytest.raises(BeartypeCallHintParamViolation) as e:
svd_interpreter.get_singular_vectors(
"test", layer_index=0, num_vectors=4, head_index=0
)
assert 'argument "vector_type" (str) did not match any element in the union' in str(
e.value
)


def test_svd_interpreter_fails_on_not_passing_required_head_index():
Expand Down
4 changes: 2 additions & 2 deletions transformer_lens/ActivationCache.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def logit_attrs(
pos_slice: Union[Slice, SliceInput] = None,
batch_slice: Union[Slice, SliceInput] = None,
has_batch_dim: bool = True,
) -> Float[torch.Tensor, "num_components *batch_and_pos_dims"]:
) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out"]:
"""Returns the logit attributions for the residual stack on an input of tokens, or the logit difference attributions for the residual stack if incorrect_tokens is provided.
Args:
Expand Down Expand Up @@ -583,7 +583,7 @@ def apply_ln_to_stack(
pos_slice: Union[Slice, SliceInput] = None,
batch_slice: Union[Slice, SliceInput] = None,
has_batch_dim: bool = True,
) -> Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"]:
) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out d_model"]:
"""Takes a stack of components of the residual stream (eg outputs of decompose_resid or accumulated_resid), treats them as the input to a specific layer, and applies the layer norm scaling of that layer to them, using the cached scale factors - simulating what that component of the residual stream contributes to that layer's input.
The layernorm scale is global across the entire residual stream for each layer, batch element and position, which is why we need to use the cached scale factors rather than just applying a new LayerNorm.
Expand Down
18 changes: 4 additions & 14 deletions transformer_lens/FactoredMatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch
from jaxtyping import Float
from typeguard import typeguard_ignore

import transformer_lens.utils as utils

Expand Down Expand Up @@ -41,9 +40,9 @@ def __matmul__(
other: Union[
Float[torch.Tensor, "... rdim new_rdim"],
Float[torch.Tensor, "rdim"],
FactoredMatrix,
"FactoredMatrix",
],
) -> Union[FactoredMatrix, Float[torch.Tensor, "... ldim"]]:
) -> Union["FactoredMatrix", Float[torch.Tensor, "... ldim"]]:
if isinstance(other, torch.Tensor):
if other.ndim < 2:
# It's a vector, so we collapse the factorisation and just return a vector
Expand All @@ -65,9 +64,9 @@ def __rmatmul__(
other: Union[
Float[torch.Tensor, "... new_rdim ldim"],
Float[torch.Tensor, "ldim"],
FactoredMatrix,
"FactoredMatrix",
],
) -> Union[FactoredMatrix, Float[torch.Tensor, "... rdim"]]:
) -> Union["FactoredMatrix", Float[torch.Tensor, "... rdim"]]:
if isinstance(other, torch.Tensor):
assert (
other.size(-1) == self.ldim
Expand Down Expand Up @@ -99,13 +98,11 @@ def __rmul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix:
return self * scalar

@property
@typeguard_ignore
def AB(self) -> Float[torch.Tensor, "*leading_dims ldim rdim"]:
"""The product matrix - expensive to compute, and can consume a lot of GPU memory"""
return self.A @ self.B

@property
@typeguard_ignore
def BA(self) -> Float[torch.Tensor, "*leading_dims rdim ldim"]:
"""The reverse product. Only makes sense when ldim==rdim"""
assert (
Expand All @@ -114,7 +111,6 @@ def BA(self) -> Float[torch.Tensor, "*leading_dims rdim ldim"]:
return self.B @ self.A

@property
@typeguard_ignore
def T(self) -> FactoredMatrix:
return FactoredMatrix(self.B.transpose(-2, -1), self.A.transpose(-2, -1))

Expand All @@ -141,22 +137,18 @@ def svd(
return U, S, Vh

@property
@typeguard_ignore
def U(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]:
return self.svd()[0]

@property
@typeguard_ignore
def S(self) -> Float[torch.Tensor, "*leading_dims mdim"]:
return self.svd()[1]

@property
@typeguard_ignore
def Vh(self) -> Float[torch.Tensor, "*leading_dims rdim mdim"]:
return self.svd()[2]

@property
@typeguard_ignore
def eigenvalues(self) -> Float[torch.Tensor, "*leading_dims mdim"]:
"""Eigenvalues of AB are the same as for BA (apart from trailing zeros), because if BAv=kv ABAv = A(BAv)=kAv, so Av is an eigenvector of AB with eigenvalue k."""
return torch.linalg.eig(self.BA).eigenvalues
Expand Down Expand Up @@ -215,7 +207,6 @@ def get_corner(self, k=3):
return utils.get_corner(self.A[..., :k, :] @ self.B[..., :, :k], k)

@property
@typeguard_ignore
def ndim(self) -> int:
return len(self.shape)

Expand All @@ -235,7 +226,6 @@ def unsqueeze(self, k: int) -> FactoredMatrix:
return FactoredMatrix(self.A.unsqueeze(k), self.B.unsqueeze(k))

@property
@typeguard_ignore
def pair(
self,
) -> Tuple[
Expand Down
Loading

0 comments on commit 10d2f8a

Please sign in to comment.