From 12b2a22b723e5ac7e989be65f67451d9d2d6ce6d Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Mon, 12 Aug 2024 10:19:02 -0700 Subject: [PATCH] add benchmark function for PT2 compile Summary: # context * provide a benchmark function entry point and command-line for running torch.compile with a baby model containing EBC * current supported arguments: ``` rank: int = 0, world_size: int = 2, num_features: int = 5, batch_size: int = 10, ``` * run command ``` buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=32 ``` # results * on a DevGPU machine ``` rank: 0, world_size: 2, num_features: 16, batch_size: 10, time: 6.65s rank: 0, world_size: 2, num_features: 32, batch_size: 10, time: 10.99s rank: 0, world_size: 2, num_features: 64, batch_size: 10, time: 61.55s rank: 0, world_size: 2, num_features: 128, batch_size: 10, time: 429.14s ``` Differential Revision: D57501708 --- .../tests/test_pt2_multiprocess.py | 60 +++++++++++++++++-- 1 file changed, 56 insertions(+), 4 deletions(-) diff --git a/torchrec/distributed/tests/test_pt2_multiprocess.py b/torchrec/distributed/tests/test_pt2_multiprocess.py index 6322a3959..f0d3d147d 100644 --- a/torchrec/distributed/tests/test_pt2_multiprocess.py +++ b/torchrec/distributed/tests/test_pt2_multiprocess.py @@ -9,13 +9,14 @@ #!/usr/bin/env python3 +import timeit import unittest from dataclasses import dataclass from enum import Enum from typing import Any, Dict, List, Optional, Tuple +import click import fbgemm_gpu.sparse_ops # noqa: F401, E402 - import torch import torchrec import torchrec.pt2.checks @@ -504,16 +505,16 @@ def get_weights(dmp: DistributedModelParallel) -> torch.Tensor: def _test_compile_fake_pg_fn( rank: int, world_size: int, + num_features: int = 5, + batch_size: int = 10, + num_embeddings: int = 256, ) -> None: sharding_type = ShardingType.TABLE_WISE.value input_type = _InputType.SINGLE_BATCH torch_compile_backend = "eager" config = _TestConfig() - num_embeddings = 256 # emb_dim must be % 4 == 0 for fbgemm emb_dim = 12 - batch_size = 10 - num_features: int = 5 num_float_features: int = 8 num_weighted_features: int = 1 @@ -773,3 +774,54 @@ def test_compile_multiprocess_fake_pg( rank=0, world_size=2, ) + + +@click.command() +@click.option( + "--repeat", + type=int, + default=1, + help="repeat times", +) +@click.option( + "--rank", + type=int, + default=0, + help="rank in the test", +) +@click.option( + "--world-size", + type=int, + default=2, + help="world_size in the test", +) +@click.option( + "--num-features", + type=int, + default=5, + help="num_features in the test", +) +@click.option( + "--batch-size", + type=int, + default=10, + help="batch_size in the test", +) +def compile_benchmark( + rank: int, world_size: int, num_features: int, batch_size: int, repeat: int +) -> None: + run: str = ( + f"_test_compile_fake_pg_fn(rank={rank}, world_size={world_size}, " + f"num_features={num_features}, batch_size={batch_size})" + ) + print("*" * 20 + " compile_benchmark started " + "*" * 20) + t = timeit.timeit(stmt=run, number=repeat, globals=globals()) + print("*" * 20 + " compile_benchmark completed " + "*" * 20) + print( + f"rank: {rank}, world_size: {world_size}, " + f"num_features: {num_features}, batch_size: {batch_size}, time: {t:.2f}s" + ) + + +if __name__ == "__main__": + compile_benchmark()