From dacd042f592c0bffa5112703f2b9cc790265bef7 Mon Sep 17 00:00:00 2001 From: Stephen Baione Date: Wed, 15 Jan 2025 15:22:08 +0000 Subject: [PATCH] Revert "[sharktank] Revert "[llama] Added the fused rotary embedding kernel (#719)" (#752)" This reverts commit 63ff841c83afb6b2474b8a456d870ea0f755c169. --- .../sharktank/examples/export_paged_llm_v1.py | 8 +- sharktank/sharktank/kernels/__init__.py | 1 + sharktank/sharktank/kernels/rotary.py | 70 +++++++++++++++ .../kernels/templates/rotary_embedding.mlir | 63 +++++++++++++ .../layers/paged_llama_attention_block.py | 2 +- .../sharktank/layers/rotary_embedding.py | 88 +++++++------------ sharktank/sharktank/models/llama/llama.py | 4 +- .../tests/evaluate/perplexity_iree_test.py | 1 + sharktank/tests/kernels/rotary.py | 31 +++++++ 9 files changed, 200 insertions(+), 68 deletions(-) create mode 100644 sharktank/sharktank/kernels/rotary.py create mode 100644 sharktank/sharktank/kernels/templates/rotary_embedding.mlir create mode 100644 sharktank/tests/kernels/rotary.py diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 183730c65..686533ca2 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -219,22 +219,16 @@ def _(model, tokens, seq_lens, seq_block_ids, cs): else: cache_tensors = cs - sl = tokens.shape[1] - input_mask = model.input_mask(seq_lens, sl) - attention_mask = model.attention_mask(input_mask) - if llama_config.tensor_parallelism_size != 1: shard_count = llama_config.tensor_parallelism_size tokens = ops.replicate(tokens, count=shard_count) - attention_mask = ops.replicate(attention_mask, count=shard_count) seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) - cache_tensors = repack_cache(cs, cache_shard_dim) logits = model.prefill( tokens, - attention_mask=attention_mask, + attention_mask=None, # We rely on causal attention seq_block_ids=seq_block_ids, cache_state=cache_tensors, ) diff --git a/sharktank/sharktank/kernels/__init__.py b/sharktank/sharktank/kernels/__init__.py index 445f44852..1b84f0bee 100644 --- a/sharktank/sharktank/kernels/__init__.py +++ b/sharktank/sharktank/kernels/__init__.py @@ -10,6 +10,7 @@ from .mmt_block_scaled_offset_q4 import * from .mmt_block_scaled_q8 import * from .mmt_super_block_scaled_offset_q4 import * +from .rotary import * from .batch_matmul_transpose_b import * from .conv_2d_nchw_fchw import * from .pooling_nchw_sum import * diff --git a/sharktank/sharktank/kernels/rotary.py b/sharktank/sharktank/kernels/rotary.py new file mode 100644 index 000000000..196fc32c2 --- /dev/null +++ b/sharktank/sharktank/kernels/rotary.py @@ -0,0 +1,70 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from sharktank.kernels.base import * + +__all__ = [ + "apply_rotary_embedding", +] + + +@CustomOp.register(library=LIBRARY) +class apply_rotary_embedding(CustomOp): + + signature = "apply_rotary_embedding(Tensor input, Tensor table) -> (Tensor)" + + def select(self, ksel: KernelSelection): + inputs_desc = ksel.arg_tensor(0) + table_desc = ksel.arg_tensor(1) + out_desc = ksel.return_new_tensor( + inputs_desc.t.shape, dtype=inputs_desc.t.dtype + ) + specialize_all_known_dims(inputs_desc) + specialize_all_known_dims(table_desc) + specialize_all_known_dims(out_desc) + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + + input = kb.arg_value(0) + table = kb.arg_value(1) + + input_tensor_type = RankedTensorType(input.type) + table_tensor_type = RankedTensorType(table.type) + + input_asm_type, input_ident, input_dtype = unpack_tensor_type(input.type) + table_asm_type, table_ident, table_dtype = unpack_tensor_type(table.type) + + assert input_dtype == table_dtype + + # Generate specialization signature and types. + bs = input.type.shape[0] + sl = input.type.shape[1] + sl = "D" if sl < 0 else sl + heads = input.type.shape[2] + dims = input.type.shape[3] + + template_file = "rotary_embedding.mlir" + target_function_name = ( + f"sharktank_rotary_embedding_{bs}_{sl}_{heads}_{dims}_{input_dtype}" + ) + + # Template params. + input_tensor_type = input_asm_type + table_tensor_type = table_asm_type + + target_function = inline_template_function( + kb, + template_file, + target_function_name, + input_tensor_type=input_tensor_type, + table_tensor_type=table_tensor_type, + bs=bs, + sl=sl, + heads=heads, + dims=dims, + dtype=str(input_dtype), + ) + kb.yield_results(*call_function(target_function, *kb.arg_bindings)) diff --git a/sharktank/sharktank/kernels/templates/rotary_embedding.mlir b/sharktank/sharktank/kernels/templates/rotary_embedding.mlir new file mode 100644 index 000000000..adec6805b --- /dev/null +++ b/sharktank/sharktank/kernels/templates/rotary_embedding.mlir @@ -0,0 +1,63 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +!input_tensor_type = {{input_tensor_type}} +!table_tensor_type = {{table_tensor_type}} + +module { + +util.func private @sharktank_rotary_embedding_{{bs}}_{{sl}}_{{heads}}_{{dims}}_{{dtype}}(%input: !input_tensor_type, %table: !table_tensor_type) -> !input_tensor_type { + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + + %d0 = tensor.dim %input, %c0 : !input_tensor_type + %d1 = tensor.dim %input, %c1 : !input_tensor_type + %d2 = tensor.dim %input, %c2 : !input_tensor_type + %d3 = tensor.dim %input, %c3 : !input_tensor_type + + %empty_dyn = tensor.empty(%d0, %d1, %d2, %d3) : tensor + %empty = tensor.cast %empty_dyn : tensor to {{input_tensor_type}} + + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + ], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%table : !table_tensor_type ) + outs(%empty : !input_tensor_type) { + ^bb0(%b0 : {{dtype}} , %b1 : {{dtype}}): + %0 = linalg.index 0 : index + %1 = linalg.index 1 : index + %2 = linalg.index 2 : index + %3 = linalg.index 3 : index + %div = arith.divui %3, %c2 : index + %mod = arith.remui %3, %c2 : index + %a_cosb = math.cos %b0 : {{dtype}} + %a_sinb = math.sin %b0 : {{dtype}} + %real_index = arith.muli %div, %c2 : index + %imag_index = arith.addi %real_index, %c1 : index + %real = tensor.extract %input[%0, %1, %2, %real_index] : !input_tensor_type + %imag = tensor.extract %input[%0, %1, %2, %imag_index] : !input_tensor_type + %cmp = arith.cmpi eq, %mod, %c0 : index + %real_t0 = arith.mulf %real, %a_cosb : {{dtype}} + %real_t1 = arith.mulf %imag, %a_sinb : {{dtype}} + %real_t2 = arith.subf %real_t0, %real_t1 : {{dtype}} + %imag_t0 = arith.mulf %imag, %a_cosb : {{dtype}} + %imag_t1 = arith.mulf %real, %a_sinb : {{dtype}} + %imag_t2 = arith.addf %imag_t0, %imag_t1 : {{dtype}} + %val = arith.select %cmp, %real_t2, %imag_t2 : {{dtype}} + linalg.yield %val : {{dtype}} + } -> !input_tensor_type + + util.return %result : !input_tensor_type +} + +} diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 6bd33c93f..d74e2a92d 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -221,7 +221,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: k=keys, # [bs, ..., sl, dim] v=values, # [bs, ..., sl, dim] a=attention_mask, # [bs, ..., sl, sl] - is_causal=False, # assumes causal masking when true + is_causal=attention_mask is None, # assumes causal masking when true scale=None, # defaults to 1/sqrt(dim) ) diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index 99ecf5057..623c02ea6 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -11,6 +11,7 @@ from .base import BaseLayer from .. import ops +from .. import kernels from ..types import SplitPrimitiveTensor, ReplicatedTensor, unbox_tensor @@ -25,7 +26,6 @@ def __init__( rope_freq_base: Optional[float], device: Optional[torch.device] = None, use_hf: bool = False, - static_tables: bool = False, use_table: bool = True, tensor_parallelism_size: int = 1, ): @@ -34,26 +34,14 @@ def __init__( self.rope_dimension_count = rope_dimension_count self.max_seqlen = max_seqlen self.use_hf = use_hf - self.static_tables = static_tables self.use_table = use_table self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0 self.tensor_parallelism_size = tensor_parallelism_size - if static_tables: - ops.module_register_buffer( - self, "static_rotary_embed_table", self._create_rotary_embed_table() - ) - else: - self.static_rotary_embed_table = None @property def rotary_embed_table(self): - if self.use_table: - if self.static_tables: - return self.static_rotary_embed_table - return self._create_rotary_embed_table() - - return None + return self._create_rotary_embed_table() def forward( self, @@ -61,33 +49,29 @@ def forward( xt: Union[torch.Tensor, SplitPrimitiveTensor], start_index: int, ): - if isinstance(xt, SplitPrimitiveTensor): - rotary_shards = [None] * xt.shard_count - if self.rotary_embed_table is not None: - assert ( - isinstance(self.rotary_embed_table, ReplicatedTensor) - and xt.shard_count == self.rotary_embed_table.shard_count - ) - rotary_shards = [ - unbox_tensor(shard) for shard in self.rotary_embed_table.shards - ] - - xt_shards = [ - self.forward_unsharded( - xt=unbox_tensor(xt_shard), - start_index=start_index, - rotary_embed_table=rotary_shard, - ) - for xt_shard, rotary_shard in zip(xt.shards, rotary_shards) - ] - xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim) - return xt - else: + table = self.rotary_embed_table + if not isinstance(xt, SplitPrimitiveTensor): return self.forward_unsharded( xt=xt, start_index=start_index, - rotary_embed_table=self.rotary_embed_table, + rotary_embed_table=table, + ) + + assert ( + isinstance(table, ReplicatedTensor) and xt.shard_count == table.shard_count + ) + rotary_shards = [unbox_tensor(shard) for shard in table.shards] + + xt_shards = [ + self.forward_unsharded( + xt=unbox_tensor(xt_shard), + start_index=start_index, + rotary_embed_table=rotary_shard, ) + for xt_shard, rotary_shard in zip(xt.shards, rotary_shards) + ] + xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim) + return xt def _create_interleaved_tensor(_, dim): """Creates a tensor which indexes an tensor such that @@ -143,18 +127,17 @@ def forward_unsharded( # Offset the table based on starting position. if self.use_table: freqs_cis = rotary_embed_table[start_index : start_index + sl, :] - freqs_cis = freqs_cis[None, 0:sl, None, :] + freqs_cis = freqs_cis[0:sl, :] else: freqs_cis = torch.arange(sl, device=xt.device) + start_index - freqs_cis = self._compute_rotary_embed_table(freqs_cis)[None, :, None, :] + freqs_cis = self._compute_rotary_embed_table(freqs_cis) assert ( - freqs_cis.shape[1] >= sl + freqs_cis.shape[0] >= sl ), f"Sequence length longer than embedding table ({sl} vs {freqs_cis.shape[0]})" - xt_ = ops.view_as_complex(xt_) - xt_ = xt_ * freqs_cis - xt_out = ops.view_as_real(xt_) + freqs_cis = ops.repeat(freqs_cis[None, :, :], (xt_.shape[0], 1, 1)) + xt_out = kernels.apply_rotary_embedding(xt_.to(freqs_cis.dtype), freqs_cis) if self.use_hf: xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])] @@ -181,7 +164,7 @@ def compute_batch_mask( self.trace_tensor("rope.positions_seq", positions_seq) if self.use_table: - freqs_cis = self.rotary_embed_table[positions_seq] + freqs_cis = self.rotary_embed_table[positions_seq.flatten()] else: shape = positions_seq.shape if isinstance(positions_seq, ReplicatedTensor): @@ -192,11 +175,8 @@ def compute_batch_mask( freqs_cis = ReplicatedTensor(ts=ts) else: freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten()) - freqs_cis = freqs_cis.unflatten(0, shape) - # Unsqueeze a unit dim for attention heads. - broadcast_freqs_cis = freqs_cis.unsqueeze(2) - return broadcast_freqs_cis + return freqs_cis.unsqueeze(1) def apply_batched_mask( self, @@ -232,9 +212,7 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor): if self.use_hf: xt = xt[..., self._create_interleaved_tensor(xt.shape[-1])] - xt_ = ops.view_as_complex(xt) - xt_ = xt_ * mask - xt_out = ops.view_as_real(xt_) + xt_out = kernels.apply_rotary_embedding(xt.to(mask.dtype), mask) if self.use_hf: xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])] @@ -244,14 +222,10 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor): def _compute_rotary_embed_table(self, t): dim = self.rope_dimension_count freqs = 1.0 / ( - self.rope_freq_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + self.rope_freq_base ** ((torch.arange(0, dim) // 2).float() / dim * 2.0) ) freqs = torch.outer(t, freqs).float() - - cos = torch.cos(freqs) - sin = torch.sin(freqs) - complex = torch.complex(cos, sin) - return complex + return freqs def _create_rotary_embed_table(self): t = torch.arange(self.max_seqlen, device=self.device) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 0a9a6f1c3..6fef6704e 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -67,7 +67,6 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): super().__init__( theta, context_length=config.hp.context_length, - static_tables=config.static_tables, device=config.device, activation_dtype=config.activation_dtype, attention_dtype=config.attention_dtype, @@ -92,7 +91,6 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): max_seqlen=hp.context_length, device=self.device, use_hf=self.use_hf, - static_tables=config.static_tables, tensor_parallelism_size=config.tensor_parallelism_size, ), ) @@ -126,7 +124,7 @@ def prefill( tokens: Union[torch.Tensor, ReplicatedTensor], *, # [1, 1, batch_seq_len, batch_seq_len] - attention_mask: Union[torch.Tensor, ReplicatedTensor], + attention_mask: Optional[Union[torch.Tensor, ReplicatedTensor]], # [bs, batch_seq_len // block_seq_stride] seq_block_ids: Union[torch.Tensor, ReplicatedTensor], cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], diff --git a/sharktank/tests/evaluate/perplexity_iree_test.py b/sharktank/tests/evaluate/perplexity_iree_test.py index 1e42bde9c..dc655af59 100644 --- a/sharktank/tests/evaluate/perplexity_iree_test.py +++ b/sharktank/tests/evaluate/perplexity_iree_test.py @@ -34,6 +34,7 @@ def setUp(self): with open(self.baseline_perplexity_scores, "r") as f: self.baseline_perplexity = json.load(f) + @pytest.mark.xfail(reason="Runtime segfault", run=False) def test_llama3_8B_f16_decomposed(self): # Llama 3.1 8B decomposed diff --git a/sharktank/tests/kernels/rotary.py b/sharktank/tests/kernels/rotary.py new file mode 100644 index 000000000..6c3d032a3 --- /dev/null +++ b/sharktank/tests/kernels/rotary.py @@ -0,0 +1,31 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging + +logging.basicConfig(level=logging.DEBUG) + +import torch +import unittest + +from sharktank import kernels +from sharktank import ops + + +class rotary_test(unittest.TestCase): + def setUp(self): + torch.manual_seed(42) + + def test_rotary(self): + dtype = torch.float32 + a = torch.rand([1, 128, 1, 64], dtype=dtype) + rot = torch.rand([128, 32], dtype=dtype) + res_b = ops.view_as_real(torch.complex(rot, rot)) + ref_b = torch.complex(torch.cos(rot), torch.sin(rot)) + + result = kernels.apply_rotary_embedding(a, res_b) + ref = ops.view_as_real(ops.view_as_complex(a) * ref_b[None, :, None, :]) + torch.testing.assert_close(result, ref)