diff --git a/tests/nn/modules/test_kv_cache.py b/tests/nn/modules/test_kv_cache.py new file mode 100644 index 00000000..7efeb3f8 --- /dev/null +++ b/tests/nn/modules/test_kv_cache.py @@ -0,0 +1,165 @@ +from unittest.mock import Mock +import pytest +import torch + +from zeta.nn.modules.kv_cache import ( + KVCache, + find_multiple, + precompute_freq_cis, + setup_cache, +) + + +# 1. Basic Tests +def test_find_multiple(): + assert find_multiple(10, 3) == 12 + assert find_multiple(15, 5) == 15 + assert find_multiple(20, 7) == 21 + + +def test_precompute_freq_cis(): + seq_len = 128 + n_elem = 64 + freqs = precompute_freq_cis(seq_len, n_elem) + assert freqs.shape == torch.Size([seq_len, n_elem, 2]) + + +def test_kv_cache_creation(): + cache = KVCache(32, 128, 8, 64) + assert isinstance(cache, KVCache) + + +# 2. Utilize Fixtures +@pytest.fixture +def sample_cache(): + return KVCache(16, 64, 4, 32) + + +def test_kv_cache_update(sample_cache): + input_pos = torch.randint(0, 64, (5,)) + k_val = torch.randn(16, 4, 64, 32) + v_val = torch.randn(16, 4, 64, 32) + k_out, v_out = sample_cache.update(input_pos, k_val, v_val) + assert k_out.shape == torch.Size([16, 4, 64, 32]) + assert v_out.shape == torch.Size([16, 4, 64, 32]) + + +# 3. Parameterized Testing +@pytest.mark.parametrize( + "max_batch_size, max_seq_len, heads, head_dim", + [(32, 128, 8, 64), (16, 64, 4, 32)], +) +def test_setup_cache(max_batch_size, max_seq_len, heads, head_dim): + layers = [ + Mock(attention=Mock(kw_cache=None)), + Mock(attention=Mock(kw_cache=None)), + ] + block_size = 64 + rope_base = 1000 + setup_cache( + max_batch_size, + max_seq_len, + head_dim * heads, + heads, + layers, + block_size, + rope_base, + ) + for layer in layers: + assert isinstance(layer.attention.kw_cache, KVCache) + + +# 1. Edge Cases +def test_find_multiple_edge_cases(): + assert find_multiple(0, 5) == 0 + assert find_multiple(5, 0) == 5 + assert find_multiple(0, 0) == 0 + + +def test_precompute_freq_cis_edge_cases(): + seq_len = 128 + n_elem = 0 + freqs = precompute_freq_cis(seq_len, n_elem) + assert freqs.shape == torch.Size([seq_len, 0, 2]) + + +# 2. Additional KVCache Tests +def test_kv_cache_update_empty_input(): + cache = KVCache(32, 128, 8, 64) + input_pos = torch.tensor([], dtype=torch.int64) + k_val = torch.randn(32, 8, 64, 64) + v_val = torch.randn(32, 8, 64, 64) + k_out, v_out = cache.update(input_pos, k_val, v_val) + assert k_out.shape == torch.Size([32, 8, 128, 64]) + assert v_out.shape == torch.Size([32, 8, 128, 64]) + + +def test_kv_cache_update_out_of_bounds_input(): + cache = KVCache(32, 128, 8, 64) + input_pos = torch.tensor([140, 160, 200], dtype=torch.int64) + k_val = torch.randn(32, 8, 64, 64) + v_val = torch.randn(32, 8, 64, 64) + k_out, v_out = cache.update(input_pos, k_val, v_val) + assert k_out.shape == torch.Size([32, 8, 128, 64]) + assert v_out.shape == torch.Size([32, 8, 128, 64]) + + +# 3. Additional setup_cache Tests +def test_setup_cache_max_seq_len_greater_than_max(): + layers = [ + Mock(attention=Mock(kw_cache=None)), + Mock(attention=Mock(kw_cache=None)), + ] + max_batch_size = 16 + max_seq_len = 64 + heads = 4 + head_dim = 32 + block_size = 32 + rope_base = 1000 + setup_cache( + max_batch_size, + max_seq_len + 10, + head_dim * heads, + heads, + layers, + block_size, + rope_base, + ) + for layer in layers: + assert isinstance(layer.attention.kw_cache, KVCache) + assert layer.attention.kw_cache.k_cache.shape == torch.Size( + [max_batch_size, heads, max_seq_len + 10, head_dim] + ) + assert layer.attention.kw_cache.v_cache.shape == torch.Size( + [max_batch_size, heads, max_seq_len + 10, head_dim] + ) + + +def test_setup_cache_max_batch_size_greater_than_max(): + layers = [ + Mock(attention=Mock(kw_cache=None)), + Mock(attention=Mock(kw_cache=None)), + ] + max_batch_size = 64 + max_seq_len = 32 + heads = 4 + head_dim = 32 + block_size = 32 + rope_base = 1000 + setup_cache( + max_batch_size + 10, + max_seq_len, + head_dim * heads, + heads, + layers, + block_size, + rope_base, + ) + for layer in layers: + assert isinstance(layer.attention.kw_cache, KVCache) + assert layer.attention.kw_cache.k_cache.shape == torch.Size( + [max_batch_size + 10, heads, max_seq_len, head_dim] + ) + assert layer.attention.kw_cache.v_cache.shape == torch.Size( + [max_batch_size + 10, heads, max_seq_len, head_dim] + ) diff --git a/zeta/nn/modules/__init__.py b/zeta/nn/modules/__init__.py index c8d1fee3..e169194b 100644 --- a/zeta/nn/modules/__init__.py +++ b/zeta/nn/modules/__init__.py @@ -44,6 +44,7 @@ from zeta.nn.modules.lang_conv_module import ConvolutionLanguageBlock from zeta.nn.modules.s4 import s4d_kernel from zeta.nn.modules.h3 import H3Layer +from zeta.nn.modules.mlp_mixer import MLPMixer # from zeta.nn.modules.img_reshape import image_reshape @@ -105,4 +106,5 @@ "IterativeCrossSelfAttention", "ConvolutionLanguageBlock", "H3Layer", + "MLPMixer", ] diff --git a/zeta/nn/modules/kv_cache.py b/zeta/nn/modules/kv_cache.py new file mode 100644 index 00000000..7e6c8fba --- /dev/null +++ b/zeta/nn/modules/kv_cache.py @@ -0,0 +1,157 @@ +import torch +from torch import nn, Tensor + + +# Helpers +def find_multiple(n: int, k: int) -> int: + """Finds the smallest multiple of k that is greater than or equal to n. + + Args: + n (int): _description_ + k (int): _description_ + + Returns: + int: _description_ + """ + if n % k == 0: + return n + return n + k - (n % k) + + +def precompute_freq_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor: + """Precomputes the frequency values for the positional encodings. + + Args: + seq_len (int): _description_ + n_elem (int): _description_ + base (int, optional): _description_. Defaults to 10000. + + Returns: + Tensor: _description_ + """ + freqs = 1.0 / ( + base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) + ) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=torch.bfloat16) + + +class KVCache(nn.Module): + """ + KVCache is a module that stores the key and value tensors for each + position in the input sequence. This is used in the decoder of the + Transformer model to store the key and value tensors for each position + in the encoder output sequence. + + The cache is updated by calling the update method, which takes the + input positions and the key and value tensors for those positions. + + The cache is a tensor of shape [B, H, S, D], where B is the batch size, + H is the number of heads, S is the maximum sequence length, and D is + the head dimension. + + Args: + max_batch_size: The maximum batch size of the model. + max_seq_len: The maximum sequence length of the model. + heads: The number of heads in the model. + head_dim: The dimension of each head. + dtype: The datatype of the cache. + + Attributes: + k_cache: The key cache. + v_cache: The value cache. + + Methods: + update: Updates the cache with the given input positions and key + and value tensors. + + Input Shapes: + input_pos: [S] + k_val: [B, H, S, D] + v_val: [B, H, S, D] + + Output Shapes: + k_out: [B, H, S, D] + v_out: [B, H, S, D] + + Examples: + >>> from zeta.nn import KVCache + >>> cache = KVCache(32, 128, 8, 64) + >>> k_val = torch.randn(32, 8, 128, 64) + >>> v_val = torch.randn(32, 8, 128, 64) + >>> input_pos = torch.randint(0, 128, (5,)) + >>> k_out, v_out = cache.update(input_pos, k_val, v_val) + >>> k_out.shape + torch.Size([32, 8, 128, 64]) + """ + + def __init__( + self, + max_batch_size: int, + max_seq_len: int, + heads: int, + head_dim: int, + dtype=torch.bfloat16, + ): + super().__init__() + cache_shape = (max_batch_size, heads, max_seq_len, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + """ + Updates the cache with the given input positions and key and value. + + Args: + input_pos (_type_): _description_ + k_val (_type_): _description_ + v_val (_type_): _description_ + + Returns: + _type_: _description_ + """ + # Input pos: [5], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos, :] = k_val + v_out[:, :, input_pos, :] = v_val + + return k_out, v_out + + +def setup_cache( + max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base +): + """Sets up the cache for the given model. + + Args: + max_batch_size (_type_): _description_ + max_seq_len (_type_): _description_ + dim (_type_): _description_ + heads (_type_): _description_ + layers (_type_): _description_ + block_size (_type_): _description_ + rope_base (_type_): _description_ + """ + if max_seq_len >= max_seq_len and max_batch_size >= max_batch_size: + return + + head_dim = dim // heads + max_seq_len = find_multiple(max_seq_len, 8) + + for b in layers: + b.attention.kv_cache = KVCache( + max_batch_size, max_seq_len, heads, head_dim + ) + + freq_cis = precompute_freq_cis(block_size, dim // heads, rope_base) + causal_mask = torch.tril( + torch.ones(max_seq_len, max_seq_len, dtype=torch.bool) + ) + + return causal_mask, freq_cis diff --git a/zeta/nn/modules/mlp_mixer.py b/zeta/nn/modules/mlp_mixer.py new file mode 100644 index 00000000..e48a5e26 --- /dev/null +++ b/zeta/nn/modules/mlp_mixer.py @@ -0,0 +1,146 @@ +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + + +class MLPBlock(nn.Module): + """MLPBlock + + Args: + dim (int): [description] + """ + + def __init__(self, dim: int): + super(MLPBlock, self).__init__() + self.dense1 = nn.Linear(dim, dim) + self.dense2 = nn.Linear(dim, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of MLPBlock + + Args: + x (torch.Tensor): _description_ + + Returns: + torch.Tensor: _description_ + """ + y = self.dense1(x) + y = F.gelu(y) + return self.dense(y) + + +class MixerBlock(nn.Module): + """MixerBlock + + + Args: + mlp_dim (int): [description] + channels_dim (int): [description] + """ + + def __init__(self, mlp_dim: int, channels_dim: int): + super(MixerBlock, self).__init__() + self.norm1 = nn.LayerNorm(channels_dim) + self.tokens_mlp = MLPBlock(mlp_dim) + + self.norm2 = nn.LayerNorm(channels_dim) + self.channel_mlp = MLPBlock(mlp_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of MixerBlock + + Args: + x (torch.Tensor): _description_ + + Returns: + torch.Tensor: _description_ + """ + y = self.norm1(x) + y = rearrange(y, "n c t -> n t c") + y = self.tokens_mlp(y) + y = rearrange(y, "n t c -> n c t") + x = x + y + y = self.norm2(x) + return x + self.channel_mlp(y) + + +class MLPMixer(nn.Module): + """MLPMixer + + Args: + num_classes (int): [description] + num_blocks (int): [description] + patch_size (int): [description] + hidden_dim (int): [description] + tokens_mlp_dim (int): [description] + channels_mlp_dim (int): [description] + + Examples: + >>> from zeta.nn import MLPMixer + >>> model = MLPMixer(10, 8, 16, 32, 64, 64) + >>> x = torch.randn(32, 3, 224, 224) + >>> model(x).shape + torch.Size([32, 10]) + + + """ + + def __init__( + self, + num_classes: int, + num_blocks: int, + patch_size: int, + hidden_dim: int, + tokens_mlp_dim: int, + channels_mlp_dim: int, + ): + super(MLPMixer, self).__init__() + self.stem = nn.Conv2d( + hidden_dim, hidden_dim, kernel_size=patch_size, stride=patch_size + ) + self.mixer_blocks = nn.ModuleList( + [ + MixerBlock(tokens_mlp_dim, channels_mlp_dim) + for _ in range(num_blocks) + ] + ) + self.pred_head_layernorm = nn.LayerNorm(hidden_dim) + self.head = nn.Linear(hidden_dim, num_classes) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of MLPMixer + + Args: + x (torch.Tensor): _description_ + + Returns: + torch.Tensor: _description_ + """ + x = self.stem(x) + x = rearrange(x, "n c h w -> n (h w) c") + for mixer_block in self.mixer_blocks: + x = mixer_block(x) + x = self.pred_head_layernorm(x) + x = x.mean(dim=1) + return self.head(x) + + +# Example of creating a model instance +mlp_mixer = MLPMixer( + num_classes=10, + num_blocks=8, + patch_size=16, + hidden_dim=512, + tokens_mlp_dim=256, + channels_mlp_dim=512, +) + +# Example input tensor +example_input = torch.randn( + 1, 512, 32, 32 +) # Batch size of 1, 512 channels, 32x32 image +output = mlp_mixer(example_input) +print( + output.shape +) # Should output the shape corresponding to the number of classes