Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llm attn op override #805

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions sharktank/sharktank/kernels/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,111 @@

__all__ = [
"flash_attention",
"masked_flash_attention",
]


@CustomOp.register(library=LIBRARY)
class masked_flash_attention(CustomOp):

signature = "masked_flash_attention(Tensor q, Tensor k, Tensor v, Tensor? a, Tensor scale) -> (Tensor)"

def select(self, ksel: KernelSelection):
q_desc = ksel.arg_tensor(0) # Shape b, l, d
k_desc = ksel.arg_tensor(1) # Shape b, s, d
v_desc = ksel.arg_tensor(2) # Shape b, s, e
a_desc = ksel.arg_tensor(3) # Shape b, l, s
s_desc = ksel.arg_tensor(4)

q_bs = q_desc.t.shape[:-2]
k_bs = k_desc.t.shape[:-2]
v_bs = v_desc.t.shape[:-2]
a_bs = a_desc.t.shape[:-2]

bs = len(q_bs)

# Note: kernel does collapse dims to get to a single batch/head dim
torch._check(len(q_bs) == 2, lambda: f"TODO: batch dims {bs} not supported")

q_l, q_d = q_desc.t.shape[-2:]
k_s, k_d = k_desc.t.shape[-2:]
v_s, v_e = v_desc.t.shape[-2:]

torch._check(
q_desc.t.dtype.is_floating_point
and k_desc.t.dtype.is_floating_point
and v_desc.t.dtype.is_floating_point
and s_desc.t.dtype.is_floating_point,
lambda: f"flash_attention: Expected floating point",
)

for q_b, k_b, v_b, a_b in zip(q_bs, k_bs, v_bs, a_bs):
torch._check(
q_b == k_b and q_b == v_b and q_b == a_b,
lambda: f"expected matching batch dims: {q_b}, {k_b}, {v_b}, {a_b}",
)

torch._check(q_d == k_d, lambda: f"expected matching qk features: {q_d}, {k_d}")

torch._check(k_s == v_s, lambda: f"expected matching kv length: {q_d}, {k_d}")

q_desc.specialize_dims(0, 1, -1)
k_desc.specialize_dims(0, 1, -1)
v_desc.specialize_dims(0, 1, -1)
a_desc.specialize_dims(0, 1)

# Result 0: Shape batch..., m, n
ksel.return_new_tensor((*q_bs, q_l, v_e), dtype=torch.float16).specialize_dims(
0, 1, -1
)

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
q = kb.arg_value(0)
k = kb.arg_value(1)
v = kb.arg_value(2)
a = kb.arg_value(3)
scale = kb.arg_value(4)

q_tensor_type = RankedTensorType(q.type)
scale_tensor_type = RankedTensorType(scale.type)
v_tensor_type = RankedTensorType(v.type)

b1, b2, l, d = q_tensor_type.shape
_, _, s, e = v_tensor_type.shape

# Unspecialized dims will be negative
l = l if l >= 0 else "?"
s = s if s >= 0 else "?"
b = str(int(b1) * int(b2))
i_type_str = str(q_tensor_type.element_type)
scale_type_str = str(scale_tensor_type.element_type)
o_type_str = "f16"

target_function_name = f"sharktank_masked_flash_attention_{b1}_{b2}_{d}_{e}_{i_type_str}_{scale_type_str}_{o_type_str}"
kwargs = {
"b": b,
"b1": b1,
"b2": b2,
"l": l,
"d": d,
"s": s,
"e": e,
"i_dtype": i_type_str,
"scale_dtype": scale_type_str,
"o_dtype": o_type_str,
"func_name": target_function_name,
}
template_file = "masked_flash_attention.mlir"
target_function = inline_template_function(
kb,
template_file,
target_function_name,
**kwargs,
)
kb.yield_results(*call_function(target_function, q, k, v, scale, a))
pass


@CustomOp.register(library=LIBRARY)
class flash_attention(CustomOp):

Expand Down
63 changes: 63 additions & 0 deletions sharktank/sharktank/kernels/templates/masked_flash_attention.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright 2024 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

!q_type = tensor<{{b1}}x{{b2}}x{{l}}x{{d}}x{{i_dtype}}>
!k_type = tensor<{{b1}}x{{b2}}x{{s}}x{{d}}x{{i_dtype}}>
!v_type = tensor<{{b1}}x{{b2}}x{{s}}x{{e}}x{{i_dtype}}>
!a_type = tensor<{{b1}}x{{b2}}x{{l}}x{{s}}x{{i_dtype}}>
!trans_v_type = tensor<{{b1}}x{{b2}}x{{e}}x{{s}}x{{i_dtype}}>
!o_type = tensor<{{b1}}x{{b2}}x{{l}}x{{e}}x{{o_dtype}}>
!o_dyn_type = tensor<?x?x?x{{o_dtype}}>
!o_collapsed_type = tensor<{{b}}x{{l}}x{{e}}x{{o_dtype}}>
!q_collapsed_type = tensor<{{b}}x{{l}}x{{d}}x{{i_dtype}}>
!k_collapsed_type = tensor<{{b}}x{{s}}x{{d}}x{{i_dtype}}>
!v_collapsed_type = tensor<{{b}}x{{s}}x{{e}}x{{i_dtype}}>
!a_collapsed_type = tensor<{{b}}x{{l}}x{{s}}x{{i_dtype}}>
!s_type = tensor<{{scale_dtype}}>

