forked from nod-ai/shark-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Revert "[sharktank] Revert "[llama] Added the fused rotary embedding …
…kernel (nod-ai#719)" (nod-ai#752)" This reverts commit 63ff841.
- Loading branch information
Showing
9 changed files
with
200 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from sharktank.kernels.base import * | ||
|
||
__all__ = [ | ||
"apply_rotary_embedding", | ||
] | ||
|
||
|
||
@CustomOp.register(library=LIBRARY) | ||
class apply_rotary_embedding(CustomOp): | ||
|
||
signature = "apply_rotary_embedding(Tensor input, Tensor table) -> (Tensor)" | ||
|
||
def select(self, ksel: KernelSelection): | ||
inputs_desc = ksel.arg_tensor(0) | ||
table_desc = ksel.arg_tensor(1) | ||
out_desc = ksel.return_new_tensor( | ||
inputs_desc.t.shape, dtype=inputs_desc.t.dtype | ||
) | ||
specialize_all_known_dims(inputs_desc) | ||
specialize_all_known_dims(table_desc) | ||
specialize_all_known_dims(out_desc) | ||
|
||
def generate(self, ksel: KernelSelection, kb: KernelBuilder): | ||
|
||
input = kb.arg_value(0) | ||
table = kb.arg_value(1) | ||
|
||
input_tensor_type = RankedTensorType(input.type) | ||
table_tensor_type = RankedTensorType(table.type) | ||
|
||
input_asm_type, input_ident, input_dtype = unpack_tensor_type(input.type) | ||
table_asm_type, table_ident, table_dtype = unpack_tensor_type(table.type) | ||
|
||
assert input_dtype == table_dtype | ||
|
||
# Generate specialization signature and types. | ||
bs = input.type.shape[0] | ||
sl = input.type.shape[1] | ||
sl = "D" if sl < 0 else sl | ||
heads = input.type.shape[2] | ||
dims = input.type.shape[3] | ||
|
||
template_file = "rotary_embedding.mlir" | ||
target_function_name = ( | ||
f"sharktank_rotary_embedding_{bs}_{sl}_{heads}_{dims}_{input_dtype}" | ||
) | ||
|
||
# Template params. | ||
input_tensor_type = input_asm_type | ||
table_tensor_type = table_asm_type | ||
|
||
target_function = inline_template_function( | ||
kb, | ||
template_file, | ||
target_function_name, | ||
input_tensor_type=input_tensor_type, | ||
table_tensor_type=table_tensor_type, | ||
bs=bs, | ||
sl=sl, | ||
heads=heads, | ||
dims=dims, | ||
dtype=str(input_dtype), | ||
) | ||
kb.yield_results(*call_function(target_function, *kb.arg_bindings)) |
63 changes: 63 additions & 0 deletions
63
sharktank/sharktank/kernels/templates/rotary_embedding.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
// Copyright 2024 Advanced Micro Devices, Inc. | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
!input_tensor_type = {{input_tensor_type}} | ||
!table_tensor_type = {{table_tensor_type}} | ||
|
||
module { | ||
|
||
util.func private @sharktank_rotary_embedding_{{bs}}_{{sl}}_{{heads}}_{{dims}}_{{dtype}}(%input: !input_tensor_type, %table: !table_tensor_type) -> !input_tensor_type { | ||
|
||
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
%c2 = arith.constant 2 : index | ||
%c3 = arith.constant 3 : index | ||
|
||
|
||
%d0 = tensor.dim %input, %c0 : !input_tensor_type | ||
%d1 = tensor.dim %input, %c1 : !input_tensor_type | ||
%d2 = tensor.dim %input, %c2 : !input_tensor_type | ||
%d3 = tensor.dim %input, %c3 : !input_tensor_type | ||
|
||
%empty_dyn = tensor.empty(%d0, %d1, %d2, %d3) : tensor<?x?x?x?x{{dtype}}> | ||
%empty = tensor.cast %empty_dyn : tensor<?x?x?x?x{{dtype}}> to {{input_tensor_type}} | ||
|
||
%result = linalg.generic { | ||
indexing_maps = [ | ||
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, | ||
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> | ||
], | ||
iterator_types = ["parallel", "parallel", "parallel", "parallel"]} | ||
ins(%table : !table_tensor_type ) | ||
outs(%empty : !input_tensor_type) { | ||
^bb0(%b0 : {{dtype}} , %b1 : {{dtype}}): | ||
%0 = linalg.index 0 : index | ||
%1 = linalg.index 1 : index | ||
%2 = linalg.index 2 : index | ||
%3 = linalg.index 3 : index | ||
%div = arith.divui %3, %c2 : index | ||
%mod = arith.remui %3, %c2 : index | ||
%a_cosb = math.cos %b0 : {{dtype}} | ||
%a_sinb = math.sin %b0 : {{dtype}} | ||
%real_index = arith.muli %div, %c2 : index | ||
%imag_index = arith.addi %real_index, %c1 : index | ||
%real = tensor.extract %input[%0, %1, %2, %real_index] : !input_tensor_type | ||
%imag = tensor.extract %input[%0, %1, %2, %imag_index] : !input_tensor_type | ||
%cmp = arith.cmpi eq, %mod, %c0 : index | ||
%real_t0 = arith.mulf %real, %a_cosb : {{dtype}} | ||
%real_t1 = arith.mulf %imag, %a_sinb : {{dtype}} | ||
%real_t2 = arith.subf %real_t0, %real_t1 : {{dtype}} | ||
%imag_t0 = arith.mulf %imag, %a_cosb : {{dtype}} | ||
%imag_t1 = arith.mulf %real, %a_sinb : {{dtype}} | ||
%imag_t2 = arith.addf %imag_t0, %imag_t1 : {{dtype}} | ||
%val = arith.select %cmp, %real_t2, %imag_t2 : {{dtype}} | ||
linalg.yield %val : {{dtype}} | ||
} -> !input_tensor_type | ||
|
||
util.return %result : !input_tensor_type | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.