From f2932b743a73c7c946ec0bf19c3c6f7d6937fb80 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Tue, 1 Oct 2024 17:18:22 -0700 Subject: [PATCH] Enable fp8 rowwise on AMDGPU (#2483) Summary: Pull Request resolved: https://github.com/pytorch/benchmark/pull/2483 Reviewed By: karthik-man Differential Revision: D63726031 fbshipit-source-id: dc410e503f918d83362fb38005ac4a6db5dc1e68 --- torchbenchmark/operators/fp8_gemm_rowwise/operator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchbenchmark/operators/fp8_gemm_rowwise/operator.py b/torchbenchmark/operators/fp8_gemm_rowwise/operator.py index b90698028..c536c9c00 100644 --- a/torchbenchmark/operators/fp8_gemm_rowwise/operator.py +++ b/torchbenchmark/operators/fp8_gemm_rowwise/operator.py @@ -33,6 +33,7 @@ def parse_args(args: List[str]) -> argparse.Namespace: try: from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import ( + get_fp8_constants as get_fp8_constants, matmul_fp8_row as triton_fp8_row, ) @@ -52,7 +53,7 @@ def parse_args(args: List[str]) -> argparse.Namespace: from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import scale_fp8_row HAS_CUBLAS = True -except ImportError: +except (ImportError, IOError, AttributeError): HAS_CUBLAS = False @@ -79,7 +80,8 @@ def parse_args(args: List[str]) -> argparse.Namespace: (16384, 8192, 13312), ] -E4M3_MAX_POS: float = torch.finfo(torch.float8_e4m3fn).max +FP8_DTYPE, _, _, _ = get_fp8_constants() +E4M3_MAX_POS: float = torch.finfo(FP8_DTYPE).max EPS: float = 1e-12 FP16_MAX_POS: float = torch.finfo(torch.float16).max @@ -91,7 +93,7 @@ def fp8_row_quantize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if x.dtype is torch.float16: scale = torch.clamp(scale, max=FP16_MAX_POS) xq = torch.clamp(x * scale[:, None], min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS).to( - torch.float8_e4m3fn + FP8_DTYPE ) return xq, scale.reciprocal().to(torch.float32)