-
Notifications
You must be signed in to change notification settings - Fork 388
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ET-VK] Introduce rotary embedding custom op (#6392)
## Context As title; introduces a custom op to calculate rotary positional embeddings in LLMs. The custom op achieves the same result as the `apply_rotary_emb` Python function. Please see the documentation comments in the shader for more details. Differential Revision: [D64697588](https://our.internmc.facebook.com/intern/diff/D64697588/) [ghstack-poisoned]
- Loading branch information
Showing
5 changed files
with
439 additions
and
0 deletions.
There are no files selected for viewing
123 changes: 123 additions & 0 deletions
123
backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#version 450 core | ||
|
||
#define PRECISION ${PRECISION} | ||
|
||
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} | ||
|
||
${define_required_extensions(DTYPE)} | ||
|
||
layout(std430) buffer; | ||
|
||
${layout_declare_tensor(B, "w", "xqout", DTYPE, STORAGE)} | ||
${layout_declare_tensor(B, "w", "xkout", DTYPE, STORAGE)} | ||
${layout_declare_tensor(B, "r", "xq", DTYPE, STORAGE)} | ||
${layout_declare_tensor(B, "r", "xk", DTYPE, STORAGE)} | ||
${layout_declare_tensor(B, "r", "freqs_cos", DTYPE, STORAGE)} | ||
${layout_declare_tensor(B, "r", "freqs_sin", DTYPE, STORAGE)} | ||
${layout_declare_ubo(B, "ivec3", "xqout_limits")} | ||
${layout_declare_ubo(B, "ivec3", "xkout_limits")} | ||
|
||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | ||
|
||
layout(constant_id = 3) const int packed_dim = 0; | ||
|
||
#include "indexing_utils.h" | ||
|
||
/* | ||
* This shader computes rotary positional embeddings which are used in the Llama | ||
* model architecture. There are 4 input tensors with the following shapes. | ||
* Note that head_dim = embedding_dim / num_heads | ||
* | ||
* 1. xq (batch_size, sequence_len, num_heads, head_dim) | ||
* 2. xk (batch_size, sequence_len, num_kv_heads, head_dim) | ||
* 3. freqs_cos (sequence_len, head_dim / 2) | ||
* 4. freqs_cos (sequence_len, head_dim / 2) | ||
* | ||
* Two output tensors are produced, with the same shapes as xq and xk | ||
* respectively. | ||
* | ||
* The computation of rotary positional embeddings can be summarized with the | ||
* following equations: | ||
* | ||
* xq_out[2i] = xq[2i] * freqs_cos[i] - xq[2i + 1] * freqs_sin[i] | ||
* xq_out[2i + 1] = xq[2i] * freqs_sin[i] + xq[2i + 1] * freqs_cos[i] | ||
* | ||
* Essentially, taking each row along head_dim of the xq and xk tensors, each | ||
* row is split into even and odd elements (xq[2i] and xq[2i + 1] respectively). | ||
* The even components of the output multiply the even components of the inputs | ||
* with the freqs_cos tensor, and the odd components of the inputs with the | ||
* freqs_sin tensor. The odd components of the output swap this. Throughout the | ||
* implementation the even components have the _r suffix and the odd components | ||
* have the _i suffix; this is a reference to complex numbers which can be used | ||
* to represent rotations. | ||
* | ||
* Note that this implementation assumes that all input tensors have the width | ||
* dim as the packed dim. | ||
*/ | ||
void main() { | ||
// Each thread will write to two output locations to maximize data re-use. | ||
// One texel loaded from the freqs_cos/freqs_sin tensors can be used to | ||
// calculate two output texels. | ||
const ivec3 x_pos_1 = ivec3( | ||
gl_GlobalInvocationID.x * 2, gl_GlobalInvocationID.yz); | ||
const ivec3 x_pos_2 = ivec3(x_pos_1.x + 1, x_pos_1.yz); | ||
|
||
if (any(greaterThanEqual(x_pos_2, xqout_limits))) { | ||
return; | ||
} | ||
|
||
const ivec3 freqs_pos = ivec3(gl_GlobalInvocationID.xz, 0); | ||
|
||
VEC4_T cos_tex = load_texel(freqs_cos, freqs_pos); | ||
VEC4_T sin_tex = load_texel(freqs_sin, freqs_pos); | ||
|
||
// Compute xqout | ||
|
||
VEC4_T x_tex_1 = load_texel(xq, x_pos_1); | ||
VEC4_T x_tex_2 = load_texel(xq, x_pos_2); | ||
|
||
// Separate into even and odd elements | ||
VEC4_T x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz); | ||
VEC4_T x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw); | ||
|
||
VEC4_T xout_r = x_r * cos_tex - x_i * sin_tex; | ||
VEC4_T xout_i = x_r * sin_tex + x_i * cos_tex; | ||
|
||
VEC4_T xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y); | ||
VEC4_T xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w); | ||
|
||
write_texel(xqout, x_pos_1, xout_tex_1); | ||
write_texel(xqout, x_pos_2, xout_tex_2); | ||
|
||
// n_heads will be greater than or equal to n_kv_heads, therefore xq and xqout | ||
// may have a larger height dim than xk and xkout. Only compute xkout if this | ||
// invocation is still within bounds. | ||
if (any(greaterThanEqual(x_pos_2, xkout_limits))) { | ||
return; | ||
} | ||
|
||
// Compute xkout | ||
|
||
x_tex_1 = load_texel(xk, x_pos_1); | ||
x_tex_2 = load_texel(xk, x_pos_2); | ||
|
||
x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz); | ||
x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw); | ||
|
||
xout_r = x_r * cos_tex - x_i * sin_tex; | ||
xout_i = x_r * sin_tex + x_i * cos_tex; | ||
|
||
xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y); | ||
xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w); | ||
|
||
write_texel(xkout, x_pos_1, xout_tex_1); | ||
write_texel(xkout, x_pos_2, xout_tex_2); | ||
} |
10 changes: 10 additions & 0 deletions
10
backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
rotary_embedding: | ||
parameter_names_with_default_values: | ||
DTYPE: float | ||
STORAGE: texture3d | ||
generate_variant_forall: | ||
DTYPE: | ||
- VALUE: half | ||
- VALUE: float | ||
shader_variants: | ||
- NAME: rotary_embedding |
89 changes: 89 additions & 0 deletions
89
backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h> | ||
|
||
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h> | ||
|
||
namespace vkcompute { | ||
|
||
void resize_rotary_embedding_node( | ||
ComputeGraph* graph, | ||
const std::vector<ArgGroup>& args, | ||
const std::vector<ValueRef>& extra_args) { | ||
(void)extra_args; | ||
vTensorPtr out = graph->get_tensor(args[0].refs[0]); | ||
vTensorPtr in = graph->get_tensor(args[1].refs[0]); | ||
|
||
std::vector<int64_t> in_sizes = in->sizes(); | ||
// UNCOMMENT BELOW IF NEEDED | ||
// out->virtual_resize(in_sizes); | ||
} | ||
|
||
void add_rotary_embedding_node( | ||
ComputeGraph& graph, | ||
const ValueRef xq, | ||
const ValueRef xk, | ||
const ValueRef freqs_cos, | ||
const ValueRef freqs_sin, | ||
const ValueRef xq_out, | ||
const ValueRef xk_out) { | ||
VK_CHECK_COND(graph.size_at<int>(-1, xq) == graph.size_at<int>(-1, xk)); | ||
VK_CHECK_COND(graph.size_at<int>(-3, xq) == graph.size_at<int>(-3, xk)); | ||
VK_CHECK_COND( | ||
graph.size_at<int>(-1, xq) == graph.size_at<int>(-1, freqs_cos) * 2); | ||
VK_CHECK_COND(graph.sizes_of(freqs_cos) == graph.sizes_of(freqs_sin)); | ||
|
||
VK_CHECK_COND(graph.packed_dim_of(xq) == WHCN::kWidthDim); | ||
VK_CHECK_COND(graph.packed_dim_of(xk) == WHCN::kWidthDim); | ||
VK_CHECK_COND(graph.packed_dim_of(freqs_cos) == WHCN::kWidthDim); | ||
VK_CHECK_COND(graph.packed_dim_of(freqs_sin) == WHCN::kWidthDim); | ||
VK_CHECK_COND(graph.has_standard_axis_map(xq)); | ||
VK_CHECK_COND(graph.has_standard_axis_map(xk)); | ||
VK_CHECK_COND(graph.has_standard_axis_map(freqs_cos)); | ||
VK_CHECK_COND(graph.has_standard_axis_map(freqs_sin)); | ||
|
||
std::string kernel_name = "rotary_embedding"; | ||
add_dtype_suffix(kernel_name, graph.dtype_of(xq_out)); | ||
|
||
utils::uvec3 global_wg_size = graph.logical_limits_of(xq_out); | ||
global_wg_size[0] /= 2; | ||
const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); | ||
|
||
graph.execute_nodes().emplace_back(new DispatchNode( | ||
graph, | ||
// Shader | ||
VK_KERNEL_FROM_STR(kernel_name), | ||
// Workgroup sizes | ||
global_wg_size, | ||
local_wg_size, | ||
// Inputs and Outputs | ||
{{{xq_out, xk_out}, vkapi::kWrite}, | ||
{{xq, xk, freqs_cos, freqs_sin}, vkapi::kRead}}, | ||
// Parameter buffers | ||
{graph.logical_limits_ubo(xq_out), graph.logical_limits_ubo(xk_out)}, | ||
// Specialization Constants | ||
{}, | ||
// Resizing Logic | ||
resize_rotary_embedding_node)); | ||
} | ||
|
||
void apply_rotary_emb(ComputeGraph& graph, const std::vector<ValueRef>& args) { | ||
const ValueListPtr out_tuple = graph.get_value_list(args[4]); | ||
const ValueRef xq_out = out_tuple->at(0); | ||
const ValueRef xk_out = out_tuple->at(1); | ||
|
||
add_rotary_embedding_node( | ||
graph, args[0], args[1], args[2], args[3], xq_out, xk_out); | ||
} | ||
|
||
REGISTER_OPERATORS { | ||
VK_REGISTER_OP(et_vk.apply_rotary_emb.default, apply_rotary_emb); | ||
} | ||
|
||
} // namespace vkcompute |
Oops, something went wrong.