diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 9a290968a3..ea4296cc52 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -448,6 +448,13 @@ def build_args_parser() -> argparse.ArgumentParser: help="type of embedding quantization for pre-quantized checkpoint, ',', e.g., '8,1024'.", ) + parser.add_argument( + "--use_attention_sink", + default=None, + type=str, + help="Use attention sink to have fluent multi-round conversation. ',,', e.g., '4,2044,1024'.", + ) + parser.add_argument( "--output_prune_map", default=None, diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 0f83e404a3..2385aba6d5 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -201,6 +201,25 @@ def __init__(self, **kwargs): sanitize_checkpoint_from_pre_quantization(checkpoint) + if hasattr(self.args, "use_attention_sink") and self.args.use_attention_sink: + from .source_transformation.attention_sink import enable_attention_sink + + attention_sink_params = self.args.use_attention_sink.split(",") + assert len(attention_sink_params) == 3 + sink_size = int(attention_sink_params[0]) + window_size = int(attention_sink_params[1]) + eviction_batch_size = int(attention_sink_params[2]) + + assert self.args.max_seq_length == sink_size + window_size + + self.model_ = enable_attention_sink( + module=self.model_, + params=model_args, + sink_size=sink_size, + window_size=window_size, + eviction_batch_size=eviction_batch_size, + ) + # assign=True: load params/buffers by assignment instead of performing an in-place copy. # Because we are using device="meta", tensors do not have memory associated with them # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario. diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 8450600d2b..b534a98e07 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -7,15 +7,22 @@ # Components for supporting Attention Sink. See # https://arxiv.org/abs/2309.17453 for more details about Attention Sink. +import types from typing import Optional import torch -from executorch.examples.models.llama.llama_transformer import KVCache, ModelArgs, Rope +from executorch.examples.models.llama.llama_transformer import ( + Attention, + KVCache, + ModelArgs, + Rope, +) from executorch.examples.models.llama.rope import ( apply_rotary_emb_to_k, hf_apply_rotary_emb_to_k, ) +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter class RopeWithAttentionSink(Rope): @@ -206,3 +213,112 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: ) self.position_shift -= num_to_evict # pyre-ignore [8] return self.position_shift + + +def attention_sink_forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, +): + assert self.use_kv_cache + assert input_pos is not None + + bsz, seqlen, _ = x.shape + + # QKV + q, k, v = self.wq(x), self.wk(x), self.wv(x) + # We need view_copy elimination + q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + # Prepare for space in KV cache and get position shift + position_shift = self.kv_cache.evict_tokens(input_pos, seqlen) + + # RoPE relative positional embeddings with shifted position in KV cache + q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + + output = self.SDPA(input_pos + position_shift, q, k, v, bsz, seqlen, self.mask) + return self.wo(output) + + +def _replace_rope( + module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink +): + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: + return isinstance(child, Rope) + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + return rope_with_attention_sink + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + + +def _replace_attention( + module: torch.nn.Module, + rope_with_attention_sink: RopeWithAttentionSink, + sink_size: int, + window_size: int, + eviction_batch_size: int, +): + for _, child_module in module._modules.items(): + if len(list(child_module.children())) > 0: # pyre-ignore [16] + _replace_attention( + module=child_module, # pyre-ignore [6] + rope_with_attention_sink=rope_with_attention_sink, + sink_size=sink_size, + window_size=window_size, + eviction_batch_size=eviction_batch_size, + ) + + if isinstance(child_module, Attention): + kv_cache = child_module.kv_cache + kv_cache_with_attention_sink = KVCacheWithAttentionSink( + n_heads=kv_cache.n_heads, + head_dim=kv_cache.head_dim, + transpose_cache=kv_cache.transpose_cache, + enable_dynamic_shape=kv_cache.enable_dynamic_shape, + rope=rope_with_attention_sink, + max_batch_size=kv_cache.max_batch_size, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + dtype=kv_cache.k_cache.dtype, + ) + child_module.kv_cache = kv_cache_with_attention_sink + child_module.SDPA.kv_cache = kv_cache_with_attention_sink + child_module.forward = types.MethodType( # pyre-ignore + attention_sink_forward, child_module + ) + + +def enable_attention_sink( + module: torch.nn.Module, + params: ModelArgs, + sink_size: int, + window_size: int, + eviction_batch_size: int, +) -> torch.nn.Module: + """ + Transform the model to be able to run inference with Attention Sink. + There mainly three steps: + - Replace Rope with RopeWithAttentionSink + - Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward + """ + rope_with_attention_sink = RopeWithAttentionSink( + params=params, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + ) + _replace_rope(module, rope_with_attention_sink) + _replace_attention( + module=module, + rope_with_attention_sink=rope_with_attention_sink, + sink_size=sink_size, + window_size=window_size, + eviction_batch_size=eviction_batch_size, + ) + return module