-
Notifications
You must be signed in to change notification settings - Fork 295
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: As discussed in pytorch/pytorch#136168, I'm going to migrate implementations of operator benchmarking. This PR adds different implementations for FusedLinearCrossEntropy as a starting example. Execution command: ``` python run_benchmark.py triton --op FusedLinearCrossEntropy ``` Example output: ``` x_val LMHeadCE-latency LigerLMHeadCE-latency inductor_fused_linear_cross_entropy-latency ------- ------------------ ----------------------- --------------------------------------------- 0 98.0041 389.87 95.0412 1 196.12 652.619 193.219 2 417.242 1248.75 416.725 3 824.906 2356.25 809.56 ``` Pull Request resolved: #2485 Reviewed By: xuzhao9 Differential Revision: D63859871 Pulled By: FindHao fbshipit-source-id: 4b73a2144702c1f8f3ae5ed15e76112d03f12b87
- Loading branch information
1 parent
a1f4b2e
commit dde8528
Showing
4 changed files
with
131 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
[build-system] | ||
# Use legacy backend to import local packages in setup.py | ||
build-backend = "setuptools.build_meta:__legacy__" | ||
|
||
|
||
[tool.black] | ||
line-length = 88 | ||
target-version = ["py38"] | ||
exclude = '''/submodules/.*''' | ||
|
||
[tool.usort] | ||
excludes = ["**/submodules/**"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .operator import Operator |
108 changes: 108 additions & 0 deletions
108
torchbenchmark/operators/FusedLinearCrossEntropy/operator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import argparse | ||
from typing import Callable, Generator, List, Optional | ||
|
||
import torch | ||
|
||
from torchbenchmark.util.triton_op import BenchmarkOperator, register_benchmark | ||
|
||
try: | ||
from liger_kernel.transformers.fused_linear_cross_entropy import ( | ||
LigerFusedLinearCrossEntropyLoss, | ||
) | ||
except ModuleNotFoundError: | ||
LigerFusedLinearCrossEntropyLoss = None | ||
|
||
# Reference: https://github.com/linkedin/Liger-Kernel/blob/\ | ||
# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py | ||
|
||
|
||
def parse_op_args(args: List[str]): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--hidden-size", type=int, default=4096, help="hidden size") | ||
parser.add_argument("--vocab-size", type=int, default=128256, help="vocab size") | ||
return parser.parse_args(args) | ||
|
||
|
||
class TorchLMHeadCE(torch.nn.Module): | ||
"""Ground truth implementation of the linear fused with torch based cross entropy loss. | ||
:param H: hidden size | ||
:param V: vocab size | ||
:param ignore_index: index to ignore | ||
:param reduction: reduction method | ||
""" | ||
|
||
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): | ||
super().__init__() | ||
self.lin = torch.nn.Linear( | ||
in_features=H, out_features=V, bias=False, dtype=dtype | ||
) | ||
self.ce_loss = torch.nn.CrossEntropyLoss( | ||
ignore_index=ignore_index, reduction="mean" | ||
) | ||
|
||
def forward(self, input, target): | ||
logits = self.lin(input) | ||
return self.ce_loss(logits, target) | ||
|
||
|
||
class LigerLMHeadCE(torch.nn.Module): | ||
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): | ||
super().__init__() | ||
self.lin = torch.nn.Linear( | ||
in_features=H, out_features=V, bias=False, dtype=dtype | ||
) | ||
self.ce_loss = LigerFusedLinearCrossEntropyLoss( | ||
ignore_index=ignore_index, reduction="mean" | ||
) | ||
|
||
def forward(self, input, target): | ||
return self.ce_loss(self.lin.weight, input, target) | ||
|
||
|
||
class Operator(BenchmarkOperator): | ||
def __init__( | ||
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None | ||
): | ||
super().__init__(tb_args, extra_args) | ||
op_args = parse_op_args(self.extra_args) | ||
self.hidden_size = op_args.hidden_size | ||
self.vocab_size = op_args.vocab_size | ||
self.baseline_model = TorchLMHeadCE( | ||
H=self.hidden_size, V=self.vocab_size, dtype=self.dtype | ||
).to(self.device) | ||
self.liger_model = LigerLMHeadCE( | ||
H=self.hidden_size, V=self.vocab_size, dtype=self.dtype | ||
).to(self.device) | ||
self.use_cuda_graphs = False | ||
|
||
def get_input_iter(self) -> Generator: | ||
for BT in [2**i for i in range(12, 16)]: | ||
_input = torch.randn( | ||
BT, | ||
self.hidden_size, | ||
requires_grad=True, | ||
dtype=self.dtype, | ||
device=self.device, | ||
) | ||
target = torch.randint( | ||
self.vocab_size, (BT, 1), dtype=torch.long, device=self.device | ||
).squeeze(1) | ||
yield _input, target | ||
|
||
@register_benchmark(baseline=True) | ||
def LMHeadCE(self, input, target) -> Callable: | ||
return lambda: self.baseline_model(input, target) | ||
|
||
@register_benchmark() | ||
def LigerLMHeadCE(self, input, target) -> Callable: | ||
return lambda: self.liger_model(input, target) | ||
|
||
@register_benchmark() | ||
def inductor_fused_linear_cross_entropy(self, input, target) -> Callable: | ||
compiled = torch.compile(self.baseline_model, dynamic=False) | ||
return lambda: compiled(input, target) | ||
|
||
def get_bwd_fn(self, fwd_fn: Callable) -> Callable: | ||
y = fwd_fn() | ||
return lambda: y.backward(retain_graph=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters