diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh index 1ef1f5bac16d6..571936fc5f038 100644 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh @@ -195,18 +195,21 @@ auto CreateOp(float scale, const float* dev_scale) { } } -template +template auto GetCKF8SplitKGemmTypeStringAndOps() { using CKTA = typename CKDataTypeAdaptor::type; using CKTB = typename CKDataTypeAdaptor::type; using CKTC = typename CKDataTypeAdaptor::type; + using CKLayoutA = typename CKBlasOpAdaptor::type; + using CKLayoutB = typename CKBlasOpAdaptor::type; + using OpA = std::conditional_t, Scale, Nop>; using OpB = std::conditional_t, Scale, Nop>; using OpC = std::conditional_t, Scale, Nop>; using DeviceGemm = ck::tensor_operation::device::DeviceGemmSplitK< - ALayout, BLayout, Row, + CKLayoutA, CKLayoutB, Row, CKTA, CKTB, CKTC, OpA, OpB, OpC>; @@ -215,16 +218,16 @@ auto GetCKF8SplitKGemmTypeStringAndOps() { for (auto num_split : {1, 4, 16, 64}) { std::vector> instances{}; if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) { + std::is_same_v && std::is_same_v) { add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(instances); } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) { + std::is_same_v && std::is_same_v) { add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(instances); } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) { + std::is_same_v && std::is_same_v) { add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(instances); } else { - static_assert(always_false, "no instances for the type combination"); + static_assert(always_false, "no instances for the type combination"); LOGS_DEFAULT(FATAL) << "no instances for the type combination"; } for (auto&& impl : instances) { @@ -257,9 +260,7 @@ class GemmFloat8TunableOp : public TunableOp> { public: GemmFloat8TunableOp() { #if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - using ALayout = std::conditional_t; - using BLayout = std::conditional_t; - for (auto&& [_, op] : GetCKF8SplitKGemmTypeStringAndOps()) { + for (auto&& [_, op] : GetCKF8SplitKGemmTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } diff --git a/onnxruntime/core/providers/rocm/composable_kernel_common.h b/onnxruntime/core/providers/rocm/composable_kernel_common.h index 90f4613a986ae..6f504995e40a3 100644 --- a/onnxruntime/core/providers/rocm/composable_kernel_common.h +++ b/onnxruntime/core/providers/rocm/composable_kernel_common.h @@ -5,15 +5,24 @@ #ifdef USE_COMPOSABLE_KERNEL #include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #endif #include "core/framework/float8.h" #include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/tunable/gemm_common.h" namespace onnxruntime { namespace rocm { #ifdef USE_COMPOSABLE_KERNEL +template +struct CKBlasOpAdaptor { + using type = std::conditional_t; +}; + template struct CKDataTypeAdaptor { using type = T; diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py index c73e73211b3a7..19a1008b3947a 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py @@ -174,6 +174,19 @@ def test_ck_gemm_alpha_beta(dta, dtb, dtc, transa, transb, m, n, k, alpha, beta) _test_gemm(getattr(ke, wrapper_name), dta, dtb, dtc, transa, transb, m, n, k, alpha, beta) +@pytest.mark.skipif(not ke.is_float8_available(), reason="float8 is not enabled") +@pytest.mark.skipif(not ke.is_composable_kernel_available(), reason="ck is not enabled") +@pytest.mark.parametrize("alpha, beta", [(1.5, 0.0), [2.0, 0.0]]) +@pytest.mark.parametrize("m, n, k", [(256, 256, 256)]) +@pytest.mark.parametrize("transa, transb", all_transabs) +@pytest.mark.parametrize("dta, dtb, dtc", dtypes) +def test_tunable_gemm(dta, dtb, dtc, transa, transb, m, n, k, alpha, beta): + if dtb == "float16" and transb: + pytest.skip("Only supports transb when b is fp8") + wrapper_name = f"GemmFloat8Tunable_{dtype_to_suffix(dta)}_{dtype_to_suffix(dtb)}_{dtype_to_suffix(dtc)}_{transab_to_suffix((transa, transb))}" + _test_gemm(getattr(ke, wrapper_name), dta, dtb, dtc, transa, transb, m, n, k, alpha, beta) + + @dataclass class GemmMetric(ke.BandwidthMetric, ke.ComputeMetric): transa: bool @@ -258,7 +271,9 @@ def profile_with_args(dta, dtb, dtc, transa, transb, m, n, k, sort): profile_gemm_func( getattr(ke, "GemmFloat8CK" + dtype_suffix + transab_suffix), dta, dtb, dtc, transa, transb, m, n, k ) - # profile_gemm_func(getattr(ke, "GemmTunable" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k) + profile_gemm_func( + getattr(ke, "GemmFloat8Tunable" + dtype_suffix + transab_suffix), dta, dtb, dtc, transa, transb, m, n, k + ) print() diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu index e927200fc50c1..2d78f390af84a 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu @@ -22,7 +22,7 @@ namespace py = pybind11; namespace onnxruntime { #if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) -template +template class GemmFloat8CK : public IKernelExplorer { public: GemmFloat8CK(BlasOp opa, BlasOp opb, @@ -31,11 +31,8 @@ class GemmFloat8CK : public IKernelExplorer { DeviceArray& a, int64_t lda, DeviceArray& scale_a, DeviceArray& b, int64_t ldb, DeviceArray& scale_b, float beta, - DeviceArray& c, int64_t ldc, DeviceArray& scale_c) - : params_{} { - auto supports_a = opa == BlasOp::N ? std::is_same_v : std::is_same_v; - auto supports_b = opb == BlasOp::N ? std::is_same_v : std::is_same_v; - ORT_ENFORCE(supports_a && supports_b); + DeviceArray& c, int64_t ldc, DeviceArray& scale_c) { + ORT_ENFORCE(opa == OpA && opb == OpB); params_.tuning_ctx = TuningContext(); params_.stream = Stream(); @@ -69,7 +66,7 @@ class GemmFloat8CK : public IKernelExplorer { params_.scale_c_dev = static_cast(scale_c.ptr()); } - for (auto&& [type_string, op] : GetCKF8SplitKGemmTypeStringAndOps()) { + for (auto&& [type_string, op] : GetCKF8SplitKGemmTypeStringAndOps()) { type_strings_.emplace_back(std::move(type_string)); ops_.emplace_back(std::move(op)); } @@ -105,31 +102,106 @@ class GemmFloat8CK : public IKernelExplorer { size_t selected_op_{}; }; -#define REGISTER_OP_COMMON(registered_name, tpl, dta, dtb, dtc, alayout, blayout) \ - py::class_>(m, registered_name) \ - .def("SetRepeats", &tpl::SetRepeats) \ - .def("Profile", &tpl::Profile) \ - .def("Run", &tpl::Run) \ - .def("ListOps", &tpl::ListOps) \ - .def("SelectOp", &tpl::SelectOp) - -#define REGISTER_GEMM_FLOAT8_CK(registered_name, dta, dtb, dtc, alayout, blayout) \ - REGISTER_OP_COMMON(registered_name, GemmFloat8CK, dta, dtb, dtc, alayout, blayout) \ - .def(py::init +class GemmFloat8Tunable : public IKernelExplorer { + public: + GemmFloat8Tunable(BlasOp opa, BlasOp opb, + int64_t m, int64_t n, int64_t k, + float alpha, + DeviceArray& a, int64_t lda, DeviceArray& scale_a, + DeviceArray& b, int64_t ldb, DeviceArray& scale_b, + float beta, + DeviceArray& c, int64_t ldc, DeviceArray& scale_c) { + ORT_ENFORCE(opa == OpA && opb == OpB); + + params_.tuning_ctx = TuningContext(); + params_.stream = Stream(); + // rocblas handle is not used for ck + params_.handle = nullptr; + params_.opa = opa; + params_.opb = opb; + params_.m = m; + params_.n = n; + params_.k = k; + + params_.a = static_cast(a.ptr()); + params_.lda = lda; + if constexpr (std::is_same_v || std::is_same_v) { + params_.scale_a = alpha; + params_.scale_a_dev = static_cast(scale_a.ptr()); + } + + params_.b = static_cast(b.ptr()); + params_.ldb = ldb; + if constexpr (std::is_same_v || std::is_same_v) { + params_.scale_b = alpha; + params_.scale_b_dev = static_cast(scale_b.ptr()); + } + + params_.c = static_cast(c.ptr()); + params_.ldc = ldc; + if constexpr (std::is_same_v || std::is_same_v) { + ORT_ENFORCE(false, "Not implemented"); + params_.scale_c = beta; + params_.scale_c_dev = static_cast(scale_c.ptr()); + } + + params_.TuningContext()->EnableTunableOpAndTuning(); + } + + void Run() override { + ORT_THROW_IF_ERROR(op_(¶ms_)); + } + + std::vector ListOps() const { + return {"Tunable"}; + } + + bool SelectOp(const std::string& name) { + return name == "Tunable"; + } + + private: + using ParamsT = GemmFloat8Params; + using OpT = GemmFloat8TunableOp; + ParamsT params_{}; + OpT op_; +}; + +#define REGISTER_GEMM_FLOAT8(registered_name, tpl, dta, dtb, dtc, opa, opb) \ + py::class_>(m, registered_name) \ + .def("SetRepeats", &tpl::SetRepeats) \ + .def("Profile", &tpl::Profile) \ + .def("Run", &tpl::Run) \ + .def("ListOps", &tpl::ListOps) \ + .def("SelectOp", &tpl::SelectOp) \ + .def(py::init()); KE_REGISTER(m) { - REGISTER_GEMM_FLOAT8_CK("GemmFloat8CK_fp8e4m3fn_half_half_NN", Float8E4M3FN, half, half, Row, Row); - REGISTER_GEMM_FLOAT8_CK("GemmFloat8CK_half_fp8e4m3fn_half_NN", half, Float8E4M3FN, half, Row, Row); - REGISTER_GEMM_FLOAT8_CK("GemmFloat8CK_fp8e4m3fnuz_half_half_NN", Float8E4M3FNUZ, half, half, Row, Row); - REGISTER_GEMM_FLOAT8_CK("GemmFloat8CK_half_fp8e4m3fnuz_half_NN", half, Float8E4M3FNUZ, half, Row, Row); + using BlasOp = rocm::tunable::blas::BlasOp; + REGISTER_GEMM_FLOAT8("GemmFloat8CK_fp8e4m3fn_half_half_NN", GemmFloat8CK, Float8E4M3FN, half, half, BlasOp::N, BlasOp::N); + REGISTER_GEMM_FLOAT8("GemmFloat8CK_half_fp8e4m3fn_half_NN", GemmFloat8CK, half, Float8E4M3FN, half, BlasOp::N, BlasOp::N); + REGISTER_GEMM_FLOAT8("GemmFloat8CK_fp8e4m3fnuz_half_half_NN", GemmFloat8CK, Float8E4M3FNUZ, half, half, BlasOp::N, BlasOp::N); + REGISTER_GEMM_FLOAT8("GemmFloat8CK_half_fp8e4m3fnuz_half_NN", GemmFloat8CK, half, Float8E4M3FNUZ, half, BlasOp::N, BlasOp::N); + + REGISTER_GEMM_FLOAT8("GemmFloat8CK_half_fp8e4m3fn_half_NT", GemmFloat8CK, half, Float8E4M3FN, half, BlasOp::N, BlasOp::T); + REGISTER_GEMM_FLOAT8("GemmFloat8CK_half_fp8e4m3fnuz_half_NT", GemmFloat8CK, half, Float8E4M3FNUZ, half, BlasOp::N, BlasOp::T); +} - REGISTER_GEMM_FLOAT8_CK("GemmFloat8CK_half_fp8e4m3fn_half_NT", half, Float8E4M3FN, half, Row, Col); - REGISTER_GEMM_FLOAT8_CK("GemmFloat8CK_half_fp8e4m3fnuz_half_NT", half, Float8E4M3FNUZ, half, Row, Col); +KE_REGISTER(m) { + using BlasOp = rocm::tunable::blas::BlasOp; + REGISTER_GEMM_FLOAT8("GemmFloat8Tunable_fp8e4m3fn_half_half_NN", GemmFloat8Tunable, Float8E4M3FN, half, half, BlasOp::N, BlasOp::N); + REGISTER_GEMM_FLOAT8("GemmFloat8Tunable_half_fp8e4m3fn_half_NN", GemmFloat8Tunable, half, Float8E4M3FN, half, BlasOp::N, BlasOp::N); + REGISTER_GEMM_FLOAT8("GemmFloat8Tunable_fp8e4m3fnuz_half_half_NN", GemmFloat8Tunable, Float8E4M3FNUZ, half, half, BlasOp::N, BlasOp::N); + REGISTER_GEMM_FLOAT8("GemmFloat8Tunable_half_fp8e4m3fnuz_half_NN", GemmFloat8Tunable, half, Float8E4M3FNUZ, half, BlasOp::N, BlasOp::N); + + REGISTER_GEMM_FLOAT8("GemmFloat8Tunable_half_fp8e4m3fn_half_NT", GemmFloat8Tunable, half, Float8E4M3FN, half, BlasOp::N, BlasOp::T); + REGISTER_GEMM_FLOAT8("GemmFloat8Tunable_half_fp8e4m3fnuz_half_NT", GemmFloat8Tunable, half, Float8E4M3FNUZ, half, BlasOp::N, BlasOp::T); } #endif