Skip to content

Commit

Permalink
run formatters
Browse files Browse the repository at this point in the history
  • Loading branch information
aciddelgado committed Dec 21, 2023
1 parent f67316d commit 8084585
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,6 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int local_window_size,
bool is_rotary_interleaved,
bool is_packed_qkv) {

auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
int local_window_size = -1,
bool is_rotary_interleaved=false,
bool is_packed_qkv=false);
bool is_rotary_interleaved = false,
bool is_packed_qkv = false);

size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads);

Expand Down
36 changes: 18 additions & 18 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ Status CheckInputs(const Tensor* query,

if (num_heads % kv_num_heads != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
num_heads % kv_num_heads);
"num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
num_heads % kv_num_heads);
}

int kv_hidden_size = 0;
Expand All @@ -61,35 +61,35 @@ Status CheckInputs(const Tensor* query,
head_size = static_cast<int>(q_hidden_size) / num_heads;
if (head_size % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"head_size must be a multiple of 8. Got head_size % 8 == ",
head_size % 8);
"head_size must be a multiple of 8. Got head_size % 8 == ",
head_size % 8);
}
if (value == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
"Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");

Check warning on line 69 in onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h#L69

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h:69:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
const auto& key_dims = key->Shape().GetDims();
if (key_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
key_dims.size());
key_dims.size());
} else if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 0 (batch size)");
"Input 'query' and 'key' shall have same dim 0 (batch size)");
} else if (query_dims[1] != key_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 1 (sequence length)");
"Input 'query' and 'key' shall have same dim 1 (sequence length)");
}
kv_hidden_size = static_cast<int>(key_dims[2]);
const auto& value_dims = value->Shape().GetDims();
if (value_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
value_dims.size());
value_dims.size());
} else if (query_dims[0] != value_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'value' shall have same dim 0 (batch size)");
"Input 'query' and 'value' shall have same dim 0 (batch size)");
} else if (query_dims[1] != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'value' shall have same dim 1 (sequence length)");
"Input 'query' and 'value' shall have same dim 1 (sequence length)");
} else if (value_dims[2] != kv_hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");

Check warning on line 94 in onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h#L94

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h:94:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
Expand All @@ -98,12 +98,12 @@ Status CheckInputs(const Tensor* query,
head_size = static_cast<int>(q_hidden_size) / (num_heads + 2 * kv_num_heads);
if (head_size % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"head_size must be a multiple of 8. Got head_size % 8 == ",
head_size % 8);
"head_size must be a multiple of 8. Got head_size % 8 == ",
head_size % 8);
}
if (value != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
"Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");

Check warning on line 106 in onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h#L106

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h:106:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
q_hidden_size = head_size * num_heads;
kv_hidden_size = head_size * kv_num_heads;
Expand Down Expand Up @@ -211,19 +211,19 @@ Status CheckInputs(const Tensor* query,

if (cos_dims[0] != present_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache dimension 0 must be of present_sequence_length.");
"cos_cache dimension 0 must be of present_sequence_length.");
}
if (sin_dims[0] != present_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sin_cache dimension 0 must be of present_sequence_length.");
"sin_cache dimension 0 must be of present_sequence_length.");
}
if (cos_dims[1] != (head_size / 16) * 8) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
"cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
}
if (sin_dims[1] != (head_size / 16) * 8) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
"sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
}
} else if (cos_cache != nullptr || sin_cache != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down
69 changes: 22 additions & 47 deletions onnxruntime/test/python/transformers/rotary_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
from typing import Optional, Tuple, Union

import torch

from einops import rearrange, repeat
import triton
import triton.language as tl

from einops import rearrange, repeat

##### TRITON KERNEL FOR ROTARY #####


# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_M": 2}),
Expand Down Expand Up @@ -81,15 +80,13 @@ def rotary_kernel(
X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)

Check warning

Code scanning / lintrunner

RUFF/N806 Warning test

Variable X in function should be lowercase.
See https://docs.astral.sh/ruff/rules/non-lowercase-variable-in-function
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])

