From 488a8d15ffa091ff8aefa0841c98d2824b7027ca Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 28 Nov 2023 18:33:00 -0800 Subject: [PATCH] Add test script for pt2 batch fusion (#2018) Summary: Add test case and torchinductor option for pt2 batch fusion Reviewed By: ckluk2 Differential Revision: D47609030 --- torchbenchmark/util/backends/torchdynamo.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchbenchmark/util/backends/torchdynamo.py b/torchbenchmark/util/backends/torchdynamo.py index 650da74bf8..b075a95fac 100644 --- a/torchbenchmark/util/backends/torchdynamo.py +++ b/torchbenchmark/util/backends/torchdynamo.py @@ -88,6 +88,11 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace: action='store_true', help="set to generate unique triton kernel names in Inductor" ) + parser.add_argument( + "--torchinductor_post_grad_batch_fusion", + type=distutils.util.strtobool, + help="Enable BMM Linear Fusion." + ) parser.add_argument( "--dynamo_disable_optimizer_step", type=distutils.util.strtobool, @@ -119,6 +124,7 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar if args.torchdynamo == "inductor": import torch._inductor as torchinductor torchinductor.config.triton.cudagraphs = bool(args.torchinductor_cudagraph) + torchinductor.config.post_grad_batch_fusion = bool(args.torchinductor_post_grad_batch_fusion) torch._inductor.config.debug = bool(args.dump_triton) # Setup torchinductor.config.triton.mm