Skip to content

Commit

Permalink
Prototype (#2486)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2486

Differential Revision: D61055780
  • Loading branch information
sijiac authored and facebook-github-bot committed Oct 3, 2024
1 parent a8ce4b5 commit ebb212c
Show file tree
Hide file tree
Showing 3 changed files with 427 additions and 0 deletions.
1 change: 1 addition & 0 deletions torchbenchmark/operators/fused_ffn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .operator import Operator
306 changes: 306 additions & 0 deletions torchbenchmark/operators/fused_ffn/kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@

import torch
import triton
import triton.language as tl


@triton.autotune(
configs=[
triton.Config(
# B_T, H_D (8192), D (2048)
{"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "BLOCK_K": BLOCK_K},
num_stages=num_stages,
num_warps=num_warps,
)
for BLOCK_M in [64]
for BLOCK_N in [128]
for BLOCK_K in [128, 256]
for num_stages in [2]
for num_warps in [8]
],
key=["B_T", "D", "H_D"],
)
@triton.jit
def fused_ffn_fwd(
x_ptr,
w13_ptr,
w2_ptr,
output_ptr,
p_ptr,
B_T,
stride_xa,
stride_xb,
stride_w13a,
stride_w13b,
stride_w2a,
stride_w2b,
stride_oa,
stride_ob,
stride_pa,
stride_pb,
HAS_P: tl.constexpr,
D: tl.constexpr,
H_D: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
dtype = x_ptr.dtype.element_ty

X_block_ptr = tl.make_block_ptr(
base=x_ptr,
shape=(B_T, D),
strides=(stride_xa, stride_xb),
offsets=(pid_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_K),
order=(1, 0),
)
O_block_ptr = tl.make_block_ptr(
base=output_ptr,
shape=(B_T, D),
strides=(stride_oa, stride_ob),
offsets=(pid_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_K),
order=(1, 0),
)

for start_n in range(0, H_D, BLOCK_N):
if HAS_P:
P_block_ptr = tl.make_block_ptr(
base=p_ptr,
shape=(B_T, H_D),
strides=(stride_pa, stride_pb),
offsets=(pid_m * BLOCK_M, start_n),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
P_block_ptr = None

w1t_bptr = tl.make_block_ptr(
base=w13_ptr,
shape=(D, H_D),
strides=(stride_w13b, stride_w13a),
offsets=(0, start_n),
block_shape=(BLOCK_K, BLOCK_N),
order=(0, 1),
)
w3t_bptr = tl.make_block_ptr(
base=w13_ptr,
shape=(D, H_D),
strides=(stride_w13b, stride_w13a),
offsets=(0, H_D + start_n),
block_shape=(BLOCK_K, BLOCK_N),
order=(0, 1),
)
w2_bptr = tl.make_block_ptr(
base=w2_ptr,
shape=(H_D, D),
strides=(stride_w2a, stride_w2b),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_K),
order=(1, 0),
)

x_bptr = X_block_ptr
o_bptr = O_block_ptr
acc_1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
acc_3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# first GEMM
w1t_bptr_inner = w1t_bptr
w3t_bptr_inner = w3t_bptr
w2_bptr_inner = w2_bptr
for _ in range(0, D, BLOCK_K):
x = tl.load(x_bptr)
w1t = tl.load(w1t_bptr_inner)
w3t = tl.load(w3t_bptr_inner)
acc_1 = tl.dot(x, w1t, acc_1)
acc_3 = tl.dot(x, w3t, acc_3)
x_bptr = tl.advance(x_bptr, (0, BLOCK_K))
w1t_bptr_inner = tl.advance(w1t_bptr_inner, (BLOCK_K, 0))
w3t_bptr_inner = tl.advance(w3t_bptr_inner, (BLOCK_K, 0))
# acc_1 = acc_1.to(dtype).to(tl.float32)
# acc_3 = acc_3.to(dtype).to(tl.float32)
p = acc_1 * tl.sigmoid(acc_1) * acc_3
p = p.to(dtype)
if HAS_P:
tl.store(P_block_ptr, p)
# second GEMM
for _ in range(0, BLOCK_K, BLOCK_K):
w2 = tl.load(w2_bptr)
o = tl.load(o_bptr)
tl.store(o_bptr, (tl.dot(p, w2) + o).to(dtype))
w2_bptr_inner = tl.advance(w2_bptr_inner, (0, BLOCK_K))
o_bptr = tl.advance(o_bptr, (0, BLOCK_K))


def fused_ffn(
x: torch.Tensor, w13: torch.Tensor, w2: torch.Tensor, has_p: bool = False
):
# x: [B_T, D]
# w13: [H_D*2, D]
# w2: [H_D, D]
# output: [B_T, D]
B_T, D = x.shape
H_D_2, D = w13.shape
H_D = w2.shape[0]
assert H_D_2 == 2 * H_D, f"H_D_2 must be 2 times of H_D but got {H_D_2=} and {H_D=}"

def grid(META):
return (triton.cdiv(B_T, META["BLOCK_M"]),)

output = torch.empty_like(x)
if has_p:
p = torch.empty((B_T, H_D), dtype=x.dtype, device=x.device)
else:
p = None

fused_ffn_fwd[grid](
x,
w13,
w2,
output,
p,
B_T,
x.stride(0),
x.stride(1),
w13.stride(0),
w13.stride(1),
w2.stride(0),
w2.stride(1),
output.stride(0),
output.stride(1),
p.stride(0) if has_p else 0,
p.stride(1) if has_p else 0,
has_p,
D,
H_D,
)

return output, p


@triton.jit
# pyre-fixme[3]: Return type must be annotated.
def _silu_mul_kernel(
# pyre-fixme[2]: Parameter must be annotated.
x1_ptr,
x1_stride: tl.constexpr,
# pyre-fixme[2]: Parameter must be annotated.
x2_ptr,
x2_stride: tl.constexpr,
# pyre-fixme[2]: Parameter must be annotated.
y_ptr,
D: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
b = tl.program_id(0).to(tl.int64)

x1_start = x1_ptr + b * x1_stride
x2_start = x2_ptr + b * x2_stride
y_start = y_ptr + b * D

for offset in range(0, D, BLOCK_SIZE):
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < D
x1v = tl.load(x1_start + cols, mask=mask, other=0).to(tl.float32)
x2v = tl.load(x2_start + cols, mask=mask, other=0).to(tl.float32)
yv = (x1v * tl.sigmoid(x1v) * x2v).to(tl.bfloat16)
tl.store(y_start + cols, yv, mask=mask)


sigmoid = torch.nn.Sigmoid()


def silu_mul(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
assert x1.shape == x2.shape
(B_T, D) = x1.shape
out = torch.empty_like(x1)
assert x1.stride(1) == x2.stride(1) == 1
assert out.is_contiguous()
grid = (B_T,)
_silu_mul_kernel[grid](x1, x1.stride(0), x2, x2.stride(0), out, D, BLOCK_SIZE=1024)
return out


def eager_ffn(x, w13, w2):
p = x @ w13.T
H_D_2, D = w13.shape
H_D = H_D_2 // 2
p1 = p[:, :H_D] # B_T, H_D
p2 = p[:, H_D:] # B_T, H_D
p_out = silu_mul(p1, p2) # B_T, H_D
out = p_out @ w2
return out, p_out


def nunerics_check(shape):
B_T, H_D, D = shape
x = torch.randn((B_T, D), dtype=torch.bfloat16, device="cuda")
w13 = torch.randn((H_D * 2, D), dtype=torch.bfloat16, device="cuda")
w2 = torch.randn((H_D, D), dtype=torch.bfloat16, device="cuda")
triton_out, triton_p = fused_ffn(x, w13, w2, has_p=True)
eager_out, eager_p = eager_ffn(x, w13, w2)

print("P numeric check: ", torch.allclose(triton_p, eager_p, atol=1e-2, rtol=1e-2))
# print("P numeric check: ", torch.allclose(eager_p, ref_p, atol=1e-2, rtol=0))
# print(triton_p[-1])
# print(eager_p[-1])
# print(ref_p[-1])


def do_benchmark():

D = 2048
H_D = 8192

configs = []
configs.append(
triton.testing.Benchmark(
x_names=[
"B_T",
"H_D",
"D",
], # Argument names to use as an x-axis for the plot
x_vals=[
(i, H_D, D)
for H_D, D in [(128, 256), (1024, 512), (8192, 2048)]
for i in [1024, 2048, 4096, 8192, 16384]
], # Different possible values for `x_name`
line_arg="provider", # Argument name whose value corresponds to a different line in the plot
# Possible values for `line_arg`
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
line_vals=["eager", "fused"],
line_names=["Eager", "Fused"],
styles=[("green", "-"), ("blue", "-")],
ylabel="Latency(ms)", # Label name for the y-axis
plot_name="fused_ffn-benchmark",
args={},
)
)

@triton.testing.perf_report(configs)
def benchmark(B_T, H_D, D, provider):
# breakpoint()
x = torch.randn((B_T, D), dtype=torch.bfloat16, device="cuda")
w13 = torch.randn((H_D * 2, D), dtype=torch.bfloat16, device="cuda")
w2 = torch.randn((H_D, D), dtype=torch.bfloat16, device="cuda")
quantiles = [0.5, 0.2, 0.8]
if provider == "eager":
return triton.testing.do_bench(
lambda: eager_ffn(x, w13, w2), quantiles=quantiles
)
if provider == "fused":
return triton.testing.do_bench(
lambda: fused_ffn(x, w13, w2), quantiles=quantiles
)

benchmark.run(show_plots=True, print_data=True)


if __name__ == "__main__":
# B_T, H_D, D
nunerics_check((64, 128, 128))
# nunerics_check((256, 8192, 2048))
# do_benchmark()
Loading

0 comments on commit ebb212c

Please sign in to comment.