-
Notifications
You must be signed in to change notification settings - Fork 295
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #2486 Differential Revision: D61055780
- Loading branch information
1 parent
a8ce4b5
commit ebb212c
Showing
3 changed files
with
427 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .operator import Operator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,306 @@ | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
|
||
@triton.autotune( | ||
configs=[ | ||
triton.Config( | ||
# B_T, H_D (8192), D (2048) | ||
{"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "BLOCK_K": BLOCK_K}, | ||
num_stages=num_stages, | ||
num_warps=num_warps, | ||
) | ||
for BLOCK_M in [64] | ||
for BLOCK_N in [128] | ||
for BLOCK_K in [128, 256] | ||
for num_stages in [2] | ||
for num_warps in [8] | ||
], | ||
key=["B_T", "D", "H_D"], | ||
) | ||
@triton.jit | ||
def fused_ffn_fwd( | ||
x_ptr, | ||
w13_ptr, | ||
w2_ptr, | ||
output_ptr, | ||
p_ptr, | ||
B_T, | ||
stride_xa, | ||
stride_xb, | ||
stride_w13a, | ||
stride_w13b, | ||
stride_w2a, | ||
stride_w2b, | ||
stride_oa, | ||
stride_ob, | ||
stride_pa, | ||
stride_pb, | ||
HAS_P: tl.constexpr, | ||
D: tl.constexpr, | ||
H_D: tl.constexpr, | ||
BLOCK_M: tl.constexpr, | ||
BLOCK_N: tl.constexpr, | ||
BLOCK_K: tl.constexpr, | ||
): | ||
pid_m = tl.program_id(axis=0) | ||
dtype = x_ptr.dtype.element_ty | ||
|
||
X_block_ptr = tl.make_block_ptr( | ||
base=x_ptr, | ||
shape=(B_T, D), | ||
strides=(stride_xa, stride_xb), | ||
offsets=(pid_m * BLOCK_M, 0), | ||
block_shape=(BLOCK_M, BLOCK_K), | ||
order=(1, 0), | ||
) | ||
O_block_ptr = tl.make_block_ptr( | ||
base=output_ptr, | ||
shape=(B_T, D), | ||
strides=(stride_oa, stride_ob), | ||
offsets=(pid_m * BLOCK_M, 0), | ||
block_shape=(BLOCK_M, BLOCK_K), | ||
order=(1, 0), | ||
) | ||
|
||
for start_n in range(0, H_D, BLOCK_N): | ||
if HAS_P: | ||
P_block_ptr = tl.make_block_ptr( | ||
base=p_ptr, | ||
shape=(B_T, H_D), | ||
strides=(stride_pa, stride_pb), | ||
offsets=(pid_m * BLOCK_M, start_n), | ||
block_shape=(BLOCK_M, BLOCK_N), | ||
order=(1, 0), | ||
) | ||
else: | ||
P_block_ptr = None | ||
|
||
w1t_bptr = tl.make_block_ptr( | ||
base=w13_ptr, | ||
shape=(D, H_D), | ||
strides=(stride_w13b, stride_w13a), | ||
offsets=(0, start_n), | ||
block_shape=(BLOCK_K, BLOCK_N), | ||
order=(0, 1), | ||
) | ||
w3t_bptr = tl.make_block_ptr( | ||
base=w13_ptr, | ||
shape=(D, H_D), | ||
strides=(stride_w13b, stride_w13a), | ||
offsets=(0, H_D + start_n), | ||
block_shape=(BLOCK_K, BLOCK_N), | ||
order=(0, 1), | ||
) | ||
w2_bptr = tl.make_block_ptr( | ||
base=w2_ptr, | ||
shape=(H_D, D), | ||
strides=(stride_w2a, stride_w2b), | ||
offsets=(0, 0), | ||
block_shape=(BLOCK_N, BLOCK_K), | ||
order=(1, 0), | ||
) | ||
|
||
x_bptr = X_block_ptr | ||
o_bptr = O_block_ptr | ||
acc_1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) | ||
acc_3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) | ||
# first GEMM | ||
w1t_bptr_inner = w1t_bptr | ||
w3t_bptr_inner = w3t_bptr | ||
w2_bptr_inner = w2_bptr | ||
for _ in range(0, D, BLOCK_K): | ||
x = tl.load(x_bptr) | ||
w1t = tl.load(w1t_bptr_inner) | ||
w3t = tl.load(w3t_bptr_inner) | ||
acc_1 = tl.dot(x, w1t, acc_1) | ||
acc_3 = tl.dot(x, w3t, acc_3) | ||
x_bptr = tl.advance(x_bptr, (0, BLOCK_K)) | ||
w1t_bptr_inner = tl.advance(w1t_bptr_inner, (BLOCK_K, 0)) | ||
w3t_bptr_inner = tl.advance(w3t_bptr_inner, (BLOCK_K, 0)) | ||
# acc_1 = acc_1.to(dtype).to(tl.float32) | ||
# acc_3 = acc_3.to(dtype).to(tl.float32) | ||
p = acc_1 * tl.sigmoid(acc_1) * acc_3 | ||
p = p.to(dtype) | ||
if HAS_P: | ||
tl.store(P_block_ptr, p) | ||
# second GEMM | ||
for _ in range(0, BLOCK_K, BLOCK_K): | ||
w2 = tl.load(w2_bptr) | ||
o = tl.load(o_bptr) | ||
tl.store(o_bptr, (tl.dot(p, w2) + o).to(dtype)) | ||
w2_bptr_inner = tl.advance(w2_bptr_inner, (0, BLOCK_K)) | ||
o_bptr = tl.advance(o_bptr, (0, BLOCK_K)) | ||
|
||
|
||
def fused_ffn( | ||
x: torch.Tensor, w13: torch.Tensor, w2: torch.Tensor, has_p: bool = False | ||
): | ||
# x: [B_T, D] | ||
# w13: [H_D*2, D] | ||
# w2: [H_D, D] | ||
# output: [B_T, D] | ||
B_T, D = x.shape | ||
H_D_2, D = w13.shape | ||
H_D = w2.shape[0] | ||
assert H_D_2 == 2 * H_D, f"H_D_2 must be 2 times of H_D but got {H_D_2=} and {H_D=}" | ||
|
||
def grid(META): | ||
return (triton.cdiv(B_T, META["BLOCK_M"]),) | ||
|
||
output = torch.empty_like(x) | ||
if has_p: | ||
p = torch.empty((B_T, H_D), dtype=x.dtype, device=x.device) | ||
else: | ||
p = None | ||
|
||
fused_ffn_fwd[grid]( | ||
x, | ||
w13, | ||
w2, | ||
output, | ||
p, | ||
B_T, | ||
x.stride(0), | ||
x.stride(1), | ||
w13.stride(0), | ||
w13.stride(1), | ||
w2.stride(0), | ||
w2.stride(1), | ||
output.stride(0), | ||
output.stride(1), | ||
p.stride(0) if has_p else 0, | ||
p.stride(1) if has_p else 0, | ||
has_p, | ||
D, | ||
H_D, | ||
) | ||
|
||
return output, p | ||
|
||
|
||
@triton.jit | ||
# pyre-fixme[3]: Return type must be annotated. | ||
def _silu_mul_kernel( | ||
# pyre-fixme[2]: Parameter must be annotated. | ||
x1_ptr, | ||
x1_stride: tl.constexpr, | ||
# pyre-fixme[2]: Parameter must be annotated. | ||
x2_ptr, | ||
x2_stride: tl.constexpr, | ||
# pyre-fixme[2]: Parameter must be annotated. | ||
y_ptr, | ||
D: tl.constexpr, | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
b = tl.program_id(0).to(tl.int64) | ||
|
||
x1_start = x1_ptr + b * x1_stride | ||
x2_start = x2_ptr + b * x2_stride | ||
y_start = y_ptr + b * D | ||
|
||
for offset in range(0, D, BLOCK_SIZE): | ||
cols = offset + tl.arange(0, BLOCK_SIZE) | ||
mask = cols < D | ||
x1v = tl.load(x1_start + cols, mask=mask, other=0).to(tl.float32) | ||
x2v = tl.load(x2_start + cols, mask=mask, other=0).to(tl.float32) | ||
yv = (x1v * tl.sigmoid(x1v) * x2v).to(tl.bfloat16) | ||
tl.store(y_start + cols, yv, mask=mask) | ||
|
||
|
||
sigmoid = torch.nn.Sigmoid() | ||
|
||
|
||
def silu_mul(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: | ||
assert x1.shape == x2.shape | ||
(B_T, D) = x1.shape | ||
out = torch.empty_like(x1) | ||
assert x1.stride(1) == x2.stride(1) == 1 | ||
assert out.is_contiguous() | ||
grid = (B_T,) | ||
_silu_mul_kernel[grid](x1, x1.stride(0), x2, x2.stride(0), out, D, BLOCK_SIZE=1024) | ||
return out | ||
|
||
|
||
def eager_ffn(x, w13, w2): | ||
p = x @ w13.T | ||
H_D_2, D = w13.shape | ||
H_D = H_D_2 // 2 | ||
p1 = p[:, :H_D] # B_T, H_D | ||
p2 = p[:, H_D:] # B_T, H_D | ||
p_out = silu_mul(p1, p2) # B_T, H_D | ||
out = p_out @ w2 | ||
return out, p_out | ||
|
||
|
||
def nunerics_check(shape): | ||
B_T, H_D, D = shape | ||
x = torch.randn((B_T, D), dtype=torch.bfloat16, device="cuda") | ||
w13 = torch.randn((H_D * 2, D), dtype=torch.bfloat16, device="cuda") | ||
w2 = torch.randn((H_D, D), dtype=torch.bfloat16, device="cuda") | ||
triton_out, triton_p = fused_ffn(x, w13, w2, has_p=True) | ||
eager_out, eager_p = eager_ffn(x, w13, w2) | ||
|
||
print("P numeric check: ", torch.allclose(triton_p, eager_p, atol=1e-2, rtol=1e-2)) | ||
# print("P numeric check: ", torch.allclose(eager_p, ref_p, atol=1e-2, rtol=0)) | ||
# print(triton_p[-1]) | ||
# print(eager_p[-1]) | ||
# print(ref_p[-1]) | ||
|
||
|
||
def do_benchmark(): | ||
|
||
D = 2048 | ||
H_D = 8192 | ||
|
||
configs = [] | ||
configs.append( | ||
triton.testing.Benchmark( | ||
x_names=[ | ||
"B_T", | ||
"H_D", | ||
"D", | ||
], # Argument names to use as an x-axis for the plot | ||
x_vals=[ | ||
(i, H_D, D) | ||
for H_D, D in [(128, 256), (1024, 512), (8192, 2048)] | ||
for i in [1024, 2048, 4096, 8192, 16384] | ||
], # Different possible values for `x_name` | ||
line_arg="provider", # Argument name whose value corresponds to a different line in the plot | ||
# Possible values for `line_arg` | ||
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment. | ||
line_vals=["eager", "fused"], | ||
line_names=["Eager", "Fused"], | ||
styles=[("green", "-"), ("blue", "-")], | ||
ylabel="Latency(ms)", # Label name for the y-axis | ||
plot_name="fused_ffn-benchmark", | ||
args={}, | ||
) | ||
) | ||
|
||
@triton.testing.perf_report(configs) | ||
def benchmark(B_T, H_D, D, provider): | ||
# breakpoint() | ||
x = torch.randn((B_T, D), dtype=torch.bfloat16, device="cuda") | ||
w13 = torch.randn((H_D * 2, D), dtype=torch.bfloat16, device="cuda") | ||
w2 = torch.randn((H_D, D), dtype=torch.bfloat16, device="cuda") | ||
quantiles = [0.5, 0.2, 0.8] | ||
if provider == "eager": | ||
return triton.testing.do_bench( | ||
lambda: eager_ffn(x, w13, w2), quantiles=quantiles | ||
) | ||
if provider == "fused": | ||
return triton.testing.do_bench( | ||
lambda: fused_ffn(x, w13, w2), quantiles=quantiles | ||
) | ||
|
||
benchmark.run(show_plots=True, print_data=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
# B_T, H_D, D | ||
nunerics_check((64, 128, 128)) | ||
# nunerics_check((256, 8192, 2048)) | ||
# do_benchmark() |
Oops, something went wrong.