Check warning

Code scanning / lintrunner

RUFF/N806 Warning test

Variable COS in function should be lowercase.
See https://docs.astral.sh/ruff/rules/non-lowercase-variable-in-function
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])

Check warning

Code scanning / lintrunner

RUFF/N806 Warning test

Variable SIN in function should be lowercase.
See https://docs.astral.sh/ruff/rules/non-lowercase-variable-in-function
cos = tl.load(
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
).to(tl.float32)
sin = tl.load(
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
).to(tl.float32)
x0 = tl.load(
X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
).to(tl.float32)
cos = tl.load(COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0).to(
tl.float32
)
sin = tl.load(SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(
tl.float32
)
x0 = tl.load(X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32)
x1 = tl.load(
X + rotary_dim_half * stride_x_headdim,
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
Expand Down Expand Up @@ -130,12 +127,8 @@ def rotary_kernel(
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
other=0.0,
).to(tl.float32)
x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
tl.float32
)
x1 = tl.load(
X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
).to(tl.float32)
x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(tl.float32)
x1 = tl.load(X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0).to(tl.float32)
if CONJUGATE:
sin = -sin
x0_cos = x0 * cos
Expand Down Expand Up @@ -184,12 +177,8 @@ def apply_rotary(
assert headdim <= 256, "Only support headdim <= 256"
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"

assert (
cos.dtype == sin.dtype
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
assert (
x.dtype == cos.dtype
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"

cos, sin = cos.contiguous(), sin.contiguous()
if isinstance(seqlen_offsets, torch.Tensor):
Expand All @@ -203,11 +192,7 @@ def apply_rotary(
if rotary_dim < headdim and not inplace:
output[..., rotary_dim:].copy_(x[..., rotary_dim:])

BLOCK_K = (
32
if rotary_dim <= 32
else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
)
BLOCK_K = 32 if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))

Check warning

Code scanning / lintrunner

RUFF/N806 Warning test

Variable BLOCK\_K in function should be lowercase.
See https://docs.astral.sh/ruff/rules/non-lowercase-variable-in-function
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)

Check warning

Code scanning / lintrunner

RUFF/N806 Warning test

Variable BLOCK\_M in function should be lowercase.
See https://docs.astral.sh/ruff/rules/non-lowercase-variable-in-function

Expand Down Expand Up @@ -243,8 +228,10 @@ def apply_rotary(
)
return output


##### ROTARY API #####


def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
Expand Down Expand Up @@ -356,9 +343,7 @@ def apply_rotary_emb(
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
return ApplyRotaryEmb.apply(
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
)
return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen)


# For backward compatibility
Expand All @@ -385,9 +370,7 @@ def forward(
# dimensions, we get the same tensor
# qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
apply_rotary(
qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
)
apply_rotary(qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True)
else:
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
Expand Down Expand Up @@ -429,9 +412,7 @@ def backward(ctx, dqkv):
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
apply_rotary(
dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True
)
apply_rotary(dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True)
apply_rotary(
dk,
cos_k,
Expand Down Expand Up @@ -476,9 +457,7 @@ def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, tor
batch, seqlen, two, nheads, headdim = kv.shape
assert two == 2
k = kv[:, :, 0]
apply_rotary(
k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
)
apply_rotary(k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True)
if isinstance(seqlen_offsets, int):
ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
ctx.seqlen_offsets = seqlen_offsets
Expand Down Expand Up @@ -597,10 +576,7 @@ def __init__(
self._sin_k_cached = None

def _compute_inv_freq(self, device=None):
return 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
)
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))

def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
# Reset the tables if the sequence length has changed,
Expand Down Expand Up @@ -638,8 +614,7 @@ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = (
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
- seqlen // 2
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
# We want the multiplication by scale to happen in fp32
Expand Down
Loading

0 comments on commit 8084585

Please sign in to comment.