From dde8528bf62a497b09eddf2d02bda398d2d25467 Mon Sep 17 00:00:00 2001 From: FindHao Date: Fri, 4 Oct 2024 16:08:24 -0700 Subject: [PATCH] Add FusedLinearCrossEntropy (#2485) Summary: As discussed in https://github.com/pytorch/pytorch/issues/136168, I'm going to migrate implementations of operator benchmarking. This PR adds different implementations for FusedLinearCrossEntropy as a starting example. Execution command: ``` python run_benchmark.py triton --op FusedLinearCrossEntropy ``` Example output: ``` x_val LMHeadCE-latency LigerLMHeadCE-latency inductor_fused_linear_cross_entropy-latency ------- ------------------ ----------------------- --------------------------------------------- 0 98.0041 389.87 95.0412 1 196.12 652.619 193.219 2 417.242 1248.75 416.725 3 824.906 2356.25 809.56 ``` Pull Request resolved: https://github.com/pytorch/benchmark/pull/2485 Reviewed By: xuzhao9 Differential Revision: D63859871 Pulled By: FindHao fbshipit-source-id: 4b73a2144702c1f8f3ae5ed15e76112d03f12b87 --- pyproject.toml | 12 ++ .../FusedLinearCrossEntropy/__init__.py | 1 + .../FusedLinearCrossEntropy/operator.py | 108 ++++++++++++++++++ userbenchmark/triton/install.py | 10 ++ 4 files changed, 131 insertions(+) create mode 100644 pyproject.toml create mode 100644 torchbenchmark/operators/FusedLinearCrossEntropy/__init__.py create mode 100644 torchbenchmark/operators/FusedLinearCrossEntropy/operator.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..da571fcfd0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,12 @@ +[build-system] +# Use legacy backend to import local packages in setup.py +build-backend = "setuptools.build_meta:__legacy__" + + +[tool.black] +line-length = 88 +target-version = ["py38"] +exclude = '''/submodules/.*''' + +[tool.usort] +excludes = ["**/submodules/**"] diff --git a/torchbenchmark/operators/FusedLinearCrossEntropy/__init__.py b/torchbenchmark/operators/FusedLinearCrossEntropy/__init__.py new file mode 100644 index 0000000000..a77a295cc4 --- /dev/null +++ b/torchbenchmark/operators/FusedLinearCrossEntropy/__init__.py @@ -0,0 +1 @@ +from .operator import Operator diff --git a/torchbenchmark/operators/FusedLinearCrossEntropy/operator.py b/torchbenchmark/operators/FusedLinearCrossEntropy/operator.py new file mode 100644 index 0000000000..9b5ed35541 --- /dev/null +++ b/torchbenchmark/operators/FusedLinearCrossEntropy/operator.py @@ -0,0 +1,108 @@ +import argparse +from typing import Callable, Generator, List, Optional + +import torch + +from torchbenchmark.util.triton_op import BenchmarkOperator, register_benchmark + +try: + from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, + ) +except ModuleNotFoundError: + LigerFusedLinearCrossEntropyLoss = None + +# Reference: https://github.com/linkedin/Liger-Kernel/blob/\ +# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py + + +def parse_op_args(args: List[str]): + parser = argparse.ArgumentParser() + parser.add_argument("--hidden-size", type=int, default=4096, help="hidden size") + parser.add_argument("--vocab-size", type=int, default=128256, help="vocab size") + return parser.parse_args(args) + + +class TorchLMHeadCE(torch.nn.Module): + """Ground truth implementation of the linear fused with torch based cross entropy loss. + + :param H: hidden size + :param V: vocab size + :param ignore_index: index to ignore + :param reduction: reduction method + """ + + def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype + ) + self.ce_loss = torch.nn.CrossEntropyLoss( + ignore_index=ignore_index, reduction="mean" + ) + + def forward(self, input, target): + logits = self.lin(input) + return self.ce_loss(logits, target) + + +class LigerLMHeadCE(torch.nn.Module): + def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype + ) + self.ce_loss = LigerFusedLinearCrossEntropyLoss( + ignore_index=ignore_index, reduction="mean" + ) + + def forward(self, input, target): + return self.ce_loss(self.lin.weight, input, target) + + +class Operator(BenchmarkOperator): + def __init__( + self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None + ): + super().__init__(tb_args, extra_args) + op_args = parse_op_args(self.extra_args) + self.hidden_size = op_args.hidden_size + self.vocab_size = op_args.vocab_size + self.baseline_model = TorchLMHeadCE( + H=self.hidden_size, V=self.vocab_size, dtype=self.dtype + ).to(self.device) + self.liger_model = LigerLMHeadCE( + H=self.hidden_size, V=self.vocab_size, dtype=self.dtype + ).to(self.device) + self.use_cuda_graphs = False + + def get_input_iter(self) -> Generator: + for BT in [2**i for i in range(12, 16)]: + _input = torch.randn( + BT, + self.hidden_size, + requires_grad=True, + dtype=self.dtype, + device=self.device, + ) + target = torch.randint( + self.vocab_size, (BT, 1), dtype=torch.long, device=self.device + ).squeeze(1) + yield _input, target + + @register_benchmark(baseline=True) + def LMHeadCE(self, input, target) -> Callable: + return lambda: self.baseline_model(input, target) + + @register_benchmark() + def LigerLMHeadCE(self, input, target) -> Callable: + return lambda: self.liger_model(input, target) + + @register_benchmark() + def inductor_fused_linear_cross_entropy(self, input, target) -> Callable: + compiled = torch.compile(self.baseline_model, dynamic=False) + return lambda: compiled(input, target) + + def get_bwd_fn(self, fwd_fn: Callable) -> Callable: + y = fwd_fn() + return lambda: y.backward(retain_graph=True) diff --git a/userbenchmark/triton/install.py b/userbenchmark/triton/install.py index a762c402ef..0f5e0ea82d 100644 --- a/userbenchmark/triton/install.py +++ b/userbenchmark/triton/install.py @@ -66,6 +66,13 @@ def install_fa3(): subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve())) +def install_liger(): + # Liger-kernel has a conflict dependency `triton` with pytorch, + # so we need to install it without dependencies + cmd = ["pip", "install", "liger-kernel", "--no-deps"] + subprocess.check_call(cmd) + + def install_tk(): try: from .tk.install import install_tk @@ -88,6 +95,7 @@ def install_tk(): ) parser.add_argument("--jax", action="store_true", help="Install jax nightly") parser.add_argument("--tk", action="store_true", help="Install ThunderKittens") + parser.add_argument("--liger", action="store_true", help="Install Liger-kernel") parser.add_argument("--test", action="store_true", help="Run test") args = parser.parse_args() @@ -105,3 +113,5 @@ def install_tk(): install_jax() if args.tk and not args.test: install_tk() + if args.liger and not args.test: + install_liger()