Skip to content

Commit

Permalink
[ET-VK] Introduce rotary embedding custom op (#6392)
Browse files Browse the repository at this point in the history
## Context

As title; introduces a custom op to calculate rotary positional embeddings in LLMs. The custom op achieves the same result as the `apply_rotary_emb` Python function. Please see the documentation comments in the shader for more details.

Differential Revision: [D64697588](https://our.internmc.facebook.com/intern/diff/D64697588/)

[ghstack-poisoned]
  • Loading branch information
SS-JIA authored Oct 21, 2024
1 parent f914d9b commit 6abf85a
Show file tree
Hide file tree
Showing 5 changed files with 439 additions and 0 deletions.
123 changes: 123 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}

${define_required_extensions(DTYPE)}

layout(std430) buffer;

${layout_declare_tensor(B, "w", "xqout", DTYPE, STORAGE)}
${layout_declare_tensor(B, "w", "xkout", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "xq", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "xk", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "freqs_cos", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "freqs_sin", DTYPE, STORAGE)}
${layout_declare_ubo(B, "ivec3", "xqout_limits")}
${layout_declare_ubo(B, "ivec3", "xkout_limits")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int packed_dim = 0;

#include "indexing_utils.h"

/*
* This shader computes rotary positional embeddings which are used in the Llama
* model architecture. There are 4 input tensors with the following shapes.
* Note that head_dim = embedding_dim / num_heads
*
* 1. xq (batch_size, sequence_len, num_heads, head_dim)
* 2. xk (batch_size, sequence_len, num_kv_heads, head_dim)
* 3. freqs_cos (sequence_len, head_dim / 2)
* 4. freqs_cos (sequence_len, head_dim / 2)
*
* Two output tensors are produced, with the same shapes as xq and xk
* respectively.
*
* The computation of rotary positional embeddings can be summarized with the
* following equations:
*
* xq_out[2i] = xq[2i] * freqs_cos[i] - xq[2i + 1] * freqs_sin[i]
* xq_out[2i + 1] = xq[2i] * freqs_sin[i] + xq[2i + 1] * freqs_cos[i]
*
* Essentially, taking each row along head_dim of the xq and xk tensors, each
* row is split into even and odd elements (xq[2i] and xq[2i + 1] respectively).
* The even components of the output multiply the even components of the inputs
* with the freqs_cos tensor, and the odd components of the inputs with the
* freqs_sin tensor. The odd components of the output swap this. Throughout the
* implementation the even components have the _r suffix and the odd components
* have the _i suffix; this is a reference to complex numbers which can be used
* to represent rotations.
*
* Note that this implementation assumes that all input tensors have the width
* dim as the packed dim.
*/
void main() {
// Each thread will write to two output locations to maximize data re-use.
// One texel loaded from the freqs_cos/freqs_sin tensors can be used to
// calculate two output texels.
const ivec3 x_pos_1 = ivec3(
gl_GlobalInvocationID.x * 2, gl_GlobalInvocationID.yz);
const ivec3 x_pos_2 = ivec3(x_pos_1.x + 1, x_pos_1.yz);

if (any(greaterThanEqual(x_pos_2, xqout_limits))) {
return;
}

const ivec3 freqs_pos = ivec3(gl_GlobalInvocationID.xz, 0);

VEC4_T cos_tex = load_texel(freqs_cos, freqs_pos);
VEC4_T sin_tex = load_texel(freqs_sin, freqs_pos);

// Compute xqout

VEC4_T x_tex_1 = load_texel(xq, x_pos_1);
VEC4_T x_tex_2 = load_texel(xq, x_pos_2);

// Separate into even and odd elements
VEC4_T x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz);
VEC4_T x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw);

VEC4_T xout_r = x_r * cos_tex - x_i * sin_tex;
VEC4_T xout_i = x_r * sin_tex + x_i * cos_tex;

VEC4_T xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y);
VEC4_T xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w);

