diff --git a/pyproject.toml b/pyproject.toml index 83fb9e25..31baa4f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zetascale" -version = "1.1.7" +version = "1.1.9" description = "Transformers at zeta scales" authors = ["Zeta Team "] license = "MIT" diff --git a/tests/nn/modules/test_simple_mamba.py b/tests/nn/modules/test_simple_mamba.py new file mode 100644 index 00000000..c6c90f35 --- /dev/null +++ b/tests/nn/modules/test_simple_mamba.py @@ -0,0 +1,89 @@ +# FILEPATH: /Users/defalt/Desktop/Athena/research/zeta/tests/nn/modules/test_simple_mamba.py + +import pytest +import torch +from torch import nn +from zeta.nn.modules.simple_mamba import Mamba, ResidualBlock, RMSNorm + +def test_mamba_class_init(): + model = Mamba(10000, 512, 6) + + assert isinstance(model.embedding, nn.Embedding) + assert isinstance(model.layers, nn.ModuleList) + assert isinstance(model.norm_f, RMSNorm) + assert isinstance(model.lm_head, nn.Linear) + +def test_mamba_forward(): + model = Mamba(10000, 512, 6) + x = torch.randint(0, 10000, (1, 50)) + out = model(x) + + assert out.shape == torch.Size([1, 50, 10000]) + +def test_residual_block_class_init(): + block = ResidualBlock(512) + + assert isinstance(block.norm1, RMSNorm) + assert isinstance(block.norm2, RMSNorm) + assert isinstance(block.fc1, nn.Linear) + assert isinstance(block.fc2, nn.Linear) + +def test_residual_block_forward(): + block = ResidualBlock(512) + x = torch.randn(1, 50, 512) + out = block(x) + + assert out.shape == torch.Size([1, 50, 512]) + +def test_mamba_different_vocab_size(): + model = Mamba(20000, 512, 6) + x = torch.randint(0, 20000, (1, 50)) + out = model(x) + + assert out.shape == torch.Size([1, 50, 20000]) + +def test_mamba_different_dim(): + model = Mamba(10000, 1024, 6) + x = torch.randint(0, 10000, (1, 50)) + out = model(x) + + assert out.shape == torch.Size([1, 50, 10000]) + +def test_mamba_different_depth(): + model = Mamba(10000, 512, 12) + x = torch.randint(0, 10000, (1, 50)) + out = model(x) + + assert out.shape == torch.Size([1, 50, 10000]) + +def test_residual_block_different_dim(): + block = ResidualBlock(1024) + x = torch.randn(1, 50, 1024) + out = block(x) + + assert out.shape == torch.Size([1, 50, 1024]) + +def test_mamba_with_dropout(): + model = Mamba(10000, 512, 6, dropout=0.5) + x = torch.randint(0, 10000, (1, 50)) + out = model(x) + + assert out.shape == torch.Size([1, 50, 10000]) + +def test_residual_block_with_dropout(): + block = ResidualBlock(512, dropout=0.5) + x = torch.randn(1, 50, 512) + out = block(x) + + assert out.shape == torch.Size([1, 50, 512]) + +def test_mamba_with_custom_layer(): + class CustomLayer(nn.Module): + def forward(self, x): + return x * 2 + + model = Mamba(10000, 512, 6, layer=CustomLayer()) + x = torch.randint(0, 10000, (1, 50)) + out = model(x) + + assert out.shape == torch.Size([1, 50, 10000]) \ No newline at end of file diff --git a/tests/test_test_example.py b/tests/test_test_example.py index b707a6d9..fbcfa709 100644 --- a/tests/test_test_example.py +++ b/tests/test_test_example.py @@ -1,5 +1,3 @@ - - import time import unittest import torch diff --git a/zeta/nn/biases/relative_position_bias.py b/zeta/nn/biases/relative_position_bias.py index f7befef9..aae02239 100644 --- a/zeta/nn/biases/relative_position_bias.py +++ b/zeta/nn/biases/relative_position_bias.py @@ -4,12 +4,9 @@ import math import torch -import torch.nn as nn +from torch import nn -from zeta.nn.biases.base import BaseBias - - -class RelativePositionBias(BaseBias): +class RelativePositionBias(nn.Module): def __init__( self, bidirectional: int = True, diff --git a/zeta/nn/modules/simple_mamba.py b/zeta/nn/modules/simple_mamba.py index 67a1a959..7f0c60fc 100644 --- a/zeta/nn/modules/simple_mamba.py +++ b/zeta/nn/modules/simple_mamba.py @@ -1,52 +1,263 @@ +from __future__ import annotations import torch -from torch import nn -from zeta.nn.modules.rms_norm import RMSNorm -from zeta.nn.modules.residual import Residual +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat, einsum +from typing import Optional, Union + + + +# [HELPERS] ---------------------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + output = ( + x + * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + * self.weight + ) + + return output + + +class ResidualBlock(nn.Module): + def __init__( + self, dim: int = None, vocab_size: int = None, depth: int = None + ): + """Simple block wrapping Mamba block with normalization and residual connection.""" + super().__init__() + self.mixer = MambaBlock(vocab_size, dim, depth) + self.norm = RMSNorm(dim) + + def forward(self, x): + """ + Args: + x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...) + + Returns: + output: shape (b, l, d) + + Official Implementation: + Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297 + + NOTE: the official repo chains residual blocks that look like + [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ... + where the first Add is a no-op. This is purely for performance reasons as this allows them to fuse the Add->Norm. + + We instead implement our residual blocks as more standard, simpler, and numerically equivalent + [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> .... + + """ + output = self.mixer(self.norm(x)) + x + + return output + + class Mamba(nn.Module): def __init__( - self, - vocab_size: int, - dim: int, - depth: int, - bias: bool = False, - *args, - **kwargs, + self, vocab_size: int = None, dim: int = None, depth: int = None ): + """Full Mamba model.""" super().__init__() + self.embedding = nn.Embedding(vocab_size, dim) - self.layers = nn.ModuleList( - [ - Residual(self.rmsnorm, nn.Linear(dim, dim, bias=bias)) - for _ in range(depth) - ] - ) - self.rmsnorm = RMSNorm(dim) - self.linear = nn.Linear(dim, vocab_size, bias=bias) - self.linear.weight = self.embedding.weight + self.layers = nn.ModuleList([ResidualBlock(dim) for _ in range(depth)]) + self.norm_f = RMSNorm(dim) + + self.lm_head = nn.Linear(dim, vocab_size, bias=False) + self.lm_head.weight = ( + self.embedding.weight + ) # Tie output projection to embedding weights. See "Weight Tying" paper + + def forward(self, x): + """ + Args: + x (long tensor): shape (b, l) (See Glossary at top for definitions of b, l, d_in, n...) + + Returns: + logits: shape (b, l, vocab_size) - def forward(self, x: torch.Tensor) -> torch.Tensor: + Official Implementation: + class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173 + + """ x = self.embedding(x) for layer in self.layers: x = layer(x) - x = self.rmsnorm(x) - logits = self.linear(x) + x = self.norm_f(x) + logits = self.lm_head(x) return logits -# class MambaBlock(nn.Module): -# def __init__( -# self, -# dim, -# inner_dim, -# bias: bool = False, -# conv_bias=None, -# dim_conv=None, -# *args, -# **kwargs, -# ): -# super().__init__() + +class MambaBlock(nn.Module): + def __init__( + self, + dim: int, + dim_inner: Optional[int], + depth: int, + d_state: int = 16, + expand: int = 2, + dt_rank: Union[int, str] = 'auto', + d_conv: int = 4, + conv_bias: bool = True, + bias: bool = False, + ): + """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].""" + super().__init__() + dim_inner = dim_inner or dim * expand + self.in_proj = nn.Linear(dim, dim_inner * 2, bias=bias) + + self.conv1d = nn.Conv1d( + in_channels=dim_inner, + out_channels=dim_inner, + bias=conv_bias, + kernel_size=d_conv, + groups=dim_inner, + padding=d_conv - 1, + ) + + # x_proj takes in `x` and outputs the input-specific Δ, B, C + self.x_proj = nn.Linear(dim_inner, dt_rank + d_state * 2, bias=False) + + # dt_proj projects Δ from dt_rank to d_in + self.dt_proj = nn.Linear(dt_rank, dim_inner, bias=True) + + A = repeat(torch.arange(1, d_state + 1), "n -> d n", d=dim_inner) + self.A_log = nn.Parameter(torch.log(A)) + self.D = nn.Parameter(torch.ones(dim_inner)) + self.out_proj = nn.Linear(dim_inner, dim, bias=bias) + + + def forward(self, x): + """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1]. + + Args: + x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...) + + Returns: + output: shape (b, l, d) + + + Official Implementation: + class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119 + mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 + + """ + (b, l, d) = x.shape + + x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in) + x_and_res = rearrange(x_and_res, "b l x -> b x l") + (x, res) = x_and_res.split( + split_size=[self.dim_inner, self.dim_inner], dim=1 + ) + + x = self.conv1d(x)[:, :, :l] + x = F.silu(x) + + y = self.ssm(x) + + y = y * F.silu(res) + + output = self.out_proj(rearrange(y, "b dim l -> b l dim")) + + return output + + def ssm(self, x): + """Runs the SSM. See: + - Algorithm 2 in Section 3.2 in the Mamba paper [1] + - run_SSM(A, B, C, u) in The Annotated S4 [2] + + Args: + x: shape (b, d_in, l) (See Glossary at top for definitions of b, l, d_in, n...) + + Returns: + output: shape (b, d_in, l) + + Official Implementation: + mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 + + """ + (d_in, n) = self.A_log.shape + + # Compute ∆ A B C D, the state space parameters. + # A, D are input independent + # ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4) + + A = -torch.exp(self.A_log.float()) # shape (d_in, n) + D = self.D.float() + + x_dbl = rearrange(x, "b d l -> b l d") + x_dbl = self.x_proj(x_dbl) # (b, l, dt_rank + 2*n) + + (delta, B, C) = x_dbl.split( + split_size=[self.dt_rank, n, n], dim=-1 + ) # delta: (b, l, dt_rank). B, C: (b, l, n) + delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in) + + y = self.selective_scan( + x, delta, A, B, C, D + ) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2] + + return y + + def selective_scan(self, u, delta, A, B, C, D): + """Does selective scan algorithm. See: + - Section 2 State Space Models in the Mamba paper [1] + - Algorithm 2 in Section 3.2 in the Mamba paper [1] + - run_SSM(A, B, C, u) in The Annotated S4 [2] + + This is the classic discrete state space formula: + x(t + 1) = Ax(t) + Bu(t) + y(t) = Cx(t) + Du(t) + except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t). + + Args: + u: shape (b, d_in, l) (See Glossary at top for definitions of b, l, d_in, n...) + delta: shape (b, l, d_in) + A: shape (d_in, n) + B: shape (b, l, n) + C: shape (b, l, n) + D: shape (d_in,) + + Returns: + output: shape (b, d_in, l) + + Official Implementation: + selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86 + Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly. + + """ + (b, d_in, l) = u.shape + n = A.shape[1] + + # Discretize continuous parameters (Δ, A, B) (see Section 2 Equation 4 in the Mamba paper [1]) + # Note that B is parameterized directly + deltaA = torch.exp(einsum(delta, A, "b l d_in, d_in n -> b d_in l n")) + deltaB_u = einsum( + delta, B, u, "b l d_in, b l n, b d_in l -> b d_in l n" + ) + + # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) + x = torch.zeros((b, d_in, n)) + ys = [] + for i in range(l): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + y = einsum(x, C[:, i, :], "b d_in n , b n -> b d_in") + ys.append(y) + y = torch.stack(ys, dim=2) # (b d_in l) + + if D is not None: + y = y + u * rearrange(D, "d_in -> d_in 1") + + return y +