Skip to content

Commit

Permalink
Add tunable binding
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Dec 5, 2023
1 parent 5b6deb5 commit c564bae
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 38 deletions.
19 changes: 10 additions & 9 deletions onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -195,18 +195,21 @@ auto CreateOp(float scale, const float* dev_scale) {
}
}

template <typename TA, typename TB, typename TC, typename ALayout, typename BLayout>
template <typename TA, typename TB, typename TC, BlasOp LayoutOpA, BlasOp LayoutOpB>
auto GetCKF8SplitKGemmTypeStringAndOps() {
using CKTA = typename CKDataTypeAdaptor<TA>::type;
using CKTB = typename CKDataTypeAdaptor<TB>::type;
using CKTC = typename CKDataTypeAdaptor<TC>::type;

using CKLayoutA = typename CKBlasOpAdaptor<LayoutOpA>::type;
using CKLayoutB = typename CKBlasOpAdaptor<LayoutOpB>::type;

using OpA = std::conditional_t<std::is_same_v<CKTA, ck::f8_t>, Scale<TA>, Nop>;
using OpB = std::conditional_t<std::is_same_v<CKTB, ck::f8_t>, Scale<TB>, Nop>;
using OpC = std::conditional_t<std::is_same_v<CKTC, ck::f8_t>, Scale<TC>, Nop>;

using DeviceGemm = ck::tensor_operation::device::DeviceGemmSplitK<
ALayout, BLayout, Row,
CKLayoutA, CKLayoutB, Row,
CKTA, CKTB, CKTC,
OpA, OpB, OpC>;

Expand All @@ -215,16 +218,16 @@ auto GetCKF8SplitKGemmTypeStringAndOps() {
for (auto num_split : {1, 4, 16, 64}) {
std::vector<std::unique_ptr<DeviceGemm>> instances{};

Check warning on line 219 in onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh#L219

Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
Raw output
onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh:219:  Add #include <memory> for unique_ptr<>  [build/include_what_you_use] [4]
if constexpr (std::is_same_v<CKTA, ck::f8_t> && std::is_same_v<CKTB, ck::half_t> && std::is_same_v<CKTC, ck::half_t> &&

Check warning on line 220 in onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh#L220

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh:220:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
std::is_same_v<ALayout, Row> && std::is_same_v<BLayout, Row>) {
std::is_same_v<CKLayoutA, Row> && std::is_same_v<CKLayoutB, Row>) {
add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(instances);
} else if constexpr (std::is_same_v<CKTA, ck::half_t> && std::is_same_v<CKTB, ck::f8_t> && std::is_same_v<CKTC, ck::half_t> &&

Check warning on line 223 in onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh#L223

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh:223:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

Check warning on line 223 in onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh#L223

If an else has a brace on one side, it should have it on both [readability/braces] [5]
Raw output
onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh:223:  If an else has a brace on one side, it should have it on both  [readability/braces] [5]
std::is_same_v<ALayout, Row> && std::is_same_v<BLayout, Row>) {
std::is_same_v<CKLayoutA, Row> && std::is_same_v<CKLayoutB, Row>) {
add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(instances);
} else if constexpr (std::is_same_v<CKTA, ck::half_t> && std::is_same_v<CKTB, ck::f8_t> && std::is_same_v<CKTC, ck::half_t> &&

Check warning on line 226 in onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh#L226

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh:226:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

Check warning on line 226 in onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh#L226

If an else has a brace on one side, it should have it on both [readability/braces] [5]
Raw output
onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh:226:  If an else has a brace on one side, it should have it on both  [readability/braces] [5]
std::is_same_v<ALayout, Row> && std::is_same_v<BLayout, Col>) {
std::is_same_v<CKLayoutA, Row> && std::is_same_v<CKLayoutB, Col>) {
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(instances);
} else {
static_assert(always_false<CKTA, CKTB, CKTC, ALayout, BLayout>, "no instances for the type combination");
static_assert(always_false<CKTA, CKTB, CKTC, CKLayoutA, CKLayoutB>, "no instances for the type combination");
LOGS_DEFAULT(FATAL) << "no instances for the type combination";
}
for (auto&& impl : instances) {
Expand Down Expand Up @@ -257,9 +260,7 @@ class GemmFloat8TunableOp : public TunableOp<GemmFloat8Params<TA, TB, TC>> {
public:
GemmFloat8TunableOp() {
#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES)
using ALayout = std::conditional_t<OpA == BlasOp::NonTrans, Row, Col>;
using BLayout = std::conditional_t<OpB == BlasOp::NonTrans, Row, Col>;
for (auto&& [_, op] : GetCKF8SplitKGemmTypeStringAndOps<TA, TB, TC, ALayout, BLayout>()) {
for (auto&& [_, op] : GetCKF8SplitKGemmTypeStringAndOps<TA, TB, TC, OpA, OpB>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/core/providers/rocm/composable_kernel_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,24 @@

#ifdef USE_COMPOSABLE_KERNEL
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#endif

#include "core/framework/float8.h"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/tunable/gemm_common.h"

namespace onnxruntime {
namespace rocm {

#ifdef USE_COMPOSABLE_KERNEL
template <tunable::blas::BlasOp Op>
struct CKBlasOpAdaptor {
using type = std::conditional_t<Op == tunable::blas::BlasOp::NonTrans,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor>;
};

template <typename T>
struct CKDataTypeAdaptor {
using type = T;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,19 @@ def test_ck_gemm_alpha_beta(dta, dtb, dtc, transa, transb, m, n, k, alpha, beta)
_test_gemm(getattr(ke, wrapper_name), dta, dtb, dtc, transa, transb, m, n, k, alpha, beta)


@pytest.mark.skipif(not ke.is_float8_available(), reason="float8 is not enabled")
@pytest.mark.skipif(not ke.is_composable_kernel_available(), reason="ck is not enabled")
@pytest.mark.parametrize("alpha, beta", [(1.5, 0.0), [2.0, 0.0]])
@pytest.mark.parametrize("m, n, k", [(256, 256, 256)])
@pytest.mark.parametrize("transa, transb", all_transabs)
@pytest.mark.parametrize("dta, dtb, dtc", dtypes)
def test_tunable_gemm(dta, dtb, dtc, transa, transb, m, n, k, alpha, beta):
if dtb == "float16" and transb:
pytest.skip("Only supports transb when b is fp8")
wrapper_name = f"GemmFloat8Tunable_{dtype_to_suffix(dta)}_{dtype_to_suffix(dtb)}_{dtype_to_suffix(dtc)}_{transab_to_suffix((transa, transb))}"
_test_gemm(getattr(ke, wrapper_name), dta, dtb, dtc, transa, transb, m, n, k, alpha, beta)


@dataclass
class GemmMetric(ke.BandwidthMetric, ke.ComputeMetric):
transa: bool
Expand Down Expand Up @@ -258,7 +271,9 @@ def profile_with_args(dta, dtb, dtc, transa, transb, m, n, k, sort):
profile_gemm_func(
getattr(ke, "GemmFloat8CK" + dtype_suffix + transab_suffix), dta, dtb, dtc, transa, transb, m, n, k
)
# profile_gemm_func(getattr(ke, "GemmTunable" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k)
profile_gemm_func(
getattr(ke, "GemmFloat8Tunable" + dtype_suffix + transab_suffix), dta, dtb, dtc, transa, transb, m, n, k
)
print()


Expand Down
128 changes: 100 additions & 28 deletions onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace py = pybind11;
namespace onnxruntime {

#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES)
template <typename TA, typename TB, typename TC, typename ALayout, typename BLayout>
template <typename TA, typename TB, typename TC, BlasOp OpA, BlasOp OpB>
class GemmFloat8CK : public IKernelExplorer {
public:
GemmFloat8CK(BlasOp opa, BlasOp opb,
Expand All @@ -31,11 +31,8 @@ class GemmFloat8CK : public IKernelExplorer {
DeviceArray& a, int64_t lda, DeviceArray& scale_a,
DeviceArray& b, int64_t ldb, DeviceArray& scale_b,
float beta,
DeviceArray& c, int64_t ldc, DeviceArray& scale_c)
: params_{} {
auto supports_a = opa == BlasOp::N ? std::is_same_v<ALayout, Row> : std::is_same_v<ALayout, Col>;
auto supports_b = opb == BlasOp::N ? std::is_same_v<BLayout, Row> : std::is_same_v<BLayout, Col>;
ORT_ENFORCE(supports_a && supports_b);
DeviceArray& c, int64_t ldc, DeviceArray& scale_c) {
ORT_ENFORCE(opa == OpA && opb == OpB);

params_.tuning_ctx = TuningContext();
params_.stream = Stream();
Expand Down Expand Up @@ -69,7 +66,7 @@ class GemmFloat8CK : public IKernelExplorer {
params_.scale_c_dev = static_cast<float*>(scale_c.ptr());
}

for (auto&& [type_string, op] : GetCKF8SplitKGemmTypeStringAndOps<TA, TB, TC, ALayout, BLayout>()) {
for (auto&& [type_string, op] : GetCKF8SplitKGemmTypeStringAndOps<TA, TB, TC, OpA, OpB>()) {
type_strings_.emplace_back(std::move(type_string));
ops_.emplace_back(std::move(op));
}
Expand Down Expand Up @@ -105,31 +102,106 @@ class GemmFloat8CK : public IKernelExplorer {
size_t selected_op_{};
};

#define REGISTER_OP_COMMON(registered_name, tpl, dta, dtb, dtc, alayout, blayout) \
py::class_<tpl<dta, dtb, dtc, alayout, blayout>>(m, registered_name) \
.def("SetRepeats", &tpl<dta, dtb, dtc, alayout, blayout>::SetRepeats) \
.def("Profile", &tpl<dta, dtb, dtc, alayout, blayout>::Profile) \
.def("Run", &tpl<dta, dtb, dtc, alayout, blayout>::Run) \
.def("ListOps", &tpl<dta, dtb, dtc, alayout, blayout>::ListOps) \
.def("SelectOp", &tpl<dta, dtb, dtc, alayout, blayout>::SelectOp)

#define REGISTER_GEMM_FLOAT8_CK(registered_name, dta, dtb, dtc, alayout, blayout) \
REGISTER_OP_COMMON(registered_name, GemmFloat8CK, dta, dtb, dtc, alayout, blayout) \
.def(py::init<BlasOp, BlasOp, int64_t, int64_t, int64_t, \
float, \
DeviceArray&, int64_t, DeviceArray&, \
DeviceArray&, int64_t, DeviceArray&, \
float, \
template <typename TA, typename TB, typename TC, BlasOp OpA, BlasOp OpB>
class GemmFloat8Tunable : public IKernelExplorer {
public:
GemmFloat8Tunable(BlasOp opa, BlasOp opb,
int64_t m, int64_t n, int64_t k,
float alpha,
DeviceArray& a, int64_t lda, DeviceArray& scale_a,
DeviceArray& b, int64_t ldb, DeviceArray& scale_b,
float beta,
DeviceArray& c, int64_t ldc, DeviceArray& scale_c) {
ORT_ENFORCE(opa == OpA && opb == OpB);

params_.tuning_ctx = TuningContext();
params_.stream = Stream();
// rocblas handle is not used for ck
params_.handle = nullptr;
params_.opa = opa;
params_.opb = opb;
params_.m = m;
params_.n = n;
params_.k = k;

params_.a = static_cast<TA*>(a.ptr());
params_.lda = lda;
if constexpr (std::is_same_v<TA, Float8E4M3FN> || std::is_same_v<TA, Float8E4M3FNUZ>) {
params_.scale_a = alpha;
params_.scale_a_dev = static_cast<float*>(scale_a.ptr());
}

params_.b = static_cast<TB*>(b.ptr());
params_.ldb = ldb;
if constexpr (std::is_same_v<TB, Float8E4M3FN> || std::is_same_v<TB, Float8E4M3FNUZ>) {
params_.scale_b = alpha;
params_.scale_b_dev = static_cast<float*>(scale_b.ptr());
}

params_.c = static_cast<TC*>(c.ptr());
params_.ldc = ldc;
if constexpr (std::is_same_v<TC, Float8E4M3FN> || std::is_same_v<TC, Float8E4M3FNUZ>) {
ORT_ENFORCE(false, "Not implemented");
params_.scale_c = beta;
params_.scale_c_dev = static_cast<float*>(scale_c.ptr());
}

params_.TuningContext()->EnableTunableOpAndTuning();
}

void Run() override {
ORT_THROW_IF_ERROR(op_(&params_));
}

std::vector<std::string> ListOps() const {
return {"Tunable"};
}

bool SelectOp(const std::string& name) {
return name == "Tunable";
}

private:
using ParamsT = GemmFloat8Params<TA, TB, TC>;
using OpT = GemmFloat8TunableOp<TA, TB, TC, OpA, OpB>;
ParamsT params_{};
OpT op_;
};

#define REGISTER_GEMM_FLOAT8(registered_name, tpl, dta, dtb, dtc, opa, opb) \
py::class_<tpl<dta, dtb, dtc, opa, opb>>(m, registered_name) \
.def("SetRepeats", &tpl<dta, dtb, dtc, opa, opb>::SetRepeats) \
.def("Profile", &tpl<dta, dtb, dtc, opa, opb>::Profile) \
.def("Run", &tpl<dta, dtb, dtc, opa, opb>::Run) \
.def("ListOps", &tpl<dta, dtb, dtc, opa, opb>::ListOps) \
.def("SelectOp", &tpl<dta, dtb, dtc, opa, opb>::SelectOp) \
.def(py::init<BlasOp, BlasOp, int64_t, int64_t, int64_t, \
float, \
DeviceArray&, int64_t, DeviceArray&, \
DeviceArray&, int64_t, DeviceArray&, \
float, \
DeviceArray&, int64_t, DeviceArray&>());

KE_REGISTER(m) {
REGISTER_GEMM_FLOAT8_CK("GemmFloat8CK_fp8e4m3fn_half_half_NN", Float8E4M3FN, half, half, Row, Row);
REGISTER_GEMM_FLOAT8_CK("GemmFloat8CK_half_fp8e4m3fn_half_NN", half, Float8E4M3FN, half, Row, Row);
REGISTER_GEMM_FLOAT8_CK("GemmFloat8CK_fp8e4m3fnuz_half_half_NN", Float8E4M3FNUZ, half, half, Row, Row);
REGISTER_GEMM_FLOAT8_CK("GemmFloat8CK_half_fp8e4m3fnuz_half_NN", half, Float8E4M3FNUZ, half, Row, Row);
using BlasOp = rocm::tunable::blas::BlasOp;
REGISTER_GEMM_FLOAT8("GemmFloat8CK_fp8e4m3fn_half_half_NN", GemmFloat8CK, Float8E4M3FN, half, half, BlasOp::N, BlasOp::N);

Check warning on line 187 in onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu#L187

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu:187:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
REGISTER_GEMM_FLOAT8("GemmFloat8CK_half_fp8e4m3fn_half_NN", GemmFloat8CK, half, Float8E4M3FN, half, BlasOp::N, BlasOp::N);

Check warning on line 188 in onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu#L188

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu:188:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
REGISTER_GEMM_FLOAT8("GemmFloat8CK_fp8e4m3fnuz_half_half_NN", GemmFloat8CK, Float8E4M3FNUZ, half, half, BlasOp::N, BlasOp::N);

Check warning on line 189 in onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu#L189

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu:189:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
REGISTER_GEMM_FLOAT8("GemmFloat8CK_half_fp8e4m3fnuz_half_NN", GemmFloat8CK, half, Float8E4M3FNUZ, half, BlasOp::N, BlasOp::N);

Check warning on line 190 in onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu#L190

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu:190:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

REGISTER_GEMM_FLOAT8("GemmFloat8CK_half_fp8e4m3fn_half_NT", GemmFloat8CK, half, Float8E4M3FN, half, BlasOp::N, BlasOp::T);

Check warning on line 192 in onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu#L192

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu:192:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
REGISTER_GEMM_FLOAT8("GemmFloat8CK_half_fp8e4m3fnuz_half_NT", GemmFloat8CK, half, Float8E4M3FNUZ, half, BlasOp::N, BlasOp::T);

Check warning on line 193 in onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu#L193

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu:193:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}

REGISTER_GEMM_FLOAT8_CK("GemmFloat8CK_half_fp8e4m3fn_half_NT", half, Float8E4M3FN, half, Row, Col);
REGISTER_GEMM_FLOAT8_CK("GemmFloat8CK_half_fp8e4m3fnuz_half_NT", half, Float8E4M3FNUZ, half, Row, Col);
KE_REGISTER(m) {
using BlasOp = rocm::tunable::blas::BlasOp;
REGISTER_GEMM_FLOAT8("GemmFloat8Tunable_fp8e4m3fn_half_half_NN", GemmFloat8Tunable, Float8E4M3FN, half, half, BlasOp::N, BlasOp::N);

Check warning on line 198 in onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu#L198

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu:198:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
REGISTER_GEMM_FLOAT8("GemmFloat8Tunable_half_fp8e4m3fn_half_NN", GemmFloat8Tunable, half, Float8E4M3FN, half, BlasOp::N, BlasOp::N);

Check warning on line 199 in onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu#L199

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu:199:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
REGISTER_GEMM_FLOAT8("GemmFloat8Tunable_fp8e4m3fnuz_half_half_NN", GemmFloat8Tunable, Float8E4M3FNUZ, half, half, BlasOp::N, BlasOp::N);

Check warning on line 200 in onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu#L200

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu:200:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
REGISTER_GEMM_FLOAT8("GemmFloat8Tunable_half_fp8e4m3fnuz_half_NN", GemmFloat8Tunable, half, Float8E4M3FNUZ, half, BlasOp::N, BlasOp::N);

Check warning on line 201 in onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu#L201

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu:201:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

REGISTER_GEMM_FLOAT8("GemmFloat8Tunable_half_fp8e4m3fn_half_NT", GemmFloat8Tunable, half, Float8E4M3FN, half, BlasOp::N, BlasOp::T);

Check warning on line 203 in onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu#L203

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu:203:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
REGISTER_GEMM_FLOAT8("GemmFloat8Tunable_half_fp8e4m3fnuz_half_NT", GemmFloat8Tunable, half, Float8E4M3FNUZ, half, BlasOp::N, BlasOp::T);

Check warning on line 204 in onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu#L204

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_float8.cu:204:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
#endif

Expand Down

0 comments on commit c564bae

Please sign in to comment.