-
-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Kye
committed
Dec 1, 2023
1 parent
deb2513
commit 1457dcc
Showing
4 changed files
with
470 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
from unittest.mock import Mock | ||
import pytest | ||
import torch | ||
|
||
from zeta.nn.modules.kv_cache import ( | ||
KVCache, | ||
find_multiple, | ||
precompute_freq_cis, | ||
setup_cache, | ||
) | ||
|
||
|
||
# 1. Basic Tests | ||
def test_find_multiple(): | ||
assert find_multiple(10, 3) == 12 | ||
assert find_multiple(15, 5) == 15 | ||
assert find_multiple(20, 7) == 21 | ||
|
||
|
||
def test_precompute_freq_cis(): | ||
seq_len = 128 | ||
n_elem = 64 | ||
freqs = precompute_freq_cis(seq_len, n_elem) | ||
assert freqs.shape == torch.Size([seq_len, n_elem, 2]) | ||
|
||
|
||
def test_kv_cache_creation(): | ||
cache = KVCache(32, 128, 8, 64) | ||
assert isinstance(cache, KVCache) | ||
|
||
|
||
# 2. Utilize Fixtures | ||
@pytest.fixture | ||
def sample_cache(): | ||
return KVCache(16, 64, 4, 32) | ||
|
||
|
||
def test_kv_cache_update(sample_cache): | ||
input_pos = torch.randint(0, 64, (5,)) | ||
k_val = torch.randn(16, 4, 64, 32) | ||
v_val = torch.randn(16, 4, 64, 32) | ||
k_out, v_out = sample_cache.update(input_pos, k_val, v_val) | ||
assert k_out.shape == torch.Size([16, 4, 64, 32]) | ||
assert v_out.shape == torch.Size([16, 4, 64, 32]) | ||
|
||
|
||
# 3. Parameterized Testing | ||
@pytest.mark.parametrize( | ||
"max_batch_size, max_seq_len, heads, head_dim", | ||
[(32, 128, 8, 64), (16, 64, 4, 32)], | ||
) | ||
def test_setup_cache(max_batch_size, max_seq_len, heads, head_dim): | ||
layers = [ | ||
Mock(attention=Mock(kw_cache=None)), | ||
Mock(attention=Mock(kw_cache=None)), | ||
] | ||
block_size = 64 | ||
rope_base = 1000 | ||
setup_cache( | ||
max_batch_size, | ||
max_seq_len, | ||
head_dim * heads, | ||
heads, | ||
layers, | ||
block_size, | ||
rope_base, | ||
) | ||
for layer in layers: | ||
assert isinstance(layer.attention.kw_cache, KVCache) | ||
|
||
|
||
# 1. Edge Cases | ||
def test_find_multiple_edge_cases(): | ||
assert find_multiple(0, 5) == 0 | ||
assert find_multiple(5, 0) == 5 | ||
assert find_multiple(0, 0) == 0 | ||
|
||
|
||
def test_precompute_freq_cis_edge_cases(): | ||
seq_len = 128 | ||
n_elem = 0 | ||
freqs = precompute_freq_cis(seq_len, n_elem) | ||
assert freqs.shape == torch.Size([seq_len, 0, 2]) | ||
|
||
|
||
# 2. Additional KVCache Tests | ||
def test_kv_cache_update_empty_input(): | ||
cache = KVCache(32, 128, 8, 64) | ||
input_pos = torch.tensor([], dtype=torch.int64) | ||
k_val = torch.randn(32, 8, 64, 64) | ||
v_val = torch.randn(32, 8, 64, 64) | ||
k_out, v_out = cache.update(input_pos, k_val, v_val) | ||
assert k_out.shape == torch.Size([32, 8, 128, 64]) | ||
assert v_out.shape == torch.Size([32, 8, 128, 64]) | ||
|
||
|
||
def test_kv_cache_update_out_of_bounds_input(): | ||
cache = KVCache(32, 128, 8, 64) | ||
input_pos = torch.tensor([140, 160, 200], dtype=torch.int64) | ||
k_val = torch.randn(32, 8, 64, 64) | ||
v_val = torch.randn(32, 8, 64, 64) | ||
k_out, v_out = cache.update(input_pos, k_val, v_val) | ||
assert k_out.shape == torch.Size([32, 8, 128, 64]) | ||
assert v_out.shape == torch.Size([32, 8, 128, 64]) | ||
|
||
|
||
# 3. Additional setup_cache Tests | ||
def test_setup_cache_max_seq_len_greater_than_max(): | ||
layers = [ | ||
Mock(attention=Mock(kw_cache=None)), | ||
Mock(attention=Mock(kw_cache=None)), | ||
] | ||
max_batch_size = 16 | ||
max_seq_len = 64 | ||
heads = 4 | ||
head_dim = 32 | ||
block_size = 32 | ||
rope_base = 1000 | ||
setup_cache( | ||
max_batch_size, | ||
max_seq_len + 10, | ||
head_dim * heads, | ||
heads, | ||
layers, | ||
block_size, | ||
rope_base, | ||
) | ||
for layer in layers: | ||
assert isinstance(layer.attention.kw_cache, KVCache) | ||
assert layer.attention.kw_cache.k_cache.shape == torch.Size( | ||
[max_batch_size, heads, max_seq_len + 10, head_dim] | ||
) | ||
assert layer.attention.kw_cache.v_cache.shape == torch.Size( | ||
[max_batch_size, heads, max_seq_len + 10, head_dim] | ||
) | ||
|
||
|
||
def test_setup_cache_max_batch_size_greater_than_max(): | ||
layers = [ | ||
Mock(attention=Mock(kw_cache=None)), | ||
Mock(attention=Mock(kw_cache=None)), | ||
] | ||
max_batch_size = 64 | ||
max_seq_len = 32 | ||
heads = 4 | ||
head_dim = 32 | ||
block_size = 32 | ||
rope_base = 1000 | ||
setup_cache( | ||
max_batch_size + 10, | ||
max_seq_len, | ||
head_dim * heads, | ||
heads, | ||
layers, | ||
block_size, | ||
rope_base, | ||
) | ||
for layer in layers: | ||
assert isinstance(layer.attention.kw_cache, KVCache) | ||
assert layer.attention.kw_cache.k_cache.shape == torch.Size( | ||
[max_batch_size + 10, heads, max_seq_len, head_dim] | ||
) | ||
assert layer.attention.kw_cache.v_cache.shape == torch.Size( | ||
[max_batch_size + 10, heads, max_seq_len, head_dim] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import torch | ||
from torch import nn, Tensor | ||
|
||
|
||
# Helpers | ||
def find_multiple(n: int, k: int) -> int: | ||
"""Finds the smallest multiple of k that is greater than or equal to n. | ||
Args: | ||
n (int): _description_ | ||
k (int): _description_ | ||
Returns: | ||
int: _description_ | ||
""" | ||
if n % k == 0: | ||
return n | ||
return n + k - (n % k) | ||
|
||
|
||
def precompute_freq_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor: | ||
"""Precomputes the frequency values for the positional encodings. | ||
Args: | ||
seq_len (int): _description_ | ||
n_elem (int): _description_ | ||
base (int, optional): _description_. Defaults to 10000. | ||
Returns: | ||
Tensor: _description_ | ||
""" | ||
freqs = 1.0 / ( | ||
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) | ||
) | ||
t = torch.arange(seq_len, device=freqs.device) | ||
freqs = torch.outer(t, freqs) | ||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) | ||
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) | ||
return cache.to(dtype=torch.bfloat16) | ||
|
||
|
||
class KVCache(nn.Module): | ||
""" | ||
KVCache is a module that stores the key and value tensors for each | ||
position in the input sequence. This is used in the decoder of the | ||
Transformer model to store the key and value tensors for each position | ||
in the encoder output sequence. | ||
The cache is updated by calling the update method, which takes the | ||
input positions and the key and value tensors for those positions. | ||
The cache is a tensor of shape [B, H, S, D], where B is the batch size, | ||
H is the number of heads, S is the maximum sequence length, and D is | ||
the head dimension. | ||
Args: | ||
max_batch_size: The maximum batch size of the model. | ||
max_seq_len: The maximum sequence length of the model. | ||
heads: The number of heads in the model. | ||
head_dim: The dimension of each head. | ||
dtype: The datatype of the cache. | ||
Attributes: | ||
k_cache: The key cache. | ||
v_cache: The value cache. | ||
Methods: | ||
update: Updates the cache with the given input positions and key | ||
and value tensors. | ||
Input Shapes: | ||
input_pos: [S] | ||
k_val: [B, H, S, D] | ||
v_val: [B, H, S, D] | ||
Output Shapes: | ||
k_out: [B, H, S, D] | ||
v_out: [B, H, S, D] | ||
Examples: | ||
>>> from zeta.nn import KVCache | ||
>>> cache = KVCache(32, 128, 8, 64) | ||
>>> k_val = torch.randn(32, 8, 128, 64) | ||
>>> v_val = torch.randn(32, 8, 128, 64) | ||
>>> input_pos = torch.randint(0, 128, (5,)) | ||
>>> k_out, v_out = cache.update(input_pos, k_val, v_val) | ||
>>> k_out.shape | ||
torch.Size([32, 8, 128, 64]) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
max_batch_size: int, | ||
max_seq_len: int, | ||
heads: int, | ||
head_dim: int, | ||
dtype=torch.bfloat16, | ||
): | ||
super().__init__() | ||
cache_shape = (max_batch_size, heads, max_seq_len, head_dim) | ||
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) | ||
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) | ||
|
||
def update(self, input_pos, k_val, v_val): | ||
""" | ||
Updates the cache with the given input positions and key and value. | ||
Args: | ||
input_pos (_type_): _description_ | ||
k_val (_type_): _description_ | ||
v_val (_type_): _description_ | ||
Returns: | ||
_type_: _description_ | ||
""" | ||
# Input pos: [5], k_val: [B, H, S, D] | ||
assert input_pos.shape[0] == k_val.shape[2] | ||
|
||
k_out = self.k_cache | ||
v_out = self.v_cache | ||
k_out[:, :, input_pos, :] = k_val | ||
v_out[:, :, input_pos, :] = v_val | ||
|
||
return k_out, v_out | ||
|
||
|
||
def setup_cache( | ||
max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base | ||
): | ||
"""Sets up the cache for the given model. | ||
Args: | ||
max_batch_size (_type_): _description_ | ||
max_seq_len (_type_): _description_ | ||
dim (_type_): _description_ | ||
heads (_type_): _description_ | ||
layers (_type_): _description_ | ||
block_size (_type_): _description_ | ||
rope_base (_type_): _description_ | ||
""" | ||
if max_seq_len >= max_seq_len and max_batch_size >= max_batch_size: | ||
return | ||
|
||
head_dim = dim // heads | ||
max_seq_len = find_multiple(max_seq_len, 8) | ||
|
||
for b in layers: | ||
b.attention.kv_cache = KVCache( | ||
max_batch_size, max_seq_len, heads, head_dim | ||
) | ||
|
||
freq_cis = precompute_freq_cis(block_size, dim // heads, rope_base) | ||
causal_mask = torch.tril( | ||
torch.ones(max_seq_len, max_seq_len, dtype=torch.bool) | ||
) | ||
|
||
return causal_mask, freq_cis |
Oops, something went wrong.