Skip to content

Commit

Permalink
Minor
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Nov 28, 2023
1 parent 32e976c commit 084325e
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 46 deletions.
25 changes: 16 additions & 9 deletions onnxruntime/contrib_ops/rocm/math/gemm_float8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,16 @@ Status GemmFloat8::ComputeFp8Fp16Fp16(OpKernelContext* ctx, const Tensor* A, con
params.scale_c = 1.0f; // NOTE: not implemented
params.scale_c_dev = nullptr; // NOTE: not implemented

// NOTE: transA is not implemented
if (transB_) {
ORT_NOT_IMPLEMENTED("transB is not implemented");
// return (*GetOp<Fp8T, MLFloat16, MLFloat16, BlasOp::NonTrans, BlasOp::Trans>())(&params);
} else {
if (!transA_ && !transB_) {
return (*GetOp<Fp8T, MLFloat16, MLFloat16, BlasOp::NonTrans, BlasOp::NonTrans>())(&params);
} else if (transA_ && !transB_) {
ORT_NOT_IMPLEMENTED("transA is not implemented");
} else if (!transA_ && transB_) {
ORT_NOT_IMPLEMENTED("transB is not implemented");
} else if (transA_ && transB_) {
ORT_NOT_IMPLEMENTED("transA & transB is not implemented");
}
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unreachable");
}

template <typename Fp8T>
Expand Down Expand Up @@ -181,12 +184,16 @@ Status GemmFloat8::ComputeFp16Fp8Fp16(OpKernelContext* ctx, const Tensor* A, con
params.scale_c = 1.0f; // NOTE: not implemented
params.scale_c_dev = nullptr; // NOTE: not implemented

// NOTE: transA is not implemented
if (transB_) {
return (*GetOp<MLFloat16, Fp8T, MLFloat16, BlasOp::NonTrans, BlasOp::Trans>())(&params);
} else {
if (!transA_ && !transB_) {
return (*GetOp<MLFloat16, Fp8T, MLFloat16, BlasOp::NonTrans, BlasOp::NonTrans>())(&params);
} else if (transA_ && !transB_) {
ORT_NOT_IMPLEMENTED("transA is not implemented");
} else if (!transA_ && transB_) {
return (*GetOp<MLFloat16, Fp8T, MLFloat16, BlasOp::NonTrans, BlasOp::Trans>())(&params);
} else if (transA_ && transB_) {
ORT_NOT_IMPLEMENTED("transA & transB is not implemented");
}
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unreachable");
}
#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints<MLFloat16, Float8E4M3FN, Float8E4M3FNUZ>()
#else
Expand Down
9 changes: 4 additions & 5 deletions onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,19 @@ using Nop = ck::tensor_operation::element_wise::PassThrough;

void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, ck::f8_t, ck::half_t, ck::half_t, Scale<Float8E4M3FN>, Nop, Nop>>>& instances);
Row, Row, Row, F8, F16, F16, Scale<Float8E4M3FN>, Nop, Nop>>>& instances);

