Skip to content

Commit

Permalink
[llama] Added the fused rotary embedding kernel (#719)
Browse files Browse the repository at this point in the history
Reworked rotary embedding application to be performed via a custom
kernel. This includes dropping `static_table` for the sake of
maintenance (it was largely unused). It includes a simple numerical test
however under the hood no numerical change should occur.

Existing baseline vs hugging face remained unchanged.
  • Loading branch information
rsuderman authored Dec 20, 2024
1 parent 7862ff8 commit fc9576b
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 69 deletions.
9 changes: 1 addition & 8 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def main():
hp,
tensor_parallelism_size=tensor_parallelism_size,
use_hf=False,
static_tables=False, # Rely on the compiler for hoisting tables.
kv_cache_type="direct" if args.bs == [1] else "paged",
attention_kernel=args.attention_kernel,
block_seq_stride=args.block_seq_stride,
Expand Down Expand Up @@ -219,22 +218,16 @@ def _(model, tokens, seq_lens, seq_block_ids, cs):
else:
cache_tensors = cs

sl = tokens.shape[1]
input_mask = model.input_mask(seq_lens, sl)
attention_mask = model.attention_mask(input_mask)

if llama_config.tensor_parallelism_size != 1:
shard_count = llama_config.tensor_parallelism_size

tokens = ops.replicate(tokens, count=shard_count)
attention_mask = ops.replicate(attention_mask, count=shard_count)
seq_block_ids = ops.replicate(seq_block_ids, count=shard_count)

cache_tensors = repack_cache(cs, cache_shard_dim)

logits = model.prefill(
tokens,
attention_mask=attention_mask,
attention_mask=None, # We rely on causal attention
seq_block_ids=seq_block_ids,
cache_state=cache_tensors,
)
Expand Down
1 change: 1 addition & 0 deletions sharktank/sharktank/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .mmt_block_scaled_offset_q4 import *
from .mmt_block_scaled_q8 import *
from .mmt_super_block_scaled_offset_q4 import *
from .rotary import *
from .batch_matmul_transpose_b import *
from .conv_2d_nchw_fchw import *
from .pooling_nchw_sum import *
Expand Down
70 changes: 70 additions & 0 deletions sharktank/sharktank/kernels/rotary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from sharktank.kernels.base import *

__all__ = [
"apply_rotary_embedding",
]


@CustomOp.register(library=LIBRARY)
class apply_rotary_embedding(CustomOp):

signature = "apply_rotary_embedding(Tensor input, Tensor table) -> (Tensor)"

def select(self, ksel: KernelSelection):
inputs_desc = ksel.arg_tensor(0)
table_desc = ksel.arg_tensor(1)
out_desc = ksel.return_new_tensor(
inputs_desc.t.shape, dtype=inputs_desc.t.dtype
)
specialize_all_known_dims(inputs_desc)
specialize_all_known_dims(table_desc)
specialize_all_known_dims(out_desc)

def generate(self, ksel: KernelSelection, kb: KernelBuilder):

input = kb.arg_value(0)
table = kb.arg_value(1)

input_tensor_type = RankedTensorType(input.type)
table_tensor_type = RankedTensorType(table.type)

input_asm_type, input_ident, input_dtype = unpack_tensor_type(input.type)
table_asm_type, table_ident, table_dtype = unpack_tensor_type(table.type)

assert input_dtype == table_dtype

# Generate specialization signature and types.
bs = input.type.shape[0]
sl = input.type.shape[1]
sl = "D" if sl < 0 else sl
heads = input.type.shape[2]
dims = input.type.shape[3]

template_file = "rotary_embedding.mlir"
target_function_name = (
f"sharktank_rotary_embedding_{bs}_{sl}_{heads}_{dims}_{input_dtype}"
)

# Template params.
input_tensor_type = input_asm_type
table_tensor_type = table_asm_type

target_function = inline_template_function(
kb,
template_file,
target_function_name,
input_tensor_type=input_tensor_type,
table_tensor_type=table_tensor_type,
bs=bs,
sl=sl,
heads=heads,
dims=dims,
dtype=str(input_dtype),
)
kb.yield_results(*call_function(target_function, *kb.arg_bindings))
63 changes: 63 additions & 0 deletions sharktank/sharktank/kernels/templates/rotary_embedding.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright 2024 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

!input_tensor_type = {{input_tensor_type}}
!table_tensor_type = {{table_tensor_type}}

module {

util.func private @sharktank_rotary_embedding_{{bs}}_{{sl}}_{{heads}}_{{dims}}_{{dtype}}(%input: !input_tensor_type, %table: !table_tensor_type) -> !input_tensor_type {

%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index


%d0 = tensor.dim %input, %c0 : !input_tensor_type
%d1 = tensor.dim %input, %c1 : !input_tensor_type
%d2 = tensor.dim %input, %c2 : !input_tensor_type
%d3 = tensor.dim %input, %c3 : !input_tensor_type

%empty_dyn = tensor.empty(%d0, %d1, %d2, %d3) : tensor<?x?x?x?x{{dtype}}>
%empty = tensor.cast %empty_dyn : tensor<?x?x?x?x{{dtype}}> to {{input_tensor_type}}

%result = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%table : !table_tensor_type )
outs(%empty : !input_tensor_type) {
^bb0(%b0 : {{dtype}} , %b1 : {{dtype}}):
%0 = linalg.index 0 : index
%1 = linalg.index 1 : index
%2 = linalg.index 2 : index
%3 = linalg.index 3 : index
%div = arith.divui %3, %c2 : index
%mod = arith.remui %3, %c2 : index
%a_cosb = math.cos %b0 : {{dtype}}
%a_sinb = math.sin %b0 : {{dtype}}
%real_index = arith.muli %div, %c2 : index
%imag_index = arith.addi %real_index, %c1 : index
%real = tensor.extract %input[%0, %1, %2, %real_index] : !input_tensor_type
%imag = tensor.extract %input[%0, %1, %2, %imag_index] : !input_tensor_type
%cmp = arith.cmpi eq, %mod, %c0 : index
%real_t0 = arith.mulf %real, %a_cosb : {{dtype}}
%real_t1 = arith.mulf %imag, %a_sinb : {{dtype}}
%real_t2 = arith.subf %real_t0, %real_t1 : {{dtype}}
%imag_t0 = arith.mulf %imag, %a_cosb : {{dtype}}
%imag_t1 = arith.mulf %real, %a_sinb : {{dtype}}
%imag_t2 = arith.addf %imag_t0, %imag_t1 : {{dtype}}
%val = arith.select %cmp, %real_t2, %imag_t2 : {{dtype}}
linalg.yield %val : {{dtype}}
} -> !input_tensor_type

util.return %result : !input_tensor_type
}

}
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
k=keys, # [bs, ..., sl, dim]
v=values, # [bs, ..., sl, dim]
a=attention_mask, # [bs, ..., sl, sl]
is_causal=False, # assumes causal masking when true
is_causal=attention_mask is None, # assumes causal masking when true
scale=None, # defaults to 1/sqrt(dim)
)

