Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform model to be able to use Attention Sink #6700

Merged
7 changes: 7 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,13 @@ def build_args_parser() -> argparse.ArgumentParser:
help="type of embedding quantization for pre-quantized checkpoint, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
)

parser.add_argument(
"--use_attention_sink",
default=None,
type=str,
help="Use attention sink to have fluent multi-round conversation. '<sink_size>,<window_size>,<batch_eviction_size>', e.g., '4,2044,1024'.",
)

parser.add_argument(
"--output_prune_map",
default=None,
Expand Down
19 changes: 19 additions & 0 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,25 @@ def __init__(self, **kwargs):

sanitize_checkpoint_from_pre_quantization(checkpoint)

if hasattr(self.args, "use_attention_sink") and self.args.use_attention_sink:
from .source_transformation.attention_sink import enable_attention_sink

attention_sink_params = self.args.use_attention_sink.split(",")
assert len(attention_sink_params) == 3
sink_size = int(attention_sink_params[0])
window_size = int(attention_sink_params[1])
eviction_batch_size = int(attention_sink_params[2])

assert self.args.max_seq_length == sink_size + window_size

self.model_ = enable_attention_sink(
module=self.model_,
params=model_args,
sink_size=sink_size,
window_size=window_size,
eviction_batch_size=eviction_batch_size,
)

# assign=True: load params/buffers by assignment instead of performing an in-place copy.
# Because we are using device="meta", tensors do not have memory associated with them
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
Expand Down
118 changes: 117 additions & 1 deletion examples/models/llama/source_transformation/attention_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,22 @@
# Components for supporting Attention Sink. See
# https://arxiv.org/abs/2309.17453 for more details about Attention Sink.

import types
from typing import Optional

import torch

from executorch.examples.models.llama.llama_transformer import KVCache, ModelArgs, Rope
from executorch.examples.models.llama.llama_transformer import (
Attention,
KVCache,
ModelArgs,
Rope,
)
from executorch.examples.models.llama.rope import (
apply_rotary_emb_to_k,
hf_apply_rotary_emb_to_k,
)
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter


class RopeWithAttentionSink(Rope):
Expand Down Expand Up @@ -206,3 +213,112 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
)
self.position_shift -= num_to_evict # pyre-ignore [8]
return self.position_shift


def attention_sink_forward(
self,
x: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
):
assert self.use_kv_cache
assert input_pos is not None

bsz, seqlen, _ = x.shape

# QKV
q, k, v = self.wq(x), self.wk(x), self.wv(x)
# We need view_copy elimination
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

# Prepare for space in KV cache and get position shift
position_shift = self.kv_cache.evict_tokens(input_pos, seqlen)

# RoPE relative positional embeddings with shifted position in KV cache
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)

output = self.SDPA(input_pos + position_shift, q, k, v, bsz, seqlen, self.mask)
return self.wo(output)


def _replace_rope(
module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink
):
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
return isinstance(child, Rope)

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
return rope_with_attention_sink

_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)


def _replace_attention(
module: torch.nn.Module,
rope_with_attention_sink: RopeWithAttentionSink,
sink_size: int,
window_size: int,
eviction_batch_size: int,
):
for _, child_module in module._modules.items():
if len(list(child_module.children())) > 0: # pyre-ignore [16]
_replace_attention(
module=child_module, # pyre-ignore [6]
rope_with_attention_sink=rope_with_attention_sink,
sink_size=sink_size,
window_size=window_size,
eviction_batch_size=eviction_batch_size,
)

if isinstance(child_module, Attention):
kv_cache = child_module.kv_cache
kv_cache_with_attention_sink = KVCacheWithAttentionSink(
n_heads=kv_cache.n_heads,
head_dim=kv_cache.head_dim,
transpose_cache=kv_cache.transpose_cache,
enable_dynamic_shape=kv_cache.enable_dynamic_shape,
rope=rope_with_attention_sink,
max_batch_size=kv_cache.max_batch_size,
window_size=window_size,
sink_size=sink_size,
eviction_batch_size=eviction_batch_size,
dtype=kv_cache.k_cache.dtype,
)
child_module.kv_cache = kv_cache_with_attention_sink
child_module.SDPA.kv_cache = kv_cache_with_attention_sink
child_module.forward = types.MethodType( # pyre-ignore
attention_sink_forward, child_module
)


def enable_attention_sink(
module: torch.nn.Module,
params: ModelArgs,
sink_size: int,
window_size: int,
eviction_batch_size: int,
) -> torch.nn.Module:
"""
Transform the model to be able to run inference with Attention Sink.
There mainly three steps:
- Replace Rope with RopeWithAttentionSink
- Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward
"""
rope_with_attention_sink = RopeWithAttentionSink(
params=params,
window_size=window_size,
sink_size=sink_size,
eviction_batch_size=eviction_batch_size,
)
_replace_rope(module, rope_with_attention_sink)
_replace_attention(
module=module,
rope_with_attention_sink=rope_with_attention_sink,
sink_size=sink_size,
window_size=window_size,
eviction_batch_size=eviction_batch_size,
)
return module