Skip to content

Commit

Permalink
Update test to allow adding ROCm EP
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Nov 23, 2023
1 parent 5e4f369 commit 910a998
Showing 1 changed file with 59 additions and 26 deletions.
85 changes: 59 additions & 26 deletions onnxruntime/test/python/onnxruntime_test_float8_gemm8.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
class TestFloat8Gemm8(unittest.TestCase):
def get_model_gemm(
self,
float_name,
a_float_name="FLOAT",
b_float_name="FLOAT",
c_float_name="FLOAT",
alpha=1.0,
beta=0.0,
transA=0,
Expand All @@ -35,8 +37,12 @@ def get_model_gemm(
dtype=TensorProto.FLOAT,
activation="NONE",
):
proto_type = getattr(TensorProto, float_name)
use_f8 = proto_type in (TensorProto.FLOAT8E4M3FN, TensorProto.FLOAT8E5M2)
a_proto_type = getattr(TensorProto, a_float_name)
b_proto_type = getattr(TensorProto, b_float_name)
c_proto_type = getattr(TensorProto, c_float_name)

f8_set = {TensorProto.FLOAT8E4M3FN, TensorProto.FLOAT8E5M2}
use_f8 = len({a_proto_type, b_proto_type, c_proto_type}.intersection(f8_set)) > 0

a = make_tensor_value_info("A", TensorProto.FLOAT, [None, None])
b = make_tensor_value_info("B", TensorProto.FLOAT, [None, None])
Expand Down Expand Up @@ -75,9 +81,9 @@ def get_model_gemm(
else:
op_name = "Gemm"
nodes = [
make_node("Cast", ["A"], ["Af"], to=proto_type),
make_node("Cast", ["B"], ["Bf"], to=proto_type),
make_node("Cast", ["C"], ["Cf"], to=proto_type) if bias else None,
make_node("Cast", ["A"], ["Af"], to=a_proto_type),
make_node("Cast", ["B"], ["Bf"], to=b_proto_type),
make_node("Cast", ["C"], ["Cf"], to=c_proto_type) if bias else None,
make_node(
op_name,
node_inputs,
Expand All @@ -100,7 +106,17 @@ def get_model_gemm(
check_model(onnx_model)
return onnx_model

def common_test_model_gemm(self, float_type, mul=0.33, atol=0, rtol=0, square=True, **kwargs):
def common_test_model_gemm(
self,
a_float_name="FLOAT",
b_float_name="FLOAT",
c_float_name="FLOAT",
mul=0.33,
atol=0,
rtol=0,
square=True,
**kwargs,
):
if square:
a = (np.arange(256) * 0.01).astype(np.float32).reshape((-1, 16))
b = (np.arange(256) * -0.01).astype(np.float32).reshape((-1, 16))
Expand All @@ -113,19 +129,28 @@ def common_test_model_gemm(self, float_type, mul=0.33, atol=0, rtol=0, square=Tr

feeds = {"A": a, "B": b}

providers = ["CPUExecutionProvider"]
if "CUDAExecutionProvider" in available_providers:
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
elif "ROCMExecutionProvider" in available_providers:
providers = ["ROCMExecutionProvider", "CPUExecutionProvider"]

expected = (a.T if kwargs.get("transA", 0) else a) @ (b.T if kwargs.get("transB", 0) else b)
expected *= kwargs.get("alpha", 1.0)
if kwargs.get("beta", 0) != 0:
expected += kwargs["beta"] * c
feeds["C"] = c

onnx_model = self.get_model_gemm("FLOAT", **kwargs)
onnx_model = self.get_model_gemm(**kwargs)

ref = InferenceSession(
onnx_model.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)
ref = InferenceSession(onnx_model.SerializeToString(), providers=providers)
y = ref.run(None, feeds)[0]
if float_type in ("FLOAT", "FLOAT16"):
if (
"CUDAExecutionProvider" in providers
and a_float_name in ("FLOAT", "FLOAT16")
and b_float_name in ("FLOAT", "FLOAT16")
and c_float_name in ("FLOAT", "FLOAT16")
):
try:
assert_allclose(expected, y, atol=atol, rtol=rtol)
except Exception as e:
Expand All @@ -151,14 +176,18 @@ def check(f):
f"\nkwargs={kwargs}"
) from e

self.assertEqual(expected.shape, y.shape)
self.assertEqual(expected.dtype, y.dtype)
self.assertEqual(expected.shape, y.shape)
self.assertEqual(expected.dtype, y.dtype)

onnx_model_f8 = self.get_model_gemm(float_type, domain="com.microsoft", **kwargs)
onnx_model_f8 = self.get_model_gemm(
a_float_name=a_float_name,
b_float_name=b_float_name,
c_float_name=c_float_name,
domain="com.microsoft",
**kwargs,
)
try:
ref8 = InferenceSession(
onnx_model_f8.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)
ref8 = InferenceSession(onnx_model_f8.SerializeToString(), providers=providers)
except Exception as e:
if "CUDA < 12.0 does not support bias" in str(e):
return
Expand Down Expand Up @@ -200,28 +229,30 @@ def check(f):

@unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_gemm_float(self):
self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3)
self.common_test_model_gemm(transA=1, rtol=1e-3)

@unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_gemm_float_default_values(self):
self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation=None)
self.common_test_model_gemm(transA=1, rtol=1e-3, activation=None)

@unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_gemm_float_relu(self):
self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation="RELU")
self.common_test_model_gemm(transA=1, rtol=1e-3, activation="RELU")

@unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_gemm_float_gelu(self):
self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation="GELU")
self.common_test_model_gemm(transA=1, rtol=1e-3, activation="GELU")

@unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_gemm_float_bias(self):
self.common_test_model_gemm("FLOAT", transA=1, beta=1.0, rtol=1e-3)
self.common_test_model_gemm( transA=1, beta=1.0, rtol=1e-3)

@unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_gemm_float16(self):
self.common_test_model_gemm(
"FLOAT16",
a_float_name="FLOAT16",
b_float_name="FLOAT16",
c_float_name="FLOAT16",
rtol=1e-2,
dtype=TensorProto.FLOAT16,
transB=1,
Expand All @@ -231,7 +262,9 @@ def test_model_gemm_float16(self):
@unittest.skipIf(not hasattr(TensorProto, "FLOAT8E4M3FN"), reason="needs onnx>=1.14.0")
def test_model_gemm_float8_e4m3(self):
self.common_test_model_gemm(
"FLOAT8E4M3FN",
a_float_name="FLOAT8E4M3FN",
b_float_name="FLOAT8E4M3FN",
c_float_name="FLOAT8E4M3FN",
rtol=0.5,
dtype=TensorProto.FLOAT,
transA=0,
Expand All @@ -242,7 +275,7 @@ def test_model_gemm_float8_e4m3(self):
@parameterized.parameterized.expand(list(itertools.product([0, 1], [0, 1])))
@unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_combinations_square_matrices(self, transA, transB):
self.common_test_model_gemm("FLOAT", transA=transA, transB=transB, rtol=1e-3)
self.common_test_model_gemm(transA=transA, transB=transB, rtol=1e-3)

@parameterized.parameterized.expand(
[
Expand Down

0 comments on commit 910a998

Please sign in to comment.