From ebb212c533e65ec07bd2b461df4471f3f70b3a98 Mon Sep 17 00:00:00 2001 From: Sijia Chen Date: Wed, 2 Oct 2024 18:11:57 -0700 Subject: [PATCH] Prototype (#2486) Summary: Pull Request resolved: https://github.com/pytorch/benchmark/pull/2486 Differential Revision: D61055780 --- .../operators/fused_ffn/__init__.py | 1 + torchbenchmark/operators/fused_ffn/kernel.py | 306 ++++++++++++++++++ .../operators/fused_ffn/operator.py | 120 +++++++ 3 files changed, 427 insertions(+) create mode 100644 torchbenchmark/operators/fused_ffn/__init__.py create mode 100644 torchbenchmark/operators/fused_ffn/kernel.py create mode 100644 torchbenchmark/operators/fused_ffn/operator.py diff --git a/torchbenchmark/operators/fused_ffn/__init__.py b/torchbenchmark/operators/fused_ffn/__init__.py new file mode 100644 index 0000000000..a77a295cc4 --- /dev/null +++ b/torchbenchmark/operators/fused_ffn/__init__.py @@ -0,0 +1 @@ +from .operator import Operator diff --git a/torchbenchmark/operators/fused_ffn/kernel.py b/torchbenchmark/operators/fused_ffn/kernel.py new file mode 100644 index 0000000000..2cb837b867 --- /dev/null +++ b/torchbenchmark/operators/fused_ffn/kernel.py @@ -0,0 +1,306 @@ + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + # B_T, H_D (8192), D (2048) + {"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "BLOCK_K": BLOCK_K}, + num_stages=num_stages, + num_warps=num_warps, + ) + for BLOCK_M in [64] + for BLOCK_N in [128] + for BLOCK_K in [128, 256] + for num_stages in [2] + for num_warps in [8] + ], + key=["B_T", "D", "H_D"], +) +@triton.jit +def fused_ffn_fwd( + x_ptr, + w13_ptr, + w2_ptr, + output_ptr, + p_ptr, + B_T, + stride_xa, + stride_xb, + stride_w13a, + stride_w13b, + stride_w2a, + stride_w2b, + stride_oa, + stride_ob, + stride_pa, + stride_pb, + HAS_P: tl.constexpr, + D: tl.constexpr, + H_D: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + dtype = x_ptr.dtype.element_ty + + X_block_ptr = tl.make_block_ptr( + base=x_ptr, + shape=(B_T, D), + strides=(stride_xa, stride_xb), + offsets=(pid_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_K), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(B_T, D), + strides=(stride_oa, stride_ob), + offsets=(pid_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_K), + order=(1, 0), + ) + + for start_n in range(0, H_D, BLOCK_N): + if HAS_P: + P_block_ptr = tl.make_block_ptr( + base=p_ptr, + shape=(B_T, H_D), + strides=(stride_pa, stride_pb), + offsets=(pid_m * BLOCK_M, start_n), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + P_block_ptr = None + + w1t_bptr = tl.make_block_ptr( + base=w13_ptr, + shape=(D, H_D), + strides=(stride_w13b, stride_w13a), + offsets=(0, start_n), + block_shape=(BLOCK_K, BLOCK_N), + order=(0, 1), + ) + w3t_bptr = tl.make_block_ptr( + base=w13_ptr, + shape=(D, H_D), + strides=(stride_w13b, stride_w13a), + offsets=(0, H_D + start_n), + block_shape=(BLOCK_K, BLOCK_N), + order=(0, 1), + ) + w2_bptr = tl.make_block_ptr( + base=w2_ptr, + shape=(H_D, D), + strides=(stride_w2a, stride_w2b), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_K), + order=(1, 0), + ) + + x_bptr = X_block_ptr + o_bptr = O_block_ptr + acc_1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + acc_3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + # first GEMM + w1t_bptr_inner = w1t_bptr + w3t_bptr_inner = w3t_bptr + w2_bptr_inner = w2_bptr + for _ in range(0, D, BLOCK_K): + x = tl.load(x_bptr) + w1t = tl.load(w1t_bptr_inner) + w3t = tl.load(w3t_bptr_inner) + acc_1 = tl.dot(x, w1t, acc_1) + acc_3 = tl.dot(x, w3t, acc_3) + x_bptr = tl.advance(x_bptr, (0, BLOCK_K)) + w1t_bptr_inner = tl.advance(w1t_bptr_inner, (BLOCK_K, 0)) + w3t_bptr_inner = tl.advance(w3t_bptr_inner, (BLOCK_K, 0)) + # acc_1 = acc_1.to(dtype).to(tl.float32) + # acc_3 = acc_3.to(dtype).to(tl.float32) + p = acc_1 * tl.sigmoid(acc_1) * acc_3 + p = p.to(dtype) + if HAS_P: + tl.store(P_block_ptr, p) + # second GEMM + for _ in range(0, BLOCK_K, BLOCK_K): + w2 = tl.load(w2_bptr) + o = tl.load(o_bptr) + tl.store(o_bptr, (tl.dot(p, w2) + o).to(dtype)) + w2_bptr_inner = tl.advance(w2_bptr_inner, (0, BLOCK_K)) + o_bptr = tl.advance(o_bptr, (0, BLOCK_K)) + + +def fused_ffn( + x: torch.Tensor, w13: torch.Tensor, w2: torch.Tensor, has_p: bool = False +): + # x: [B_T, D] + # w13: [H_D*2, D] + # w2: [H_D, D] + # output: [B_T, D] + B_T, D = x.shape + H_D_2, D = w13.shape + H_D = w2.shape[0] + assert H_D_2 == 2 * H_D, f"H_D_2 must be 2 times of H_D but got {H_D_2=} and {H_D=}" + + def grid(META): + return (triton.cdiv(B_T, META["BLOCK_M"]),) + + output = torch.empty_like(x) + if has_p: + p = torch.empty((B_T, H_D), dtype=x.dtype, device=x.device) + else: + p = None + + fused_ffn_fwd[grid]( + x, + w13, + w2, + output, + p, + B_T, + x.stride(0), + x.stride(1), + w13.stride(0), + w13.stride(1), + w2.stride(0), + w2.stride(1), + output.stride(0), + output.stride(1), + p.stride(0) if has_p else 0, + p.stride(1) if has_p else 0, + has_p, + D, + H_D, + ) + + return output, p + + +@triton.jit +# pyre-fixme[3]: Return type must be annotated. +def _silu_mul_kernel( + # pyre-fixme[2]: Parameter must be annotated. + x1_ptr, + x1_stride: tl.constexpr, + # pyre-fixme[2]: Parameter must be annotated. + x2_ptr, + x2_stride: tl.constexpr, + # pyre-fixme[2]: Parameter must be annotated. + y_ptr, + D: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + b = tl.program_id(0).to(tl.int64) + + x1_start = x1_ptr + b * x1_stride + x2_start = x2_ptr + b * x2_stride + y_start = y_ptr + b * D + + for offset in range(0, D, BLOCK_SIZE): + cols = offset + tl.arange(0, BLOCK_SIZE) + mask = cols < D + x1v = tl.load(x1_start + cols, mask=mask, other=0).to(tl.float32) + x2v = tl.load(x2_start + cols, mask=mask, other=0).to(tl.float32) + yv = (x1v * tl.sigmoid(x1v) * x2v).to(tl.bfloat16) + tl.store(y_start + cols, yv, mask=mask) + + +sigmoid = torch.nn.Sigmoid() + + +def silu_mul(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + assert x1.shape == x2.shape + (B_T, D) = x1.shape + out = torch.empty_like(x1) + assert x1.stride(1) == x2.stride(1) == 1 + assert out.is_contiguous() + grid = (B_T,) + _silu_mul_kernel[grid](x1, x1.stride(0), x2, x2.stride(0), out, D, BLOCK_SIZE=1024) + return out + + +def eager_ffn(x, w13, w2): + p = x @ w13.T + H_D_2, D = w13.shape + H_D = H_D_2 // 2 + p1 = p[:, :H_D] # B_T, H_D + p2 = p[:, H_D:] # B_T, H_D + p_out = silu_mul(p1, p2) # B_T, H_D + out = p_out @ w2 + return out, p_out + + +def nunerics_check(shape): + B_T, H_D, D = shape + x = torch.randn((B_T, D), dtype=torch.bfloat16, device="cuda") + w13 = torch.randn((H_D * 2, D), dtype=torch.bfloat16, device="cuda") + w2 = torch.randn((H_D, D), dtype=torch.bfloat16, device="cuda") + triton_out, triton_p = fused_ffn(x, w13, w2, has_p=True) + eager_out, eager_p = eager_ffn(x, w13, w2) + + print("P numeric check: ", torch.allclose(triton_p, eager_p, atol=1e-2, rtol=1e-2)) + # print("P numeric check: ", torch.allclose(eager_p, ref_p, atol=1e-2, rtol=0)) + # print(triton_p[-1]) + # print(eager_p[-1]) + # print(ref_p[-1]) + + +def do_benchmark(): + + D = 2048 + H_D = 8192 + + configs = [] + configs.append( + triton.testing.Benchmark( + x_names=[ + "B_T", + "H_D", + "D", + ], # Argument names to use as an x-axis for the plot + x_vals=[ + (i, H_D, D) + for H_D, D in [(128, 256), (1024, 512), (8192, 2048)] + for i in [1024, 2048, 4096, 8192, 16384] + ], # Different possible values for `x_name` + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + # Possible values for `line_arg` + # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment. + line_vals=["eager", "fused"], + line_names=["Eager", "Fused"], + styles=[("green", "-"), ("blue", "-")], + ylabel="Latency(ms)", # Label name for the y-axis + plot_name="fused_ffn-benchmark", + args={}, + ) + ) + + @triton.testing.perf_report(configs) + def benchmark(B_T, H_D, D, provider): + # breakpoint() + x = torch.randn((B_T, D), dtype=torch.bfloat16, device="cuda") + w13 = torch.randn((H_D * 2, D), dtype=torch.bfloat16, device="cuda") + w2 = torch.randn((H_D, D), dtype=torch.bfloat16, device="cuda") + quantiles = [0.5, 0.2, 0.8] + if provider == "eager": + return triton.testing.do_bench( + lambda: eager_ffn(x, w13, w2), quantiles=quantiles + ) + if provider == "fused": + return triton.testing.do_bench( + lambda: fused_ffn(x, w13, w2), quantiles=quantiles + ) + + benchmark.run(show_plots=True, print_data=True) + + +if __name__ == "__main__": + # B_T, H_D, D + nunerics_check((64, 128, 128)) + # nunerics_check((256, 8192, 2048)) + # do_benchmark() diff --git a/torchbenchmark/operators/fused_ffn/operator.py b/torchbenchmark/operators/fused_ffn/operator.py new file mode 100644 index 0000000000..bf4fc538d8 --- /dev/null +++ b/torchbenchmark/operators/fused_ffn/operator.py @@ -0,0 +1,120 @@ +import argparse +import os +from typing import Any, Callable, Generator, List, Optional, Tuple + +import torch +import triton + +from torchbenchmark.util.triton_op import ( + BenchmarkOperator, + BenchmarkOperatorMetrics, + register_benchmark, + register_metric, + register_x_val, +) + +from .kernel import eager_ffn, fused_ffn + + +def parse_args(args: List[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="TorchBench for fused FFN operator Benchmark" + ) + parser.add_argument("--b-t", type=int) + parser.add_argument("--h-d", type=int) + parser.add_argument("--d", type=int) + args = parser.parse_args(args) + return args + + +BUILDIN_SHAPES = [ + (b_t, h_d, d) + for h_d, d in [(128, 256), (1024, 512), (8192, 2048)] + for b_t in [1024, 2048, 4096, 8192, 16384] +] + + +class Operator(BenchmarkOperator): + DEFAULT_METRICS = ["latency"] + DEFAULT_PRECISION = "bf16" + + def __init__( + self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None + ): + super().__init__(tb_args, extra_args) + addmm_args = parse_args(self.extra_args) + if addmm_args.m and addmm_args.n and addmm_args.k: + self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)] + else: + self.shapes = BUILDIN_SHAPES + + @register_benchmark() + def fused_ffn_op(self, x, w13, w2) -> Callable: + return lambda: fused_ffn(x, w13, w2) + + @register_benchmark() + def eager_ffn_op(self, x, w13, w2) -> Callable: + return lambda: eager_ffn(x, w13, w2) + + @register_metric() + def tflops( + self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics + ) -> List[float]: + x, w13, w2 = example_inputs + B_T, _ = x.size() + H_D_2, _ = w13.size() + H_D, D = w2.size() + # gemm #1 + flops = 2 * B_T * H_D_2 * D + # gemm #2 + flops += 2 * B_T * H_D * D + return flops / metrics.latency / 1e12 * 1e3 + + @register_x_val(label="(B_T, Hidden_D, D)") + def get_x_val(self, example_inputs) -> Tuple[int, int, int]: + x, w13, w2 = example_inputs + B_T, D = x.size() + H_D, D = w2.size() + return (B_T, H_D, D) + + def get_input_iter(self) -> Generator: + for shape in self.shapes: + b_t, h_d, d = shape + x = torch.randn((b_t, d), device=self.device, dtype=self.dtype) + w13 = torch.randn((2 * h_d, d), device=self.device, dtype=self.dtype) + w2 = torch.randn((h_d, d), device=self.device, dtype=self.dtype) + + yield x, w13, w2 + + def plot(self): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["shape"], # argument names to use as an x-axis for the plot + x_vals=self.output.x_vals, # different possible values for `x_name` + line_arg="provider", # argument name whose value corresponds to a different line in the plot + line_vals=[ + "fused_ffn", + "eager_ffn", + ], # possible values for `line_arg`` + line_names=[ + "Fused FFN", + "Eager FFN", + ], # label name for the lines + styles=[ + ("blue", "-"), + ("green", "-"), + ], # line styles + ylabel="tflops", # label name for the y-axis + plot_name="gemm-performance", # name for the plot. Used also as a file name for saving the plot. + args={}, # values for function arguments not in `x_names` and `y_name` + ) + ) + def _plot(density, provider): + tflops = self.output.get_y_vals(density, provider, "tflops") + return tflops + + save_path = self.get_temp_path() + + os.mkdirs(save_path, exist_ok=True) + + _plot.run(show_plots=True, print_data=True, save_path=save_path)