module {

util.func private @{{func_name}}(
%q: !q_type,
%k: !k_type,
%v: !v_type,
%s: !s_type,
%a: !a_type) -> !o_type {

%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%b0 = arith.constant {{b}} : index


%l = tensor.dim %q, %c2 : !q_type
%e = tensor.dim %v, %c3 : !v_type

%scale = tensor.extract %s[] : !s_type
%empty_dyn = tensor.empty(%b0, %l, %e) : !o_dyn_type
%empty = tensor.cast %empty_dyn : !o_dyn_type to !o_collapsed_type

%collapsed_q = tensor.collapse_shape %q [[0, 1], [2], [3]] : !q_type into !q_collapsed_type
%collapsed_k = tensor.collapse_shape %k [[0, 1], [2], [3]] : !k_type into !k_collapsed_type
%collapsed_v = tensor.collapse_shape %v [[0, 1], [2], [3]] : !v_type into !v_collapsed_type
%collapsed_a = tensor.collapse_shape %a [[0, 1], [2], [3]] : !a_type into !a_collapsed_type

%atten = iree_linalg_ext.attention {indexing_maps = [
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>]}
ins(%collapsed_q, %collapsed_k, %collapsed_v, %scale, %collapsed_a : !q_collapsed_type, !k_collapsed_type, !v_collapsed_type, {{scale_dtype}}, !a_collapsed_type) outs(%empty : !o_collapsed_type) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> !o_collapsed_type
%expanded_o = tensor.expand_shape %atten [[0,1], [2], [3]] output_shape [{{b1}}, {{b2}}, %l, {{e}}] : !o_collapsed_type into !o_type
util.return %expanded_o : !o_type
}
}
1 change: 0 additions & 1 deletion sharktank/sharktank/ops/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@ def _match_targets(self, type_spec: tuple):
targets = []
for override in self._overrides:
override_type_spec = override.type_spec

# Check if the override is a boolean type expression and if it is that it
# satisfied.
if self._is_type_expr_target(override_type_spec, type_spec):
Expand Down
42 changes: 40 additions & 2 deletions sharktank/sharktank/ops/attention_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,35 @@


def _extract_linear_scale(t):
# Returns the qs (quantized tensor value) and scale where appropriate
if (
isinstance(t, PlanarQuantizedTensor)
and isinstance(t.layout, TensorScaledLayout)
and t.layout.m is None
and t.layout.m is None # offset must be none or can't fuse scales
):
return t.layout.qs, t.layout.d
return unbox_tensor(t), None


def flash_attention(q, k, v, a, is_causal, scale):
def register_attention_override_by_name(name: str):
"""Provides a way to override available attention kernels
based on something other than a global flag"""
if name == "flash_attention":
scaled_dot_product_attention.override(
PlanarQuantizedTensor,
PlanarQuantizedTensor,
PlanarQuantizedTensor,
NoneType,
)(flash_attention)
elif name == "masked_flash_attention":
scaled_dot_product_attention.override(
AnyTensor, AnyTensor, AnyTensor, AnyTensor
)(masked_flash_attention)
else:
assert False, f"{name} not a registerable override"


def prepare_args(q, k, v, scale):
scale = torch.scalar_tensor(1.0 / math.sqrt(q.shape[-1]), dtype=torch.float32)

q, qscale = _extract_linear_scale(q)
Expand All @@ -66,8 +85,23 @@ def flash_attention(q, k, v, a, is_causal, scale):
if v.dtype == torch.float32:
v = v.to(torch.float16)

return q, k, v, scale, vscale


def flash_attention(q, k, v, a, is_causal, scale):
assert not is_causal or is_causal == None, "NYI: is_causal iree custom attention"
q, k, v, scale, vscale = prepare_args(q, k, v, scale)
atten = kernels.flash_attention(q, k, v, scale)
atten = atten * vscale if vscale is not None else atten
return atten


def masked_flash_attention(q, k, v, a, is_causal, scale):
assert not is_causal or is_causal == None, "NYI: is_causal iree custom attention"
q, k, v, scale, vscale = prepare_args(q, k, v, scale)
if a.dtype == torch.float32:
a = a.to(torch.float16)
atten = kernels.masked_flash_attention(q, k, v, a, scale)
atten = atten * vscale if vscale is not None else atten
return atten

Expand All @@ -76,3 +110,7 @@ def flash_attention(q, k, v, a, is_causal, scale):
scaled_dot_product_attention.override(
PlanarQuantizedTensor, PlanarQuantizedTensor, PlanarQuantizedTensor, NoneType
)(flash_attention)
if debugging.flags.use_custom_generic_attention:
scaled_dot_product_attention.override(AnyTensor, AnyTensor, AnyTensor, AnyTensor)(
masked_flash_attention
)
3 changes: 3 additions & 0 deletions sharktank/sharktank/utils/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class DebugFlags:
# certain eager use cases are still having problems with these custom
# kernels, so keeping it to unblock progress.
use_custom_iree_kernels: bool = True
use_custom_generic_attention: bool = False

def set(self, part: str):
m = re.match(SETTING_PART_PATTERN, part)
Expand All @@ -56,6 +57,8 @@ def set(self, part: str):
self.save_goldens_path = Path(value)
elif name == "use_custom_iree_kernels":
self.use_custom_iree_kernels = logical_sense
elif name == "use_custom_generic_attention":
self.use_custom_generic_attention = logical_sense
else:
logger.warn("Unrecognized %s flag: '%s'", FLAGS_ENV_NAME, name)

Expand Down
Loading
Loading