From b6019117feaa71a93ff7678504cbe7cc9e4b67d3 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Wed, 22 Nov 2023 11:12:52 +0000 Subject: [PATCH] Add transb support for fp8 b --- .../contrib_ops/rocm/math/gemm_float8.cu | 76 ++++++++++--------- .../contrib_ops/rocm/math/gemm_float8_ck.cuh | 34 ++++++--- .../math/gemm_float8_ck_impl/add_instance.cu | 46 ++++++++--- ...xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu | 9 +-- ...k_f16_f8_f16_mk_kn_mn_instance_original.cu | 9 +-- ...xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu | 76 +++++++++++++++++++ ...k_f8_f16_f16_mk_kn_mn_instance_original.cu | 9 +-- .../kernels/gemm_float8_test.py | 16 ++-- .../kernels/rocm/gemm_float8_ck.cu | 4 +- 9 files changed, 197 insertions(+), 82 deletions(-) create mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu index 016006d9d27ff..9468948d4235e 100644 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu @@ -21,11 +21,6 @@ class GemmFloat8 final : public RocmKernel { dtype_ = info.GetAttrOrDefault("dtype", onnx::TensorProto_DataType_FLOAT16); alpha_ = info.GetAttrOrDefault("alpha", 1); beta_ = info.GetAttrOrDefault("beta", 0); - - tunable_op_fp8e4m3fn_fp16_fp16_ = std::make_unique(); - tunable_op_fp8e4m3fnuz_fp16_fp16_ = std::make_unique(); - tunable_op_fp16_fp8e4m3fn_fp16_ = std::make_unique(); - tunable_op_fp16_fp8e4m3fnuz_fp16_ = std::make_unique(); } Status ComputeInternal(OpKernelContext* ctx) const override; @@ -35,21 +30,25 @@ class GemmFloat8 final : public RocmKernel { template Status ComputeFp16Fp8Fp16(OpKernelContext* ctx, const Tensor* A, const Tensor* B, const Tensor* scaleB, Tensor* C) const; - template - [[nodiscard]] inline auto& GetOp() const { - if constexpr (std::is_same_v) { - if constexpr (IsAFp8) { - return tunable_op_fp8e4m3fn_fp16_fp16_; - } else { - return tunable_op_fp16_fp8e4m3fn_fp16_; - } - } else if constexpr (std::is_same_v) { - if constexpr (IsAFp8) { - return tunable_op_fp8e4m3fnuz_fp16_fp16_; - } else { - return tunable_op_fp16_fp8e4m3fnuz_fp16_; - } + template + inline auto MaybeCreateTypeErasedSharedPtr() const { + + } + + template + [[nodiscard]] inline auto* GetOp() const { + using OpT = F8GemmTunableOp; + if (tunable_op_) { + return static_cast(tunable_op_.get()); } + + auto create = std::make_unique(); // avoid new + tunable_op_ = std::shared_ptr(create.release(), [](void* ptr) { + auto release = std::unique_ptr(); // avoid delete + release.reset(static_cast(ptr)); + }); + + return static_cast(tunable_op_.get()); } float alpha_; @@ -58,10 +57,8 @@ class GemmFloat8 final : public RocmKernel { bool transB_; int64_t dtype_; - std::unique_ptr> tunable_op_fp8e4m3fn_fp16_fp16_; - std::unique_ptr> tunable_op_fp8e4m3fnuz_fp16_fp16_; - std::unique_ptr> tunable_op_fp16_fp8e4m3fn_fp16_; - std::unique_ptr> tunable_op_fp16_fp8e4m3fnuz_fp16_; + // fully type erased + mutable std::shared_ptr tunable_op_; }; Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { @@ -81,7 +78,7 @@ Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { output_shape[output_shape.size() - 1] = b_shape[b_shape.NumDimensions() - 1]; Tensor* Y = ctx->Output(0, output_shape); - ORT_ENFORCE(!transA_ && !transB_, "ROCm GemmFloat8 does not support input transpose"); + ORT_ENFORCE(!transA_, "ROCm GemmFloat8 does not support input A transpose"); ORT_ENFORCE(dtype_ == onnx::TensorProto_DataType_FLOAT16, "ROCm GemmFloat8 only supports output float16"); ORT_ENFORCE(C == nullptr, "ROCm GemmFloat8 does not support bias input"); ORT_ENFORCE(scale_y == nullptr, "ROCm GemmFloat8 does not support output scaling"); @@ -114,20 +111,20 @@ Status GemmFloat8::ComputeFp8Fp16Fp16(OpKernelContext* ctx, const Tensor* A, con params.tuning_ctx = GetTuningContext(); params.stream = ctx->GetComputeStream(); params.handle = GetRocblasHandle(ctx); - params.opa = tunable::blas::BlasOp::NonTrans; - params.opb = tunable::blas::BlasOp::NonTrans; + params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; + params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; params.m = m; params.n = n; params.k = k; params.a = static_cast(A->DataRaw()); - params.lda = k; + params.lda = transA_ ? m : k; params.scale_a = alpha_; params.scale_a_dev = static_cast(scale_a->DataRaw()); params.b = static_cast(B->DataRaw()); - params.ldb = n; + params.ldb = transB_ ? k : n; params.scale_b = 1.0f; // NOTE: not used params.scale_b_dev = nullptr; // NOTE: not used @@ -136,7 +133,13 @@ Status GemmFloat8::ComputeFp8Fp16Fp16(OpKernelContext* ctx, const Tensor* A, con params.scale_c = 1.0f; // NOTE: not implemented params.scale_c_dev = nullptr; // NOTE: not implemented - return (*GetOp())(¶ms); + // NOTE: transA is not implemented + if (transB_) { + ORT_NOT_IMPLEMENTED("transB is not implemented"); + // return (*GetOp())(¶ms); + } else { + return (*GetOp())(¶ms); + } } template @@ -154,20 +157,20 @@ Status GemmFloat8::ComputeFp16Fp8Fp16(OpKernelContext* ctx, const Tensor* A, con params.tuning_ctx = GetTuningContext(); params.stream = ctx->GetComputeStream(); params.handle = GetRocblasHandle(ctx); - params.opa = tunable::blas::BlasOp::NonTrans; - params.opb = tunable::blas::BlasOp::NonTrans; + params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; + params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; params.m = m; params.n = n; params.k = k; params.a = static_cast(A->DataRaw()); - params.lda = k; + params.lda = transA_ ? m : k; params.scale_a = 1.0f; // NOTE: not used params.scale_a_dev = nullptr; // NOTE: not used params.b = static_cast(B->DataRaw()); - params.ldb = n; + params.ldb = transB_ ? k : n; params.scale_b = alpha_; params.scale_b_dev = static_cast(scale_b->DataRaw()); @@ -176,7 +179,12 @@ Status GemmFloat8::ComputeFp16Fp8Fp16(OpKernelContext* ctx, const Tensor* A, con params.scale_c = 1.0f; // NOTE: not implemented params.scale_c_dev = nullptr; // NOTE: not implemented - return (*GetOp())(¶ms); + // NOTE: transA is not implemented + if (transB_) { + return (*GetOp())(¶ms); + } else { + return (*GetOp())(¶ms); + } } ONNX_OPERATOR_KERNEL_EX( diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh index 63dce1ffd2989..72d50b46cbcd0 100644 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh @@ -24,6 +24,13 @@ namespace onnxruntime { namespace rocm { namespace tunable { +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + template constexpr bool always_false = false; @@ -32,7 +39,7 @@ struct Scale { constexpr const static bool is_pack2_invocable = true; constexpr const static bool is_pack4_invocable = true; - explicit Scale(float scale_value, const float* dev_scale_ptr) : scale_value_{scale_value}, dev_scale_ptr_{dev_scale_ptr} {} + explicit Scale(float scale_value, const float* dev_scale_ptr) : scale_value_{scale_value}, dev_scale_ptr_{dev_scale_ptr} {} template __forceinline__ __host__ __device__ Y fast_type_convert(X x) const { @@ -141,8 +148,6 @@ struct GemmFloat8Params : tunable::OpParams { int64_t ldc; }; -namespace internal { - #ifdef USE_COMPOSABLE_KERNEL using Row = ck::tensor_layout::gemm::RowMajor; @@ -166,6 +171,14 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( std::vector, Nop>>>& instances); +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( + std::vector, Nop>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( + std::vector, Nop>>>& instances); + template auto CreateOp(float scale, const float* dev_scale) { if constexpr (std::is_same_v) { @@ -197,12 +210,17 @@ auto GetCKF8SplitKGemmTypeStringAndOps() { for (auto num_split : {1, 4, 16, 64}) { std::vector> instances{}; // only supports fp8_fp16_fp16_row_row_row and fp16_fp8_fp16_row_row_row now. - if constexpr (std::is_same_v && std::is_same_v && std::is_same_v) { + if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) { add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(instances); - } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v) { + } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) { add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(instances); + } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) { + add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(instances); } else { - static_assert(always_false, "no instances for the type combination"); + static_assert(always_false, "no instances for the type combination"); LOGS_DEFAULT(FATAL) << "no instances for the type combination"; } for (auto&& impl : instances) { @@ -230,14 +248,12 @@ auto GetCKF8SplitKGemmTypeStringAndOps() { #endif // USE_COMPOSABLE_KERNEL -} // namespace internal - template class F8GemmTunableOp : public TunableOp> { public: F8GemmTunableOp() { #ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : internal::GetCKF8SplitKGemmTypeStringAndOps()) { + for (auto&& [_, op] : GetCKF8SplitKGemmTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu index d2622e0337ba4..63c147006f825 100644 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu @@ -13,7 +13,6 @@ namespace onnxruntime { namespace rocm { namespace tunable { namespace blas { -namespace internal { using F8 = ck::f8_t; using F16 = ck::half_t; @@ -27,6 +26,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +namespace internal { void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( std::vector, PassThrough>>>& instances); @@ -42,23 +42,25 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( std::vector, PassThrough>>>& instances); +} // namespace internal void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( std::vector, PassThrough>>>& instances) { - add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); - add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); } void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( std::vector, PassThrough>>>& instances) { - add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); - add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); } +namespace internal { void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( std::vector, PassThrough, PassThrough>>>& instances); @@ -76,22 +78,48 @@ void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( // void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort( // std::vector, PassThrough, PassThrough>>>& instances); +} // namespace internal void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( std::vector, PassThrough, PassThrough>>>& instances) { - add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances); - // add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO: + internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances); + // internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO: } void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( std::vector, PassThrough, PassThrough>>>& instances) { - add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances); - // add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO: + internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances); + // internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO: } +namespace internal { +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( + std::vector, PassThrough>>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( + std::vector, PassThrough>>>& + instances); } // namespace internal + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( + std::vector, PassThrough>>>& + instances) { + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances); +} + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( + std::vector, PassThrough>>>& + instances) { + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances); +} + } // namespace blas } // namespace tunable } // namespace rocm diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu index d4410c7bb56c9..e336d06cee2f9 100644 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu @@ -19,13 +19,6 @@ namespace tunable { namespace blas { namespace internal { -using F8 = ck::f8_t; -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - template using S = ck::Sequence; @@ -33,7 +26,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -#define DeviceGemmXdlSplitKCShuffle ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle +using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; // The derived version is simply double BBlockTransferSrcScalarPerVector and adjust other values correspondingly template diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu index ddbc154029179..9133e8572005e 100644 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu @@ -19,13 +19,6 @@ namespace tunable { namespace blas { namespace internal { -using F8 = ck::f8_t; -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - template using S = ck::Sequence; @@ -33,7 +26,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -#define DeviceGemmXdlSplitKCShuffle ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle +using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] template diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu new file mode 100644 index 0000000000000..f56a3f6ae42f2 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Modifications Copyright (c) Microsoft. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { +namespace internal { + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( + std::vector, PassThrough>>>& + instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); +} + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( + std::vector, PassThrough>>>& + instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); +} + +} // namespace internal +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu index 65dd0fa5fb8be..c3f419f49a3db 100644 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu @@ -19,13 +19,6 @@ namespace tunable { namespace blas { namespace internal { -using F8 = ck::f8_t; -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - template using S = ck::Sequence; @@ -33,7 +26,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -#define DeviceGemmXdlSplitKCShuffle ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle +using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] template diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py index 05c4ca63f01f2..4bdecac979afb 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py @@ -66,8 +66,8 @@ def _test_gemm( np.random.seed(0) - a, scale_a = cast_and_scale(np.random.randn(*a_shape), dta) - b, scale_b = cast_and_scale(np.random.randn(*b_shape), dtb) + a, scale_a = cast_and_scale(np.random.rand(*a_shape), dta) + b, scale_b = cast_and_scale(np.random.rand(*b_shape), dtb) scale_c = float("nan") inv_scale_a = np.array(1 / scale_a).astype("float32") @@ -113,8 +113,8 @@ def _test_gemm( # TODO: how to derive the bound for fp8? atol = 0.01 - rtol = 0.001 - print(f"{dta} {dtb} {dtc} {transab_to_suffix((transa, transb))} m={m} n={n} k={k} atol={atol} rtol={rtol}") + rtol = 0.005 + print(f"atol={atol} rtol={rtol}") # print for pytest -s -v for impl in my_gemm.ListOps(): if not my_gemm.SelectOp(impl): @@ -144,13 +144,15 @@ def _test_gemm( ("float16", "float8_e4m3fn", "float16"), ("float16", "float8_e4m3fnuz", "float16"), ] -all_transabs = [(False, False)] +all_transabs = [(False, False), (False, True)] @pytest.mark.skipif(not ke.is_composable_kernel_available(), reason="ck is not enabled") @pytest.mark.parametrize( "m, n, k", [ + (1, 768, 768), + (768, 768, 768), (1, 8192, 28672), (1, 28672, 8192), (1, 8192, 8192), @@ -162,6 +164,8 @@ def _test_gemm( @pytest.mark.parametrize("transa, transb", all_transabs) @pytest.mark.parametrize("dta, dtb, dtc", dtypes) def test_ck_gemm(dta, dtb, dtc, transa, transb, m, n, k): + if dtb == "float16" and transb: + pytest.skip("Only supports transb when b is fp8") wrapper_name = f"GemmFloat8CK_{dtype_to_suffix(dta)}_{dtype_to_suffix(dtb)}_{dtype_to_suffix(dtc)}_{transab_to_suffix((transa, transb))}" _test_gemm(getattr(ke, wrapper_name), dta, dtb, dtc, transa, transb, m, n, k) @@ -171,6 +175,8 @@ def test_ck_gemm(dta, dtb, dtc, transa, transb, m, n, k): @pytest.mark.parametrize("transa, transb", all_transabs) @pytest.mark.parametrize("dta, dtb, dtc", dtypes) def test_ck_gemm_alpha_beta(dta, dtb, dtc, transa, transb, m, n, k, alpha, beta): + if dtb == "float16" and transb: + pytest.skip("Only supports transb when b is fp8") wrapper_name = f"GemmFloat8CK_{dtype_to_suffix(dta)}_{dtype_to_suffix(dtb)}_{dtype_to_suffix(dtc)}_{transab_to_suffix((transa, transb))}" _test_gemm(getattr(ke, wrapper_name), dta, dtb, dtc, transa, transb, m, n, k, alpha, beta) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8_ck.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8_ck.cu index 214144d324450..a0ad01ddf77b0 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8_ck.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8_ck.cu @@ -16,7 +16,6 @@ #include "python/tools/kernel_explorer/kernel_explorer_interface.h" using namespace onnxruntime::rocm::tunable::blas; -using namespace onnxruntime::rocm::tunable::blas::internal; namespace py = pybind11; @@ -128,6 +127,9 @@ KE_REGISTER(m) { REGISTER_GEMM_FLOAT8_CK("GemmFloat8CK_half_fp8e4m3fn_half_NN", half, Float8E4M3FN, half, Row, Row); REGISTER_GEMM_FLOAT8_CK("GemmFloat8CK_fp8e4m3fnuz_half_half_NN", Float8E4M3FNUZ, half, half, Row, Row); REGISTER_GEMM_FLOAT8_CK("GemmFloat8CK_half_fp8e4m3fnuz_half_NN", half, Float8E4M3FNUZ, half, Row, Row); + + REGISTER_GEMM_FLOAT8_CK("GemmFloat8CK_half_fp8e4m3fn_half_NT", half, Float8E4M3FN, half, Row, Col); + REGISTER_GEMM_FLOAT8_CK("GemmFloat8CK_half_fp8e4m3fnuz_half_NT", half, Float8E4M3FNUZ, half, Row, Col); } #endif // USE_COMPOSABLE_KERNEL