Skip to content

Commit

Permalink
add benchmark function for PT2 compile
Browse files Browse the repository at this point in the history
Summary:
# context
* provide a benchmark function entry point and command-line for running torch.compile with a baby model containing EBC
* current supported arguments:
```
    rank: int = 0,
    world_size: int = 2,
    num_features: int = 5,
    batch_size: int = 10,
```
* run command
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=32
```

# results
* on a DevGPU machine
```
rank: 0, world_size: 2, num_features: 16, batch_size: 10, time: 6.65s
rank: 0, world_size: 2, num_features: 32, batch_size: 10, time: 10.99s
rank: 0, world_size: 2, num_features: 64, batch_size: 10, time: 61.55s
rank: 0, world_size: 2, num_features: 128, batch_size: 10, time: 429.14s
```

Differential Revision: D57501708
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Aug 12, 2024
1 parent 927d7db commit 12b2a22
Showing 1 changed file with 56 additions and 4 deletions.
60 changes: 56 additions & 4 deletions torchrec/distributed/tests/test_pt2_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@

#!/usr/bin/env python3

import timeit
import unittest
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple

import click
import fbgemm_gpu.sparse_ops # noqa: F401, E402

import torch
import torchrec
import torchrec.pt2.checks
Expand Down Expand Up @@ -504,16 +505,16 @@ def get_weights(dmp: DistributedModelParallel) -> torch.Tensor:
def _test_compile_fake_pg_fn(
rank: int,
world_size: int,
num_features: int = 5,
batch_size: int = 10,
num_embeddings: int = 256,
) -> None:
sharding_type = ShardingType.TABLE_WISE.value
input_type = _InputType.SINGLE_BATCH
torch_compile_backend = "eager"
config = _TestConfig()
num_embeddings = 256
# emb_dim must be % 4 == 0 for fbgemm
emb_dim = 12
batch_size = 10
num_features: int = 5

num_float_features: int = 8
num_weighted_features: int = 1
Expand Down Expand Up @@ -773,3 +774,54 @@ def test_compile_multiprocess_fake_pg(
rank=0,
world_size=2,
)


@click.command()
@click.option(
"--repeat",
type=int,
default=1,
help="repeat times",
)
@click.option(
"--rank",
type=int,
default=0,
help="rank in the test",
)
@click.option(
"--world-size",
type=int,
default=2,
help="world_size in the test",
)
@click.option(
"--num-features",
type=int,
default=5,
help="num_features in the test",
)
@click.option(
"--batch-size",
type=int,
default=10,
help="batch_size in the test",
)
def compile_benchmark(
rank: int, world_size: int, num_features: int, batch_size: int, repeat: int
) -> None:
run: str = (
f"_test_compile_fake_pg_fn(rank={rank}, world_size={world_size}, "
f"num_features={num_features}, batch_size={batch_size})"
)
print("*" * 20 + " compile_benchmark started " + "*" * 20)
t = timeit.timeit(stmt=run, number=repeat, globals=globals())
print("*" * 20 + " compile_benchmark completed " + "*" * 20)
print(
f"rank: {rank}, world_size: {world_size}, "
f"num_features: {num_features}, batch_size: {batch_size}, time: {t:.2f}s"
)


if __name__ == "__main__":
compile_benchmark()

0 comments on commit 12b2a22

Please sign in to comment.