write_texel(xqout, x_pos_1, xout_tex_1);
write_texel(xqout, x_pos_2, xout_tex_2);

// n_heads will be greater than or equal to n_kv_heads, therefore xq and xqout
// may have a larger height dim than xk and xkout. Only compute xkout if this
// invocation is still within bounds.
if (any(greaterThanEqual(x_pos_2, xkout_limits))) {
return;
}

// Compute xkout

x_tex_1 = load_texel(xk, x_pos_1);
x_tex_2 = load_texel(xk, x_pos_2);

x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz);
x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw);

xout_r = x_r * cos_tex - x_i * sin_tex;
xout_i = x_r * sin_tex + x_i * cos_tex;

xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y);
xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w);

write_texel(xkout, x_pos_1, xout_tex_1);
write_texel(xkout, x_pos_2, xout_tex_2);
}
10 changes: 10 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
rotary_embedding:
parameter_names_with_default_values:
DTYPE: float
STORAGE: texture3d
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: rotary_embedding
89 changes: 89 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

void resize_rotary_embedding_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
(void)extra_args;
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
vTensorPtr in = graph->get_tensor(args[1].refs[0]);

std::vector<int64_t> in_sizes = in->sizes();
// UNCOMMENT BELOW IF NEEDED
// out->virtual_resize(in_sizes);
}

void add_rotary_embedding_node(
ComputeGraph& graph,
const ValueRef xq,
const ValueRef xk,
const ValueRef freqs_cos,
const ValueRef freqs_sin,
const ValueRef xq_out,
const ValueRef xk_out) {
VK_CHECK_COND(graph.size_at<int>(-1, xq) == graph.size_at<int>(-1, xk));
VK_CHECK_COND(graph.size_at<int>(-3, xq) == graph.size_at<int>(-3, xk));
VK_CHECK_COND(
graph.size_at<int>(-1, xq) == graph.size_at<int>(-1, freqs_cos) * 2);
VK_CHECK_COND(graph.sizes_of(freqs_cos) == graph.sizes_of(freqs_sin));

VK_CHECK_COND(graph.packed_dim_of(xq) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(xk) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(freqs_cos) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(freqs_sin) == WHCN::kWidthDim);
VK_CHECK_COND(graph.has_standard_axis_map(xq));
VK_CHECK_COND(graph.has_standard_axis_map(xk));
VK_CHECK_COND(graph.has_standard_axis_map(freqs_cos));
VK_CHECK_COND(graph.has_standard_axis_map(freqs_sin));

std::string kernel_name = "rotary_embedding";
add_dtype_suffix(kernel_name, graph.dtype_of(xq_out));

utils::uvec3 global_wg_size = graph.logical_limits_of(xq_out);
global_wg_size[0] /= 2;
const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
// Shader
VK_KERNEL_FROM_STR(kernel_name),
// Workgroup sizes
global_wg_size,
local_wg_size,
// Inputs and Outputs
{{{xq_out, xk_out}, vkapi::kWrite},
{{xq, xk, freqs_cos, freqs_sin}, vkapi::kRead}},
// Parameter buffers
{graph.logical_limits_ubo(xq_out), graph.logical_limits_ubo(xk_out)},
// Specialization Constants
{},
// Resizing Logic
resize_rotary_embedding_node));
}

void apply_rotary_emb(ComputeGraph& graph, const std::vector<ValueRef>& args) {
const ValueListPtr out_tuple = graph.get_value_list(args[4]);
const ValueRef xq_out = out_tuple->at(0);
const ValueRef xk_out = out_tuple->at(1);

add_rotary_embedding_node(
graph, args[0], args[1], args[2], args[3], xq_out, xk_out);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(et_vk.apply_rotary_emb.default, apply_rotary_emb);
}

} // namespace vkcompute
Loading

0 comments on commit 6abf85a

Please sign in to comment.