Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ET-VK] Introduce rotary embedding custom op #6392

Merged
merged 2 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}

${define_required_extensions(DTYPE)}

layout(std430) buffer;

${layout_declare_tensor(B, "w", "xqout", DTYPE, STORAGE)}
${layout_declare_tensor(B, "w", "xkout", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "xq", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "xk", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "freqs_cos", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "freqs_sin", DTYPE, STORAGE)}
${layout_declare_ubo(B, "ivec3", "xqout_limits")}
${layout_declare_ubo(B, "ivec3", "xkout_limits")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int packed_dim = 0;

#include "indexing_utils.h"

/*
* This shader computes rotary positional embeddings which are used in the Llama
* model architecture. There are 4 input tensors with the following shapes.
* Note that head_dim = embedding_dim / num_heads
*
* 1. xq (batch_size, sequence_len, num_heads, head_dim)
* 2. xk (batch_size, sequence_len, num_kv_heads, head_dim)
* 3. freqs_cos (sequence_len, head_dim / 2)
* 4. freqs_cos (sequence_len, head_dim / 2)
*
* Two output tensors are produced, with the same shapes as xq and xk
* respectively.
*
* The computation of rotary positional embeddings can be summarized with the
* following equations:
*
* xq_out[2i] = xq[2i] * freqs_cos[i] - xq[2i + 1] * freqs_sin[i]
* xq_out[2i + 1] = xq[2i] * freqs_sin[i] + xq[2i + 1] * freqs_cos[i]
*
* Essentially, taking each row along head_dim of the xq and xk tensors, each
* row is split into even and odd elements (xq[2i] and xq[2i + 1] respectively).
* The even components of the output multiply the even components of the inputs
* with the freqs_cos tensor, and the odd components of the inputs with the
* freqs_sin tensor. The odd components of the output swap this. Throughout the
* implementation the even components have the _r suffix and the odd components
* have the _i suffix; this is a reference to complex numbers which can be used
* to represent rotations.
*
* Note that this implementation assumes that all input tensors have the width
* dim as the packed dim.
*/
void main() {
// Each thread will write to two output locations to maximize data re-use.
// One texel loaded from the freqs_cos/freqs_sin tensors can be used to
// calculate two output texels.
const ivec3 x_pos_1 = ivec3(
gl_GlobalInvocationID.x * 2, gl_GlobalInvocationID.yz);
const ivec3 x_pos_2 = ivec3(x_pos_1.x + 1, x_pos_1.yz);

if (any(greaterThanEqual(x_pos_2, xqout_limits))) {
return;
}

const ivec3 freqs_pos = ivec3(gl_GlobalInvocationID.xz, 0);

VEC4_T cos_tex = load_texel(freqs_cos, freqs_pos);
VEC4_T sin_tex = load_texel(freqs_sin, freqs_pos);

// Compute xqout

VEC4_T x_tex_1 = load_texel(xq, x_pos_1);
VEC4_T x_tex_2 = load_texel(xq, x_pos_2);

// Separate into even and odd elements
VEC4_T x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz);
VEC4_T x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw);

VEC4_T xout_r = x_r * cos_tex - x_i * sin_tex;
VEC4_T xout_i = x_r * sin_tex + x_i * cos_tex;

VEC4_T xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y);
VEC4_T xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w);

write_texel(xqout, x_pos_1, xout_tex_1);
write_texel(xqout, x_pos_2, xout_tex_2);

// n_heads will be greater than or equal to n_kv_heads, therefore xq and xqout
// may have a larger height dim than xk and xkout. Only compute xkout if this
// invocation is still within bounds.
if (any(greaterThanEqual(x_pos_2, xkout_limits))) {
return;
}

// Compute xkout

x_tex_1 = load_texel(xk, x_pos_1);
x_tex_2 = load_texel(xk, x_pos_2);

x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz);
x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw);

xout_r = x_r * cos_tex - x_i * sin_tex;
xout_i = x_r * sin_tex + x_i * cos_tex;

xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y);
xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w);

write_texel(xkout, x_pos_1, xout_tex_1);
write_texel(xkout, x_pos_2, xout_tex_2);
}
10 changes: 10 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
rotary_embedding:
parameter_names_with_default_values:
DTYPE: float
STORAGE: texture3d
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: rotary_embedding
89 changes: 89 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

void resize_rotary_embedding_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
(void)extra_args;
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
vTensorPtr in = graph->get_tensor(args[1].refs[0]);

std::vector<int64_t> in_sizes = in->sizes();
// UNCOMMENT BELOW IF NEEDED
// out->virtual_resize(in_sizes);
}

void add_rotary_embedding_node(
ComputeGraph& graph,
const ValueRef xq,
const ValueRef xk,
const ValueRef freqs_cos,
const ValueRef freqs_sin,
const ValueRef xq_out,
const ValueRef xk_out) {
VK_CHECK_COND(graph.size_at<int>(-1, xq) == graph.size_at<int>(-1, xk));
VK_CHECK_COND(graph.size_at<int>(-3, xq) == graph.size_at<int>(-3, xk));
VK_CHECK_COND(
graph.size_at<int>(-1, xq) == graph.size_at<int>(-1, freqs_cos) * 2);
VK_CHECK_COND(graph.sizes_of(freqs_cos) == graph.sizes_of(freqs_sin));

VK_CHECK_COND(graph.packed_dim_of(xq) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(xk) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(freqs_cos) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(freqs_sin) == WHCN::kWidthDim);
VK_CHECK_COND(graph.has_standard_axis_map(xq));
VK_CHECK_COND(graph.has_standard_axis_map(xk));
VK_CHECK_COND(graph.has_standard_axis_map(freqs_cos));
VK_CHECK_COND(graph.has_standard_axis_map(freqs_sin));

std::string kernel_name = "rotary_embedding";
add_dtype_suffix(kernel_name, graph.dtype_of(xq_out));

utils::uvec3 global_wg_size = graph.logical_limits_of(xq_out);
global_wg_size[0] /= 2;
const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
// Shader
VK_KERNEL_FROM_STR(kernel_name),
// Workgroup sizes
global_wg_size,
local_wg_size,
// Inputs and Outputs
{{{xq_out, xk_out}, vkapi::kWrite},
{{xq, xk, freqs_cos, freqs_sin}, vkapi::kRead}},
// Parameter buffers
{graph.logical_limits_ubo(xq_out), graph.logical_limits_ubo(xk_out)},
// Specialization Constants
{},
// Resizing Logic
resize_rotary_embedding_node));
}

void apply_rotary_emb(ComputeGraph& graph, const std::vector<ValueRef>& args) {
const ValueListPtr out_tuple = graph.get_value_list(args[4]);
const ValueRef xq_out = out_tuple->at(0);
const ValueRef xk_out = out_tuple->at(1);

add_rotary_embedding_node(
graph, args[0], args[1], args[2], args[3], xq_out, xk_out);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(et_vk.apply_rotary_emb.default, apply_rotary_emb);
}

} // namespace vkcompute
Loading
Loading