From 8a46c774b7134c32a7cf2fbc80b90c8fc72eaea0 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Wed, 6 Nov 2024 16:05:11 -0800 Subject: [PATCH 1/2] Transform model to be able to use Attention Sink This PR adds necessary functions for transforming the model to be able to use Attention Sink. Differential Revision: [D65571289](https://our.internmc.facebook.com/intern/diff/D65571289/) [ghstack-poisoned] --- examples/models/llama/export_llama_lib.py | 7 ++ examples/models/llama/model.py | 14 +++ .../source_transformation/attention_sink.py | 114 +++++++++++++++++- 3 files changed, 134 insertions(+), 1 deletion(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 23b3589c2a..e94cc99bf5 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -432,6 +432,13 @@ def build_args_parser() -> argparse.ArgumentParser: help="type of embedding quantization for pre-quantized checkpoint, ',', e.g., '8,1024'.", ) + parser.add_argument( + "--use_attention_sink", + default="4,2044,1024", + type=str, + help="Use attention sink to have fluent multi-round conversation. ',,'" + ) + parser.add_argument( "--output_prune_map", default=None, diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 0f83e404a3..4a085d6c20 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -200,6 +200,20 @@ def __init__(self, **kwargs): ) sanitize_checkpoint_from_pre_quantization(checkpoint) + + if hasattr(self.args, "use_attention_sink"): + from .source_transformation.sink_attention import ( + enable_attention_sink, + ) + attention_sink_params = self.args.use_attention_sink.split(",") + assert len(attention_sink_params) == 3 + + self.model_ = enable_attention_sink( + module=self.model_, + params=model_args, + sink_size=int(attention_sink_params[0]), + window_size=int(attention_sink_params[1]), + eviction_batch_size=int(attention_sink_params[2])) # assign=True: load params/buffers by assignment instead of performing an in-place copy. # Because we are using device="meta", tensors do not have memory associated with them diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 5326d5c477..68d0f39bfd 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -7,13 +7,22 @@ # Components for supporting Attention Sink. See # https://arxiv.org/abs/2309.17453 for more details about Attention Sink. +import types +from typing import Optional + import torch -from executorch.examples.models.llama.llama_transformer import KVCache, ModelArgs, Rope +from executorch.examples.models.llama.llama_transformer import ( + Attention, + KVCache, + ModelArgs, + Rope, +) from executorch.examples.models.llama.rope import ( apply_rotary_emb_to_k, hf_apply_rotary_emb_to_k, ) +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter class RopeWithAttentionSink(Rope): @@ -167,3 +176,106 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: ) self.position_shift -= num_to_evict # pyre-ignore [8] return self.position_shift + + +def attention_sink_forward( + self, + x: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, +): + assert self.use_kv_cache + assert input_pos is not None + + bsz, seqlen, _ = x.shape + + # QKV + q, k, v = self.wq(x), self.wk(x), self.wv(x) + # We need view_copy elimination + q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + # Prepare for space in KV cache and get position shift + position_shift = self.kv_cache.evict_tokens(input_pos, seqlen) + + shifted_position = input_pos + position_shift + + # RoPE relative positional embeddings with shifted position in KV cache + q, k = self.rope.forward(q, k, shifted_position) + + output = self.SDPA(shifted_position, q, k, v, bsz, seqlen, self.mask) + return self.wo(output) + + +def _replace_rope( + module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink +): + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: + return isinstance(child, Rope) + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + return rope_with_attention_sink + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + + +def _replace_kv_cache( + module: torch.nn.Module, + rope_with_attention_sink: RopeWithAttentionSink, + sink_size: int, + window_size: int, + eviction_batch_size: int, +): + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: + return isinstance(child, KVCache) + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + kv_cache_with_attention_sink = KVCacheWithAttentionSink( + n_heads=child.n_heads, + head_dim=child.head_dim, + transpose_cache=child.transpose_cache, + enable_dynamic_shape=child.enable_dynamic_shape, + rope=rope_with_attention_sink, + max_batch_size=child.max_batch_size, + window_size=window_size, + sink_size=sink_size, + eviction_batch_size=eviction_batch_size, + dtype=child.k_cache.dtype, + ) + return kv_cache_with_attention_sink + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + + +def _replace_attention_forward(module: torch.nn.Module): + for name, child_module in module._modules.items(): + if len(list(child_module.children())) > 0: # pyre-ignore [16] + _replace_attention_forward(child_module) # pyre-ignore [6] + + if isinstance(child_module, Attention): + module._modules[name].forward = types.MethodType( # pyre-ignore + attention_sink_forward, module._modules[name] + ) + + +def enable_attention_sink( + module: torch.nn.Module, + params: ModelArgs, + sink_size: int = 4, + window_size: int = 2044, + eviction_batch_size: int = 1, +) -> torch.nn.Module: + """ + Transform the model to be able to run inference with Attention Sink. + There mainly three steps: + - Replace Rope with RopeWithAttentionSink + - Replace KVCache with KVCacheWithAttentionSink + - Replace Attention's forward with attention_sink_forward + """ + rope_with_attention_sink = RopeWithAttentionSink(params=params) + _replace_rope(module, rope_with_attention_sink) + _replace_kv_cache( + module, rope_with_attention_sink, sink_size, window_size, eviction_batch_size + ) + _replace_attention_forward(module) + return module From 1c0c17ca80b427b59efe33fe74f50d89e4c29f42 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Wed, 6 Nov 2024 18:09:09 -0800 Subject: [PATCH 2/2] Update on "Transform model to be able to use Attention Sink" This PR adds necessary functions for transforming the model to be able to use Attention Sink. Differential Revision: [D65571289](https://our.internmc.facebook.com/intern/diff/D65571289/) [ghstack-poisoned] --- examples/models/llama/export_llama_lib.py | 4 ++-- examples/models/llama/model.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index e94cc99bf5..10d9ee8f06 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -434,9 +434,9 @@ def build_args_parser() -> argparse.ArgumentParser: parser.add_argument( "--use_attention_sink", - default="4,2044,1024", + default=None, type=str, - help="Use attention sink to have fluent multi-round conversation. ',,'" + help="Use attention sink to have fluent multi-round conversation. ',,', e.g., '4,2044,1024'.", ) parser.add_argument( diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 4a085d6c20..b0016bddc6 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -200,11 +200,10 @@ def __init__(self, **kwargs): ) sanitize_checkpoint_from_pre_quantization(checkpoint) - - if hasattr(self.args, "use_attention_sink"): - from .source_transformation.sink_attention import ( - enable_attention_sink, - ) + + if hasattr(self.args, "use_attention_sink") and self.args.use_attention_sink: + from .source_transformation.attention_sink import enable_attention_sink + attention_sink_params = self.args.use_attention_sink.split(",") assert len(attention_sink_params) == 3 @@ -213,7 +212,8 @@ def __init__(self, **kwargs): params=model_args, sink_size=int(attention_sink_params[0]), window_size=int(attention_sink_params[1]), - eviction_batch_size=int(attention_sink_params[2])) + eviction_batch_size=int(attention_sink_params[2]), + ) # assign=True: load params/buffers by assignment instead of performing an in-place copy. # Because we are using device="meta", tensors do not have memory associated with them