From 07b375876ee996758238141a46881f9c883770f5 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Tue, 28 Nov 2023 05:29:12 +0000 Subject: [PATCH] Minor --- .../contrib_ops/rocm/math/gemm_float8.cu | 25 ++++++++++++------- .../contrib_ops/rocm/math/gemm_float8_ck.cuh | 8 +++--- .../math/gemm_float8_ck_impl/add_instance.cu | 18 +++++-------- ...xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu | 6 ++--- ...k_f16_f8_f16_mk_kn_mn_instance_original.cu | 12 +++------ ...xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu | 6 ++--- ...k_f8_f16_f16_mk_kn_mn_instance_original.cu | 6 ++--- 7 files changed, 36 insertions(+), 45 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu index fb9344f3fd780..d7d98dfa15756 100644 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu @@ -135,13 +135,16 @@ Status GemmFloat8::ComputeFp8Fp16Fp16(OpKernelContext* ctx, const Tensor* A, con params.scale_c = 1.0f; // NOTE: not implemented params.scale_c_dev = nullptr; // NOTE: not implemented - // NOTE: transA is not implemented - if (transB_) { - ORT_NOT_IMPLEMENTED("transB is not implemented"); - // return (*GetOp())(¶ms); - } else { + if (!transA_ && !transB_) { return (*GetOp())(¶ms); + } else if (transA_ && !transB_) { + ORT_NOT_IMPLEMENTED("transA is not implemented"); + } else if (!transA_ && transB_) { + ORT_NOT_IMPLEMENTED("transB is not implemented"); + } else if (transA_ && transB_) { + ORT_NOT_IMPLEMENTED("transA & transB is not implemented"); } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unreachable"); } template @@ -181,12 +184,16 @@ Status GemmFloat8::ComputeFp16Fp8Fp16(OpKernelContext* ctx, const Tensor* A, con params.scale_c = 1.0f; // NOTE: not implemented params.scale_c_dev = nullptr; // NOTE: not implemented - // NOTE: transA is not implemented - if (transB_) { - return (*GetOp())(¶ms); - } else { + if (!transA_ && !transB_) { return (*GetOp())(¶ms); + } else if (transA_ && !transB_) { + ORT_NOT_IMPLEMENTED("transA is not implemented"); + } else if (!transA_ && transB_) { + return (*GetOp())(¶ms); + } else if (transA_ && transB_) { + ORT_NOT_IMPLEMENTED("transA & transB is not implemented"); } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unreachable"); } #define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() #else diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh index 8536b5d0433a6..010962f77b27c 100644 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh @@ -162,19 +162,19 @@ using Nop = ck::tensor_operation::element_wise::PassThrough; void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( std::vector, Nop, Nop>>>& instances); + Row, Row, Row, F8, F16, F16, Scale, Nop, Nop>>>& instances); void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( std::vector, Nop, Nop>>>& instances); + Row, Row, Row, F8, F16, F16, Scale, Nop, Nop>>>& instances); void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( std::vector, Nop>>>& instances); + Row, Row, Row, F16, F8, F16, Nop, Scale, Nop>>>& instances); void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( std::vector, Nop>>>& instances); + Row, Row, Row, F16, F8, F16, Nop, Scale, Nop>>>& instances); void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( std::vector, PassThrough>>>& - instances) { + Row, Row, Row, F16, F8, F16, PassThrough, Scale, PassThrough>>>& instances) { internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); } void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( std::vector, PassThrough>>>& - instances) { + Row, Row, Row, F16, F8, F16, PassThrough, Scale, PassThrough>>>& instances) { internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); } @@ -99,26 +97,22 @@ void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( namespace internal { void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( std::vector, PassThrough>>>& - instances); + Row, Col, Row, F16, F8, F16, PassThrough, Scale, PassThrough>>>& instances); void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( std::vector, PassThrough>>>& - instances); + Row, Col, Row, F16, F8, F16, PassThrough, Scale, PassThrough>>>& instances); } // namespace internal void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( std::vector, PassThrough>>>& - instances) { + Row, Col, Row, F16, F8, F16, PassThrough, Scale, PassThrough>>>& instances) { internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances); } void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( std::vector, PassThrough>>>& - instances) { + Row, Col, Row, F16, F8, F16, PassThrough, Scale, PassThrough>>>& instances) { internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances); } diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu index 2bb4f06bd67f4..49463e58886f8 100644 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu @@ -72,8 +72,7 @@ using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort = std::tuple< void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( std::vector, PassThrough>>>& - instances) { + Row, Row, Row, F16, F8, F16, PassThrough, Scale, PassThrough>>>& instances) { ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort{}); ck::tensor_operation::device::instance::add_device_operation_instances( @@ -82,8 +81,7 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( std::vector, PassThrough>>>& - instances) { + Row, Row, Row, F16, F8, F16, PassThrough, Scale, PassThrough>>>& instances) { ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort{}); ck::tensor_operation::device::instance::add_device_operation_instances( diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu index 5a1d98c51e8ff..236e5555051fc 100644 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu @@ -59,20 +59,16 @@ using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck = std::tuple< void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( std::vector, PassThrough>>>& - instances) { + Row, Row, Row, F16, F8, F16, PassThrough, Scale, PassThrough>>>& instances) { ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck{}); + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck{}); } void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( std::vector, PassThrough>>>& - instances) { + Row, Row, Row, F16, F8, F16, PassThrough, Scale, PassThrough>>>& instances) { ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck{}); + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck{}); } } // namespace internal diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu index 63e27c34f3c91..1a0d45df82a71 100644 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu @@ -69,8 +69,7 @@ using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances = std::tuple< void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( std::vector, PassThrough>>>& - instances) { + Row, Col, Row, F16, F8, F16, PassThrough, Scale, PassThrough>>>& instances) { ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); ck::tensor_operation::device::instance::add_device_operation_instances( @@ -79,8 +78,7 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( std::vector, PassThrough>>>& - instances) { + Row, Col, Row, F16, F8, F16, PassThrough, Scale, PassThrough>>>& instances) { ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); ck::tensor_operation::device::instance::add_device_operation_instances( diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu index 9fc6c12722d35..a0628802ec09e 100644 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu @@ -72,8 +72,7 @@ using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck = std::tuple< void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( std::vector, PassThrough, PassThrough>>>& - instances) { + Row, Row, Row, F8, F16, F16, Scale, PassThrough, PassThrough>>>& instances) { ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck{}); ck::tensor_operation::device::instance::add_device_operation_instances( @@ -82,8 +81,7 @@ void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( std::vector, PassThrough, PassThrough>>>& - instances) { + Row, Row, Row, F8, F16, F16, Scale, PassThrough, PassThrough>>>& instances) { ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck{}); ck::tensor_operation::device::instance::add_device_operation_instances(