Skip to content

Commit

Permalink
Enable fp8 rowwise on AMDGPU (#2483)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2483

Reviewed By: karthik-man

Differential Revision: D63726031

fbshipit-source-id: dc410e503f918d83362fb38005ac4a6db5dc1e68
  • Loading branch information
htyu authored and facebook-github-bot committed Oct 2, 2024
1 parent 4445aa2 commit f2932b7
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions torchbenchmark/operators/fp8_gemm_rowwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def parse_args(args: List[str]) -> argparse.Namespace:

try:
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
get_fp8_constants as get_fp8_constants,
matmul_fp8_row as triton_fp8_row,
)

Expand All @@ -52,7 +53,7 @@ def parse_args(args: List[str]) -> argparse.Namespace:
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import scale_fp8_row

HAS_CUBLAS = True
except ImportError:
except (ImportError, IOError, AttributeError):
HAS_CUBLAS = False


Expand All @@ -79,7 +80,8 @@ def parse_args(args: List[str]) -> argparse.Namespace:
(16384, 8192, 13312),
]

E4M3_MAX_POS: float = torch.finfo(torch.float8_e4m3fn).max
FP8_DTYPE, _, _, _ = get_fp8_constants()
E4M3_MAX_POS: float = torch.finfo(FP8_DTYPE).max
EPS: float = 1e-12
FP16_MAX_POS: float = torch.finfo(torch.float16).max

Expand All @@ -91,7 +93,7 @@ def fp8_row_quantize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if x.dtype is torch.float16:
scale = torch.clamp(scale, max=FP16_MAX_POS)
xq = torch.clamp(x * scale[:, None], min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS).to(
torch.float8_e4m3fn
FP8_DTYPE
)
return xq, scale.reciprocal().to(torch.float32)

Expand Down

0 comments on commit f2932b7

Please sign in to comment.