Skip to content

Commit

Permalink
Add layout options to gemm
Browse files Browse the repository at this point in the history
Summary:
We were only benchmarking `row-major x row-major` gemms (also called
`TT` or `transpose-transpose`, because FORTRAN), which is actually not the
common case; `nn.Linear` will use column-major layouts for weights, which means
`TN` is actually much more common.

Reviewed By: adamomainz

Differential Revision: D63714661

fbshipit-source-id: 735c25c59ddeb6596afd9b19f463af92036a830b
  • Loading branch information
bertmaher authored and facebook-github-bot committed Oct 1, 2024
1 parent d512e67 commit 4445aa2
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 37 deletions.
34 changes: 0 additions & 34 deletions torchbenchmark/operators/gemm/data_io.py

This file was deleted.

38 changes: 37 additions & 1 deletion torchbenchmark/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torch._inductor.config as inductor_config
import triton

from torchbenchmark import REPO_PATH

from torchbenchmark.util.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
Expand All @@ -19,7 +21,6 @@
register_x_val,
)

from .data_io import parse_args, read_shapes_from_csv
from .kernels import matmul as kernels
from .partition_k import matmul_partition_k
from .persistent_matmul import (
Expand Down Expand Up @@ -88,6 +89,35 @@
]


def parse_args(args: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="TorchBench Gemm operator Benchmark")
parser.add_argument("--m", type=int)
parser.add_argument("--k", type=int)
parser.add_argument("--n", type=int)
parser.add_argument("--bias", type=int)
parser.add_argument("--input", type=str)
parser.add_argument("--splitk", action="store_true", default=False)
parser.add_argument("--llama", action="store_true", default=False)
parser.add_argument("--layout", type=str, default="tn")
args = parser.parse_args(args)
return args


def read_shapes_from_csv(csv_path: str) -> List[List[int]]:
input_file_path = os.path.join(
REPO_PATH, "torchbenchmark", "operators", "gemm", csv_path
)
shapes = []
with open(input_file_path, "r") as f:
reader = csv.DictReader(f)
for row in reader:
shape = [
int(row.get(f)) if row.get(f) else None for f in ("M", "N", "K", "Bias")
]
shapes.append(shape)
return shapes


class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["speedup", "tflops"]
DEFAULT_PRECISION = "fp16"
Expand All @@ -98,6 +128,7 @@ def __init__(
super().__init__(tb_args, extra_args)
self.use_cuda_graphs = False
gemm_args = parse_args(self.extra_args)
self.layout = gemm_args.layout
if gemm_args.input:
self.shapes = read_shapes_from_csv(gemm_args.input)
elif gemm_args.splitk:
Expand Down Expand Up @@ -261,6 +292,11 @@ def get_input_iter(self) -> Generator:
w = self._scaled_randn(
(k, n), scale=k, device=self.device, dtype=self.dtype
)
# Convert inputs to column-major if layout is "n" (non-transposed)
if self.layout[0] == "n":
a = a.T.contiguous().T
if self.layout[1] == "n":
w = w.T.contiguous().T
if not bias == None:
bias = torch.randn(
(bias), device=self.device, dtype=self.dtype
Expand Down
2 changes: 0 additions & 2 deletions torchbenchmark/operators/gemm/triton_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,6 @@ def leaky_relu(x):
def matmul(a, b, activation=""):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
assert b.is_contiguous(), "Matrix B must be contiguous"
M, K = a.shape
K, N = b.shape
# Allocates output.
Expand Down

0 comments on commit 4445aa2

Please sign in to comment.