Skip to content

Commit

Permalink
[MLPMixer]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 1, 2023
1 parent deb2513 commit 1457dcc
Show file tree
Hide file tree
Showing 4 changed files with 470 additions and 0 deletions.
165 changes: 165 additions & 0 deletions tests/nn/modules/test_kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from unittest.mock import Mock
import pytest
import torch

from zeta.nn.modules.kv_cache import (
KVCache,
find_multiple,
precompute_freq_cis,
setup_cache,
)


# 1. Basic Tests
def test_find_multiple():
assert find_multiple(10, 3) == 12
assert find_multiple(15, 5) == 15
assert find_multiple(20, 7) == 21


def test_precompute_freq_cis():
seq_len = 128
n_elem = 64
freqs = precompute_freq_cis(seq_len, n_elem)
assert freqs.shape == torch.Size([seq_len, n_elem, 2])


def test_kv_cache_creation():
cache = KVCache(32, 128, 8, 64)
assert isinstance(cache, KVCache)


# 2. Utilize Fixtures
@pytest.fixture
def sample_cache():
return KVCache(16, 64, 4, 32)


def test_kv_cache_update(sample_cache):
input_pos = torch.randint(0, 64, (5,))
k_val = torch.randn(16, 4, 64, 32)
v_val = torch.randn(16, 4, 64, 32)
k_out, v_out = sample_cache.update(input_pos, k_val, v_val)
assert k_out.shape == torch.Size([16, 4, 64, 32])
assert v_out.shape == torch.Size([16, 4, 64, 32])


# 3. Parameterized Testing
@pytest.mark.parametrize(
"max_batch_size, max_seq_len, heads, head_dim",
[(32, 128, 8, 64), (16, 64, 4, 32)],
)
def test_setup_cache(max_batch_size, max_seq_len, heads, head_dim):
layers = [
Mock(attention=Mock(kw_cache=None)),
Mock(attention=Mock(kw_cache=None)),
]
block_size = 64
rope_base = 1000
setup_cache(
max_batch_size,
max_seq_len,
head_dim * heads,
heads,
layers,
block_size,
rope_base,
)
for layer in layers:
assert isinstance(layer.attention.kw_cache, KVCache)


# 1. Edge Cases
def test_find_multiple_edge_cases():
assert find_multiple(0, 5) == 0
assert find_multiple(5, 0) == 5
assert find_multiple(0, 0) == 0


def test_precompute_freq_cis_edge_cases():
seq_len = 128
n_elem = 0
freqs = precompute_freq_cis(seq_len, n_elem)
assert freqs.shape == torch.Size([seq_len, 0, 2])


# 2. Additional KVCache Tests
def test_kv_cache_update_empty_input():
cache = KVCache(32, 128, 8, 64)
input_pos = torch.tensor([], dtype=torch.int64)
k_val = torch.randn(32, 8, 64, 64)
v_val = torch.randn(32, 8, 64, 64)
k_out, v_out = cache.update(input_pos, k_val, v_val)
assert k_out.shape == torch.Size([32, 8, 128, 64])
assert v_out.shape == torch.Size([32, 8, 128, 64])


def test_kv_cache_update_out_of_bounds_input():
cache = KVCache(32, 128, 8, 64)
input_pos = torch.tensor([140, 160, 200], dtype=torch.int64)
k_val = torch.randn(32, 8, 64, 64)
v_val = torch.randn(32, 8, 64, 64)
k_out, v_out = cache.update(input_pos, k_val, v_val)
assert k_out.shape == torch.Size([32, 8, 128, 64])
assert v_out.shape == torch.Size([32, 8, 128, 64])


# 3. Additional setup_cache Tests
def test_setup_cache_max_seq_len_greater_than_max():
layers = [
Mock(attention=Mock(kw_cache=None)),
Mock(attention=Mock(kw_cache=None)),
]
max_batch_size = 16
max_seq_len = 64
heads = 4
head_dim = 32
block_size = 32
rope_base = 1000
setup_cache(
max_batch_size,
max_seq_len + 10,
head_dim * heads,
heads,
layers,
block_size,
rope_base,
)
for layer in layers:
assert isinstance(layer.attention.kw_cache, KVCache)
assert layer.attention.kw_cache.k_cache.shape == torch.Size(
[max_batch_size, heads, max_seq_len + 10, head_dim]
)
assert layer.attention.kw_cache.v_cache.shape == torch.Size(
[max_batch_size, heads, max_seq_len + 10, head_dim]
)