void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, ck::f8_t, ck::half_t, ck::half_t, Scale<Float8E4M3FNUZ>, Nop, Nop>>>& instances);
Row, Row, Row, F8, F16, F16, Scale<Float8E4M3FNUZ>, Nop, Nop>>>& instances);

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, ck::half_t, ck::f8_t, ck::half_t, Nop, Scale<Float8E4M3FN>, Nop>>>& instances);
Row, Row, Row, F16, F8, F16, Nop, Scale<Float8E4M3FN>, Nop>>>& instances);

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, ck::half_t, ck::f8_t, ck::half_t, Nop, Scale<Float8E4M3FNUZ>, Nop>>>& instances);
Row, Row, Row, F16, F8, F16, Nop, Scale<Float8E4M3FNUZ>, Nop>>>& instances);

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Expand Down Expand Up @@ -214,7 +214,6 @@ auto GetCKF8SplitKGemmTypeStringAndOps() {

for (auto num_split : {1, 4, 16, 64}) {
std::vector<std::unique_ptr<DeviceGemm>> instances{};

Check warning on line 216 in onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh#L216

Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
Raw output
onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh:216:  Add #include <memory> for unique_ptr<>  [build/include_what_you_use] [4]
// only supports fp8_fp16_fp16_row_row_row and fp16_fp8_fp16_row_row_row now.
if constexpr (std::is_same_v<CKTA, ck::f8_t> && std::is_same_v<CKTB, ck::half_t> && std::is_same_v<CKTC, ck::half_t> &&

Check warning on line 217 in onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh#L217

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh:217:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
std::is_same_v<ALayout, Row> && std::is_same_v<BLayout, Row>) {
add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(instances);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,14 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FN>, PassThrough>>>&
instances) {
Row, Row, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FN>, PassThrough>>>& instances) {
internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances);
internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances);
}

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FNUZ>, PassThrough>>>&
instances) {
Row, Row, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FNUZ>, PassThrough>>>& instances) {
internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances);
internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances);
}
Expand Down Expand Up @@ -99,26 +97,22 @@ void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(
namespace internal {
void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Col, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FN>, PassThrough>>>&
instances);
Row, Col, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FN>, PassThrough>>>& instances);

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Col, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FNUZ>, PassThrough>>>&
instances);
Row, Col, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FNUZ>, PassThrough>>>& instances);
} // namespace internal

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Col, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FN>, PassThrough>>>&
instances) {
Row, Col, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FN>, PassThrough>>>& instances) {
internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances);
}

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Col, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FNUZ>, PassThrough>>>&
instances) {
Row, Col, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FNUZ>, PassThrough>>>& instances) {
internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort = std::tuple<

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FN>, PassThrough>>>&
instances) {
Row, Row, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FN>, PassThrough>>>& instances) {
ck::tensor_operation::device::instance::add_device_operation_instances(
instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort<Float8E4M3FN>{});
ck::tensor_operation::device::instance::add_device_operation_instances(
Expand All @@ -82,8 +81,7 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FNUZ>, PassThrough>>>&
instances) {
Row, Row, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FNUZ>, PassThrough>>>& instances) {
ck::tensor_operation::device::instance::add_device_operation_instances(
instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort<Float8E4M3FNUZ>{});
ck::tensor_operation::device::instance::add_device_operation_instances(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,16 @@ using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck = std::tuple<

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FN>, PassThrough>>>&
instances) {
Row, Row, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FN>, PassThrough>>>& instances) {
ck::tensor_operation::device::instance::add_device_operation_instances(
instances,
device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck<Float8E4M3FN>{});
instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck<Float8E4M3FN>{});
}

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FNUZ>, PassThrough>>>&
instances) {
Row, Row, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FNUZ>, PassThrough>>>& instances) {
ck::tensor_operation::device::instance::add_device_operation_instances(
instances,
device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck<Float8E4M3FNUZ>{});
instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck<Float8E4M3FNUZ>{});
}

} // namespace internal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances = std::tuple<

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Col, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FN>, PassThrough>>>&
instances) {
Row, Col, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FN>, PassThrough>>>& instances) {
ck::tensor_operation::device::instance::add_device_operation_instances(
instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances<Float8E4M3FN>{});
ck::tensor_operation::device::instance::add_device_operation_instances(
Expand All @@ -79,8 +78,7 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Col, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FNUZ>, PassThrough>>>&
instances) {
Row, Col, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FNUZ>, PassThrough>>>& instances) {
ck::tensor_operation::device::instance::add_device_operation_instances(
instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances<Float8E4M3FNUZ>{});
ck::tensor_operation::device::instance::add_device_operation_instances(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck = std::tuple<

void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, F8, F16, F16, Scale<Float8E4M3FN>, PassThrough, PassThrough>>>&
instances) {
Row, Row, Row, F8, F16, F16, Scale<Float8E4M3FN>, PassThrough, PassThrough>>>& instances) {
ck::tensor_operation::device::instance::add_device_operation_instances(
instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck<Float8E4M3FN>{});
ck::tensor_operation::device::instance::add_device_operation_instances(
Expand All @@ -82,8 +81,7 @@ void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(

void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, F8, F16, F16, Scale<Float8E4M3FNUZ>, PassThrough, PassThrough>>>&
instances) {
Row, Row, Row, F8, F16, F16, Scale<Float8E4M3FNUZ>, PassThrough, PassThrough>>>& instances) {
ck::tensor_operation::device::instance::add_device_operation_instances(
instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck<Float8E4M3FNUZ>{});
ck::tensor_operation::device::instance::add_device_operation_instances(
Expand Down

0 comments on commit 084325e

Please sign in to comment.