Skip to content

Commit

Permalink
Update on "[ET-VK] Introduce rotary embedding custom op"
Browse files Browse the repository at this point in the history
## Context

As title; introduces a custom op to calculate rotary positional embeddings in LLMs. The custom op achieves the same result as the `apply_rotary_emb` Python function. Please see the documentation comments in the shader for more details.

Differential Revision: [D64697588](https://our.internmc.facebook.com/intern/diff/D64697588/)

[ghstack-poisoned]
  • Loading branch information
SS-JIA committed Oct 21, 2024
2 parents fda00c2 + f914d9b commit 6547b38
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
11 changes: 7 additions & 4 deletions backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ layout(constant_id = 3) const int packed_dim = 0;
*
* The computation of rotary positional embeddings can be summarized with the
* following equations:
*
* xq_out[2i] = xq[2i] * freqs_cos[i] - xq[2i + 1] * freqs_sin[i]
* xq_out[2i + 1] = xq[2i] * freqs_sin[i] + xq[2i + 1] * freqs_cos[i]
*
Expand All @@ -55,9 +55,9 @@ layout(constant_id = 3) const int packed_dim = 0;
* The even components of the output multiply the even components of the inputs
* with the freqs_cos tensor, and the odd components of the inputs with the
* freqs_sin tensor. The odd components of the output swap this. Throughout the
* implements the even components have the _r suffix and the odd components have
* the _i suffix; this is likely a reference to complex numbers which can be
* used to represent rotations.
* implementation the even components have the _r suffix and the odd components
* have the _i suffix; this is a reference to complex numbers which can be used
* to represent rotations.
*
* Note that this implementation assumes that all input tensors have the width
* dim as the packed dim.
Expand Down Expand Up @@ -97,6 +97,9 @@ void main() {
write_texel(xqout, x_pos_1, xout_tex_1);
write_texel(xqout, x_pos_2, xout_tex_2);

// n_heads will be greater than or equal to n_kv_heads, therefore xq and xqout
// may have a larger height dim than xk and xkout. Only compute xkout if this
// invocation is still within bounds.
if (any(greaterThanEqual(x_pos_2, xkout_limits))) {
return;
}
Expand Down
12 changes: 10 additions & 2 deletions backends/vulkan/test/op_tests/rotary_embedding_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ void test_reference(
graph.copy_from_staging(
staging_xk_out, vk_xk_out.mutable_data_ptr(), vk_xk_out.numel());

EXPECT_TRUE(at::allclose(xq_out, vk_xq_out));
EXPECT_TRUE(at::allclose(xk_out, vk_xk_out));
EXPECT_TRUE(at::allclose(xq_out, vk_xq_out, 1e-4, 1e-4));
EXPECT_TRUE(at::allclose(xk_out, vk_xk_out, 1e-4, 1e-4));
}

TEST(VulkanRotaryEmbeddingTest, rotary_embedding_test) {
Expand All @@ -170,3 +170,11 @@ TEST(VulkanRotaryEmbeddingTest, rotary_embedding_llama3_params_test) {
/*n_kv_heads=*/8,
/*dim=*/2048);
}

TEST(VulkanRotaryEmbeddingTest, rotary_embedding_llama3_params_test_seq_len_3) {
test_reference(
/*n_heads=*/32,
/*n_kv_heads=*/8,
/*dim=*/2048,
/*seq_len=*/3);
}

0 comments on commit 6547b38

Please sign in to comment.