def test_setup_cache_max_batch_size_greater_than_max():
layers = [
Mock(attention=Mock(kw_cache=None)),
Mock(attention=Mock(kw_cache=None)),
]
max_batch_size = 64
max_seq_len = 32
heads = 4
head_dim = 32
block_size = 32
rope_base = 1000
setup_cache(
max_batch_size + 10,
max_seq_len,
head_dim * heads,
heads,
layers,
block_size,
rope_base,
)
for layer in layers:
assert isinstance(layer.attention.kw_cache, KVCache)
assert layer.attention.kw_cache.k_cache.shape == torch.Size(
[max_batch_size + 10, heads, max_seq_len, head_dim]
)
assert layer.attention.kw_cache.v_cache.shape == torch.Size(
[max_batch_size + 10, heads, max_seq_len, head_dim]
)
2 changes: 2 additions & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from zeta.nn.modules.lang_conv_module import ConvolutionLanguageBlock
from zeta.nn.modules.s4 import s4d_kernel
from zeta.nn.modules.h3 import H3Layer
from zeta.nn.modules.mlp_mixer import MLPMixer


# from zeta.nn.modules.img_reshape import image_reshape
Expand Down Expand Up @@ -105,4 +106,5 @@
"IterativeCrossSelfAttention",
"ConvolutionLanguageBlock",
"H3Layer",
"MLPMixer",
]
157 changes: 157 additions & 0 deletions zeta/nn/modules/kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import torch
from torch import nn, Tensor


# Helpers
def find_multiple(n: int, k: int) -> int:
"""Finds the smallest multiple of k that is greater than or equal to n.
Args:
n (int): _description_
k (int): _description_
Returns:
int: _description_
"""
if n % k == 0:
return n
return n + k - (n % k)


def precompute_freq_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
"""Precomputes the frequency values for the positional encodings.
Args:
seq_len (int): _description_
n_elem (int): _description_
base (int, optional): _description_. Defaults to 10000.
Returns:
Tensor: _description_
"""
freqs = 1.0 / (
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
)
t = torch.arange(seq_len, device=freqs.device)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
return cache.to(dtype=torch.bfloat16)


class KVCache(nn.Module):
"""
KVCache is a module that stores the key and value tensors for each
position in the input sequence. This is used in the decoder of the
Transformer model to store the key and value tensors for each position
in the encoder output sequence.
The cache is updated by calling the update method, which takes the
input positions and the key and value tensors for those positions.
The cache is a tensor of shape [B, H, S, D], where B is the batch size,
H is the number of heads, S is the maximum sequence length, and D is
the head dimension.
Args:
max_batch_size: The maximum batch size of the model.
max_seq_len: The maximum sequence length of the model.
heads: The number of heads in the model.
head_dim: The dimension of each head.
dtype: The datatype of the cache.
Attributes:
k_cache: The key cache.
v_cache: The value cache.
Methods:
update: Updates the cache with the given input positions and key
and value tensors.
Input Shapes:
input_pos: [S]
k_val: [B, H, S, D]
v_val: [B, H, S, D]
Output Shapes:
k_out: [B, H, S, D]
v_out: [B, H, S, D]
Examples:
>>> from zeta.nn import KVCache
>>> cache = KVCache(32, 128, 8, 64)
>>> k_val = torch.randn(32, 8, 128, 64)
>>> v_val = torch.randn(32, 8, 128, 64)
>>> input_pos = torch.randint(0, 128, (5,))
>>> k_out, v_out = cache.update(input_pos, k_val, v_val)
>>> k_out.shape
torch.Size([32, 8, 128, 64])
"""

def __init__(
self,
max_batch_size: int,
max_seq_len: int,
heads: int,
head_dim: int,
dtype=torch.bfloat16,
):
super().__init__()
cache_shape = (max_batch_size, heads, max_seq_len, head_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))

def update(self, input_pos, k_val, v_val):
"""
Updates the cache with the given input positions and key and value.
Args:
input_pos (_type_): _description_
k_val (_type_): _description_
v_val (_type_): _description_
Returns:
_type_: _description_
"""
# Input pos: [5], k_val: [B, H, S, D]
assert input_pos.shape[0] == k_val.shape[2]

k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos, :] = k_val
v_out[:, :, input_pos, :] = v_val

return k_out, v_out


def setup_cache(
max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base
):
"""Sets up the cache for the given model.
Args:
max_batch_size (_type_): _description_
max_seq_len (_type_): _description_
dim (_type_): _description_
heads (_type_): _description_
layers (_type_): _description_
block_size (_type_): _description_
rope_base (_type_): _description_
"""
if max_seq_len >= max_seq_len and max_batch_size >= max_batch_size:
return

head_dim = dim // heads
max_seq_len = find_multiple(max_seq_len, 8)

for b in layers:
b.attention.kv_cache = KVCache(
max_batch_size, max_seq_len, heads, head_dim
)

freq_cis = precompute_freq_cis(block_size, dim // heads, rope_base)
causal_mask = torch.tril(
torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
)

return causal_mask, freq_cis
Loading

0 comments on commit 1457dcc

Please sign in to comment.