From 5f27b75aa42eb047fd05dbc339aa94236a989e94 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Tue, 29 Oct 2024 14:46:50 -0700 Subject: [PATCH] move rope related logic together Right now, rope related code scatters around a few different places in `llama_transformer`. It makes it hard to make changes to rope related things. This PR moves all rope related logic into its own module. Differential Revision: [D65102268](https://our.internmc.facebook.com/intern/diff/D65102268/) [ghstack-poisoned] --- examples/models/llama/llama_transformer.py | 142 +++++++++++---------- 1 file changed, 75 insertions(+), 67 deletions(-) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 76e8730328..3656a66df6 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -143,6 +143,69 @@ def __post_init__(self): self.hidden_dim = find_multiple(hidden_dim, multiple_of) +class Rope(torch.nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + if self.params.use_hf_rope: + self.precompute_freqs_cis = hf_precompute_freqs_cis + else: + self.precompute_freqs_cis = partial( + precompute_freqs_cis, use_scaled=self.params.use_scaled_rope + ) + freqs_cos, freqs_sin = self.precompute_freqs_cis( + self.params.dim // self.params.n_heads, + ( + self.params.max_seq_len # Normal llama2. + if self.params.ffn_dim_multiplier is None + else self.params.max_seq_len * 2 # Sharded checkpoint. + ), + self.params.rope_freq_base, + ) + self.register_buffer("freqs_cos", freqs_cos, persistent=False) + self.register_buffer("freqs_sin", freqs_sin, persistent=False) + if self.params.use_hf_rope: + self.apply_rotary_emb = hf_apply_rotary_emb + else: + self.apply_rotary_emb = RotaryEmbedding() + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + seq_len: int, + input_pos: Optional[torch.Tensor] = None, + ): + if self.params.use_kv_cache: + assert ( + input_pos is not None + ), "input_pos must be provided when use_kv_cache is True" + + if self.params.enable_dynamic_shape: + # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. + input_pos_item = input_pos[-1].item() + torch._check_is_size(input_pos_item) + torch._check(input_pos_item < self.params.max_seq_len) + # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor + freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len) + # pyre-ignore: Incompatible parameter type [6] + freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len) + else: + # When not using dynamic shape, use of the .item results in + # symints, due to querying the data from tensor. + # this path avoids that for mps backend, although probably mps backend + # can support dynamic shape? + freqs_cos = self.freqs_cos[input_pos] + freqs_sin = self.freqs_sin[input_pos] + + else: + assert input_pos is None, "input_pos is unused when use_kv_cache is False" + freqs_cos = self.freqs_cos[:seq_len] + freqs_sin = self.freqs_sin[:seq_len] + q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin) + return q, k + + class KVCache(nn.Module): def __init__( self, @@ -262,7 +325,7 @@ def forward( class Attention(nn.Module): - def __init__(self, args: ModelArgs, layer_id: int): + def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): super().__init__() self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads @@ -284,6 +347,8 @@ def __init__(self, args: ModelArgs, layer_id: int): self.layer_id = layer_id + self.rope = rope + causal_mask = torch.tril( torch.ones( self.max_seq_len, @@ -300,7 +365,7 @@ def __init__(self, args: ModelArgs, layer_id: int): args.max_seq_len, self.n_kv_heads, self.head_dim, - not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v + not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v args.enable_dynamic_shape, ) self.SDPA = SDPA( @@ -311,16 +376,10 @@ def __init__(self, args: ModelArgs, layer_id: int): max_seq_len=self.max_seq_len, enable_dynamic_shape=args.enable_dynamic_shape, ) - if args.use_hf_rope: - self.apply_rotary_emb = hf_apply_rotary_emb - else: - self.apply_rotary_emb = RotaryEmbedding() def forward( self, x: torch.Tensor, - freqs_cos: torch.Tensor, - freqs_sin: torch.Tensor, input_pos: Optional[torch.Tensor] = None, ): bsz, seqlen, _ = x.shape @@ -333,7 +392,7 @@ def forward( v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) # RoPE relative positional embeddings - q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin) + q, k = self.rope.forward(q, k, seqlen, input_pos) if self.use_kv_cache: assert input_pos is not None @@ -421,13 +480,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock(nn.Module): - def __init__(self, layer_id: int, args: ModelArgs): + def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): super().__init__() self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.dim // args.n_heads - self.attention = Attention(args, layer_id) + self.attention = Attention(args, layer_id, rope) if args.moe: self.block_sparse_moe = MOEFeedForward(args) else: @@ -435,10 +494,8 @@ def __init__(self, layer_id: int, args: ModelArgs): self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) - def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN - h = self.attention.forward( - self.attention_norm(x), freqs_cos, freqs_sin, input_pos - ) + def forward(self, x, input_pos=None): # x: 1xN + h = self.attention.forward(self.attention_norm(x), input_pos) h = x + h if hasattr(self, "block_sparse_moe"): @@ -456,9 +513,10 @@ def __init__(self, params: ModelArgs): self.n_layers = params.n_layers self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + self.rope = Rope(params) self.layers = torch.nn.ModuleList() for layer_id in range(params.n_layers): - self.layers.append(TransformerBlock(layer_id, params)) + self.layers.append(TransformerBlock(layer_id, params, self.rope)) self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.use_kv_cache = params.use_kv_cache @@ -466,23 +524,6 @@ def __init__(self, params: ModelArgs): self.max_seq_len = params.max_seq_len self.input_prune_map = params.input_prune_map self.output_prune_map = params.output_prune_map - if params.use_hf_rope: - self.precompute_freqs_cis = hf_precompute_freqs_cis - else: - self.precompute_freqs_cis = partial( - precompute_freqs_cis, use_scaled=params.use_scaled_rope - ) - freqs_cos, freqs_sin = self.precompute_freqs_cis( - params.dim // params.n_heads, - ( - params.max_seq_len # Normal llama2. - if params.ffn_dim_multiplier is None - else params.max_seq_len * 2 # Sharded checkpoint. - ), - params.rope_freq_base, - ) - self.register_buffer("freqs_cos", freqs_cos, persistent=False) - self.register_buffer("freqs_sin", freqs_sin, persistent=False) def forward( self, @@ -498,42 +539,9 @@ def forward( ) if tokens is not None and h is None: h = self.tok_embeddings(tokens) - seqlen = h.shape[1] - - if self.use_kv_cache: - assert ( - input_pos is not None - ), "input_pos must be provided when use_kv_cache is True" - - if self.params.enable_dynamic_shape: - # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. - input_pos_item = input_pos[-1].item() - torch._check_is_size(input_pos_item) - torch._check(input_pos_item < self.params.max_seq_len) - # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor - freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen) - # pyre-ignore: Incompatible parameter type [6] - freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen) - else: - # When not using dynamic shape, use of the .item results in - # symints, due to querying the data from tensor. - # this path avoids that for mps backend, although probably mps backend - # can support dynamic shape? - freqs_cos = self.freqs_cos[input_pos] - freqs_sin = self.freqs_sin[input_pos] - - else: - assert input_pos is None, "input_pos is unused when use_kv_cache is False" - freqs_cos = self.freqs_cos[:seqlen] - freqs_sin = self.freqs_sin[:seqlen] for layer in self.layers: - h = layer( - h, - freqs_cos, - freqs_sin, - input_pos, - ) + h = layer(h, input_pos) if not self.generate_full_logits: # Only the last logit is used for the new generated token