Skip to content

Commit

Permalink
[ET-VK][ez] Apply rotary embedding as Module (#6422)
Browse files Browse the repository at this point in the history
Pull Request resolved: #6391

## Context

As title. Wrap the `apply_rotary_emb` function call in a `nn.Module` to make it easy to perform a source module replacement for rotary embedding calculation.

The Vulkan delegate will use the source module replacement technique to insert a custom op to calculate rotary embeddings.
ghstack-source-id: 249175724
@exported-using-ghexport

Differential Revision: [D64697589](https://our.internmc.facebook.com/intern/diff/D64697589/)

Co-authored-by: Stephen Jia <[email protected]>
  • Loading branch information
kirklandsign and SS-JIA authored Oct 21, 2024
1 parent 1247545 commit 0309854
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import torch.nn.functional as F

from executorch.examples.models.llama.rope import (
apply_rotary_emb,
hf_apply_rotary_emb,
hf_precompute_freqs_cis,
precompute_freqs_cis,
RotaryEmbedding,
)

from torch import nn
Expand Down Expand Up @@ -311,7 +311,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
if args.use_hf_rope:
self.apply_rotary_emb = hf_apply_rotary_emb
else:
self.apply_rotary_emb = apply_rotary_emb
self.apply_rotary_emb = RotaryEmbedding()

def forward(
self,
Expand Down
15 changes: 15 additions & 0 deletions examples/models/llama/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ def apply_rotary_emb(
return xq_out.type_as(xq), xk_out.type_as(xk)


class RotaryEmbedding(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(
self,
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
):
xq_out, xk_out = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
return xq_out, xk_out


# ======================= HuggingFace Implementation ========================


Expand Down

0 comments on commit 0309854

Please sign in to comment.