From b227a68a03b53e047ced5c068ea72d0f697e91f4 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Thu, 30 Nov 2023 15:46:54 -0800 Subject: [PATCH] Add test script for pt2 batch fusion (#2018) Summary: Add test case and torchinductor option for pt2 batch fusion Reviewed By: ckluk2 Differential Revision: D47609030 --- torchbenchmark/util/backends/torchdynamo.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchbenchmark/util/backends/torchdynamo.py b/torchbenchmark/util/backends/torchdynamo.py index 650da74bf8..9f0e12bf7c 100644 --- a/torchbenchmark/util/backends/torchdynamo.py +++ b/torchbenchmark/util/backends/torchdynamo.py @@ -88,6 +88,11 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace: action='store_true', help="set to generate unique triton kernel names in Inductor" ) + parser.add_argument( + "--torchinductor_post_grad_batch_fusion", + type=distutils.util.strtobool, + help="Enable BMM Linear Fusion." + ) parser.add_argument( "--dynamo_disable_optimizer_step", type=distutils.util.strtobool, @@ -119,6 +124,8 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar if args.torchdynamo == "inductor": import torch._inductor as torchinductor torchinductor.config.triton.cudagraphs = bool(args.torchinductor_cudagraph) + if bool(args.torchinductor_post_grad_batch_fusion): + torchinductor.config.post_grad_fusion_options["batch_linear"] = {} torch._inductor.config.debug = bool(args.dump_triton) # Setup torchinductor.config.triton.mm