-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,14 +4,13 @@ | |
from typing import Optional, Tuple, Union | ||
|
||
import torch | ||
|
||
from einops import rearrange, repeat | ||
import triton | ||
import triton.language as tl | ||
|
||
from einops import rearrange, repeat | ||
|
||
##### TRITON KERNEL FOR ROTARY ##### | ||
|
||
|
||
# @triton.autotune( | ||
# configs=[ | ||
# triton.Config({"BLOCK_M": 2}), | ||
|
@@ -81,15 +80,13 @@ def rotary_kernel( | |
X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim) | ||
Check warning Code scanning / lintrunner RUFF/N806 Warning test
Variable X in function should be lowercase.
See https://docs.astral.sh/ruff/rules/non-lowercase-variable-in-function |
||
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) | ||
Check warning Code scanning / lintrunner RUFF/N806 Warning test
Variable COS in function should be lowercase.
See https://docs.astral.sh/ruff/rules/non-lowercase-variable-in-function |
||
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) | ||
Check warning Code scanning / lintrunner RUFF/N806 Warning test
Variable SIN in function should be lowercase.
See https://docs.astral.sh/ruff/rules/non-lowercase-variable-in-function |
||
cos = tl.load( | ||
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0 | ||
).to(tl.float32) | ||
sin = tl.load( | ||
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0 | ||
).to(tl.float32) | ||
x0 = tl.load( | ||
X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0 | ||
).to(tl.float32) | ||
cos = tl.load(COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0).to( | ||
tl.float32 | ||
) | ||
sin = tl.load(SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0).to( | ||
tl.float32 | ||
) | ||
x0 = tl.load(X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32) | ||
x1 = tl.load( | ||
X + rotary_dim_half * stride_x_headdim, | ||
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), | ||
|
@@ -130,12 +127,8 @@ def rotary_kernel( | |
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), | ||
other=0.0, | ||
).to(tl.float32) | ||
x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to( | ||
tl.float32 | ||
) | ||
x1 = tl.load( | ||
X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0 | ||
).to(tl.float32) | ||
x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(tl.float32) | ||
x1 = tl.load(X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0).to(tl.float32) | ||
if CONJUGATE: | ||
sin = -sin | ||
x0_cos = x0 * cos | ||
|
@@ -184,12 +177,8 @@ def apply_rotary( | |
assert headdim <= 256, "Only support headdim <= 256" | ||
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" | ||
|
||
assert ( | ||
cos.dtype == sin.dtype | ||
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" | ||
assert ( | ||
x.dtype == cos.dtype | ||
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" | ||
assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" | ||
assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" | ||
|
||
cos, sin = cos.contiguous(), sin.contiguous() | ||
if isinstance(seqlen_offsets, torch.Tensor): | ||
|
@@ -203,11 +192,7 @@ def apply_rotary( | |
if rotary_dim < headdim and not inplace: | ||
output[..., rotary_dim:].copy_(x[..., rotary_dim:]) | ||
|
||
BLOCK_K = ( | ||
32 | ||
if rotary_dim <= 32 | ||
else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) | ||
) | ||
BLOCK_K = 32 if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) | ||
Check warning Code scanning / lintrunner RUFF/N806 Warning test
Variable BLOCK\_K in function should be lowercase.
See https://docs.astral.sh/ruff/rules/non-lowercase-variable-in-function |
||
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa | ||
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) | ||
Check warning Code scanning / lintrunner RUFF/N806 Warning test
Variable BLOCK\_M in function should be lowercase.
See https://docs.astral.sh/ruff/rules/non-lowercase-variable-in-function |
||
|
||
|
@@ -243,8 +228,10 @@ def apply_rotary( | |
) | ||
return output | ||
|
||
|
||
##### ROTARY API ##### | ||
|
||
|
||
def rotate_half(x, interleaved=False): | ||
if not interleaved: | ||
x1, x2 = x.chunk(2, dim=-1) | ||
|
@@ -356,9 +343,7 @@ def apply_rotary_emb( | |
rotary_dim must be <= headdim | ||
Apply rotary embedding to the first rotary_dim of x. | ||
""" | ||
return ApplyRotaryEmb.apply( | ||
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen | ||
) | ||
return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen) | ||
|
||
|
||
# For backward compatibility | ||
|
@@ -385,9 +370,7 @@ def forward( | |
# dimensions, we get the same tensor | ||
# qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") | ||
qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim) | ||
apply_rotary( | ||
qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True | ||
) | ||
apply_rotary(qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True) | ||
else: | ||
cos_k = cos if cos_k is None else cos_k | ||
sin_k = sin if sin_k is None else sin_k | ||
|
@@ -429,9 +412,7 @@ def backward(ctx, dqkv): | |
cos_k = cos if cos_k is None else cos_k | ||
sin_k = sin if sin_k is None else sin_k | ||
dq, dk = dqkv[:, :, 0], dqkv[:, :, 1] | ||
apply_rotary( | ||
dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True | ||
) | ||
apply_rotary(dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True) | ||
apply_rotary( | ||
dk, | ||
cos_k, | ||
|
@@ -476,9 +457,7 @@ def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, tor | |
batch, seqlen, two, nheads, headdim = kv.shape | ||
assert two == 2 | ||
k = kv[:, :, 0] | ||
apply_rotary( | ||
k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True | ||
) | ||
apply_rotary(k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True) | ||
if isinstance(seqlen_offsets, int): | ||
ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward | ||
ctx.seqlen_offsets = seqlen_offsets | ||
|
@@ -597,10 +576,7 @@ def __init__( | |
self._sin_k_cached = None | ||
|
||
def _compute_inv_freq(self, device=None): | ||
return 1.0 / ( | ||
self.base | ||
** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) | ||
) | ||
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) | ||
|
||
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): | ||
# Reset the tables if the sequence length has changed, | ||
|
@@ -638,8 +614,7 @@ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): | |
self._sin_cached = torch.sin(freqs).to(dtype) | ||
else: | ||
power = ( | ||
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) | ||
- seqlen // 2 | ||
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 | ||
) / self.scale_base | ||
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") | ||
# We want the multiplication by scale to happen in fp32 | ||
|