diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py index 4bdecac979afb..21fb00fd64a7e 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py @@ -4,23 +4,13 @@ # -------------------------------------------------------------------------- import sys -import ctypes from dataclasses import dataclass -from itertools import product import kernel_explorer as ke import numpy as np import pytest -from ml_dtypes import float8_e4m3fn, float8_e4m3fnuz, finfo -from utils import ( - dtype_to_suffix, - get_gemm_basic_sizes, - get_gemm_bert_sizes, - get_gemm_bound, - matmul, - transab_to_suffix, - dtype_to_bytes, -) +from ml_dtypes import finfo, float8_e4m3fn, float8_e4m3fnuz +from utils import dtype_to_bytes, dtype_to_suffix, get_gemm_bert_sizes, matmul, transab_to_suffix def create_device_array(a): @@ -45,7 +35,7 @@ def compute_scaling_factor(a: np.ndarray, fp8_max: float, margin: int) -> np.nda def cast_and_scale(a, dtype: str): if dtype == "float16": return a.astype(dtype), 1.0 - elif dtype in ("float8_e4m3fn", "float8_e4m3fnuz"): + elif np.dtype(dtype) in (float8_e4m3fn, float8_e4m3fnuz): t = globals()[dtype] sf = compute_scaling_factor(a, fp8_max=finfo(t).max, margin=4) return (a * sf).astype(t), sf diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/utils.py b/onnxruntime/python/tools/kernel_explorer/kernels/utils.py index 296cbe6c11765..cdbae640b05d5 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/utils.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/utils.py @@ -8,7 +8,6 @@ import numpy as np import scipy.special -from ml_dtypes import float8_e4m3fnuz def dtype_to_bytes(dtype):