Expand Down
88 changes: 31 additions & 57 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .base import BaseLayer
from .. import ops
from .. import kernels
from ..types import SplitPrimitiveTensor, ReplicatedTensor, unbox_tensor


Expand All @@ -25,7 +26,6 @@ def __init__(
rope_freq_base: Optional[float],
device: Optional[torch.device] = None,
use_hf: bool = False,
static_tables: bool = False,
use_table: bool = True,
tensor_parallelism_size: int = 1,
):
Expand All @@ -34,60 +34,44 @@ def __init__(
self.rope_dimension_count = rope_dimension_count
self.max_seqlen = max_seqlen
self.use_hf = use_hf
self.static_tables = static_tables
self.use_table = use_table

self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0
self.tensor_parallelism_size = tensor_parallelism_size
if static_tables:
ops.module_register_buffer(
self, "static_rotary_embed_table", self._create_rotary_embed_table()
)
else:
self.static_rotary_embed_table = None

@property
def rotary_embed_table(self):
if self.use_table:
if self.static_tables:
return self.static_rotary_embed_table
return self._create_rotary_embed_table()

return None
return self._create_rotary_embed_table()

def forward(
self,
*,
xt: Union[torch.Tensor, SplitPrimitiveTensor],
start_index: int,
):
if isinstance(xt, SplitPrimitiveTensor):
rotary_shards = [None] * xt.shard_count
if self.rotary_embed_table is not None:
assert (
isinstance(self.rotary_embed_table, ReplicatedTensor)
and xt.shard_count == self.rotary_embed_table.shard_count
)
rotary_shards = [
unbox_tensor(shard) for shard in self.rotary_embed_table.shards
]

xt_shards = [
self.forward_unsharded(
xt=unbox_tensor(xt_shard),
start_index=start_index,
rotary_embed_table=rotary_shard,
)
for xt_shard, rotary_shard in zip(xt.shards, rotary_shards)
]
xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
return xt
else:
table = self.rotary_embed_table
if not isinstance(xt, SplitPrimitiveTensor):
return self.forward_unsharded(
xt=xt,
start_index=start_index,
rotary_embed_table=self.rotary_embed_table,
rotary_embed_table=table,
)

assert (
isinstance(table, ReplicatedTensor) and xt.shard_count == table.shard_count
)
rotary_shards = [unbox_tensor(shard) for shard in table.shards]

xt_shards = [
self.forward_unsharded(
xt=unbox_tensor(xt_shard),
start_index=start_index,
rotary_embed_table=rotary_shard,
)
for xt_shard, rotary_shard in zip(xt.shards, rotary_shards)
]
xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
return xt

def _create_interleaved_tensor(_, dim):
"""Creates a tensor which indexes an tensor such that
Expand Down Expand Up @@ -143,18 +127,17 @@ def forward_unsharded(
# Offset the table based on starting position.
if self.use_table:
freqs_cis = rotary_embed_table[start_index : start_index + sl, :]
freqs_cis = freqs_cis[None, 0:sl, None, :]
freqs_cis = freqs_cis[0:sl, :]
else:
freqs_cis = torch.arange(sl, device=xt.device) + start_index
freqs_cis = self._compute_rotary_embed_table(freqs_cis)[None, :, None, :]
freqs_cis = self._compute_rotary_embed_table(freqs_cis)

assert (
freqs_cis.shape[1] >= sl
freqs_cis.shape[0] >= sl
), f"Sequence length longer than embedding table ({sl} vs {freqs_cis.shape[0]})"

xt_ = ops.view_as_complex(xt_)
xt_ = xt_ * freqs_cis
xt_out = ops.view_as_real(xt_)
freqs_cis = ops.repeat(freqs_cis[None, :, :], (xt_.shape[0], 1, 1))
xt_out = kernels.apply_rotary_embedding(xt_.to(freqs_cis.dtype), freqs_cis)

if self.use_hf:
xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])]
Expand All @@ -181,7 +164,7 @@ def compute_batch_mask(
self.trace_tensor("rope.positions_seq", positions_seq)

if self.use_table:
freqs_cis = self.rotary_embed_table[positions_seq]
freqs_cis = self.rotary_embed_table[positions_seq.flatten()]
else:
shape = positions_seq.shape
if isinstance(positions_seq, ReplicatedTensor):
Expand All @@ -192,11 +175,8 @@ def compute_batch_mask(
freqs_cis = ReplicatedTensor(ts=ts)
else:
freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten())
freqs_cis = freqs_cis.unflatten(0, shape)

# Unsqueeze a unit dim for attention heads.
broadcast_freqs_cis = freqs_cis.unsqueeze(2)
return broadcast_freqs_cis
return freqs_cis.unsqueeze(1)

def apply_batched_mask(
self,
Expand Down Expand Up @@ -232,9 +212,7 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor):
if self.use_hf:
xt = xt[..., self._create_interleaved_tensor(xt.shape[-1])]

xt_ = ops.view_as_complex(xt)
xt_ = xt_ * mask
xt_out = ops.view_as_real(xt_)
xt_out = kernels.apply_rotary_embedding(xt.to(mask.dtype), mask)

if self.use_hf:
xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])]
Expand All @@ -244,14 +222,10 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor):
def _compute_rotary_embed_table(self, t):
dim = self.rope_dimension_count
freqs = 1.0 / (
self.rope_freq_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
self.rope_freq_base ** ((torch.arange(0, dim) // 2).float() / dim * 2.0)
)
freqs = torch.outer(t, freqs).float()

cos = torch.cos(freqs)
sin = torch.sin(freqs)
complex = torch.complex(cos, sin)
return complex
return freqs

def _create_rotary_embed_table(self):
t = torch.arange(self.max_seqlen, device=self.device)
Expand Down
4 changes: 1 addition & 3 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
super().__init__(
theta,
context_length=config.hp.context_length,
static_tables=config.static_tables,
device=config.device,
activation_dtype=config.activation_dtype,
attention_dtype=config.attention_dtype,
Expand All @@ -92,7 +91,6 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
max_seqlen=hp.context_length,
device=self.device,
use_hf=self.use_hf,
static_tables=config.static_tables,
tensor_parallelism_size=config.tensor_parallelism_size,
),
)
Expand Down Expand Up @@ -126,7 +124,7 @@ def prefill(
tokens: Union[torch.Tensor, ReplicatedTensor],
*,
# [1, 1, batch_seq_len, batch_seq_len]
attention_mask: Union[torch.Tensor, ReplicatedTensor],
attention_mask: Optional[Union[torch.Tensor, ReplicatedTensor]],
# [bs, batch_seq_len // block_seq_stride]
seq_block_ids: Union[torch.Tensor, ReplicatedTensor],
cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
Expand Down
Loading

0 comments on commit fc9576b

Please sign in to comment.