Skip to content

Commit

Permalink
Add transb support for fp8 b
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Nov 22, 2023
1 parent 38317d2 commit b601911
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 82 deletions.
76 changes: 42 additions & 34 deletions onnxruntime/contrib_ops/rocm/math/gemm_float8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ class GemmFloat8 final : public RocmKernel {
dtype_ = info.GetAttrOrDefault<int64_t>("dtype", onnx::TensorProto_DataType_FLOAT16);
alpha_ = info.GetAttrOrDefault<float>("alpha", 1);
beta_ = info.GetAttrOrDefault<float>("beta", 0);

tunable_op_fp8e4m3fn_fp16_fp16_ = std::make_unique<decltype(tunable_op_fp8e4m3fn_fp16_fp16_)::element_type>();
tunable_op_fp8e4m3fnuz_fp16_fp16_ = std::make_unique<decltype(tunable_op_fp8e4m3fnuz_fp16_fp16_)::element_type>();
tunable_op_fp16_fp8e4m3fn_fp16_ = std::make_unique<decltype(tunable_op_fp16_fp8e4m3fn_fp16_)::element_type>();
tunable_op_fp16_fp8e4m3fnuz_fp16_ = std::make_unique<decltype(tunable_op_fp16_fp8e4m3fnuz_fp16_)::element_type>();
}
Status ComputeInternal(OpKernelContext* ctx) const override;

Expand All @@ -35,21 +30,25 @@ class GemmFloat8 final : public RocmKernel {
template <typename Fp8T>
Status ComputeFp16Fp8Fp16(OpKernelContext* ctx, const Tensor* A, const Tensor* B, const Tensor* scaleB, Tensor* C) const;

template <typename Fp8T, bool IsAFp8>
[[nodiscard]] inline auto& GetOp() const {
if constexpr (std::is_same_v<Fp8T, Float8E4M3FN>) {
if constexpr (IsAFp8) {
return tunable_op_fp8e4m3fn_fp16_fp16_;
} else {
return tunable_op_fp16_fp8e4m3fn_fp16_;
}
} else if constexpr (std::is_same_v<Fp8T, Float8E4M3FNUZ>) {
if constexpr (IsAFp8) {
return tunable_op_fp8e4m3fnuz_fp16_fp16_;
} else {
return tunable_op_fp16_fp8e4m3fnuz_fp16_;
}
template <typename TA, typename TB, typename TC, typename ALayout, typename BLayout>
inline auto MaybeCreateTypeErasedSharedPtr() const {

}

template <typename TA, typename TB, typename TC, typename ALayout, typename BLayout>
[[nodiscard]] inline auto* GetOp() const {
using OpT = F8GemmTunableOp<TA, TB, TC, ALayout, BLayout>;
if (tunable_op_) {
return static_cast<OpT*>(tunable_op_.get());
}

auto create = std::make_unique<OpT>(); // avoid new
tunable_op_ = std::shared_ptr<void>(create.release(), [](void* ptr) {
auto release = std::unique_ptr<OpT>(); // avoid delete
release.reset(static_cast<OpT*>(ptr));
});

return static_cast<OpT*>(tunable_op_.get());
}

float alpha_;
Expand All @@ -58,10 +57,8 @@ class GemmFloat8 final : public RocmKernel {
bool transB_;
int64_t dtype_;

std::unique_ptr<F8GemmTunableOp<Float8E4M3FN, MLFloat16, MLFloat16, internal::Row, internal::Row>> tunable_op_fp8e4m3fn_fp16_fp16_;
std::unique_ptr<F8GemmTunableOp<Float8E4M3FNUZ, MLFloat16, MLFloat16, internal::Row, internal::Row>> tunable_op_fp8e4m3fnuz_fp16_fp16_;
std::unique_ptr<F8GemmTunableOp<MLFloat16, Float8E4M3FN, MLFloat16, internal::Row, internal::Row>> tunable_op_fp16_fp8e4m3fn_fp16_;
std::unique_ptr<F8GemmTunableOp<MLFloat16, Float8E4M3FNUZ, MLFloat16, internal::Row, internal::Row>> tunable_op_fp16_fp8e4m3fnuz_fp16_;
// fully type erased
mutable std::shared_ptr<void> tunable_op_;
};

Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const {
Expand All @@ -81,7 +78,7 @@ Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const {
output_shape[output_shape.size() - 1] = b_shape[b_shape.NumDimensions() - 1];
Tensor* Y = ctx->Output(0, output_shape);

ORT_ENFORCE(!transA_ && !transB_, "ROCm GemmFloat8 does not support input transpose");
ORT_ENFORCE(!transA_, "ROCm GemmFloat8 does not support input A transpose");
ORT_ENFORCE(dtype_ == onnx::TensorProto_DataType_FLOAT16, "ROCm GemmFloat8 only supports output float16");
ORT_ENFORCE(C == nullptr, "ROCm GemmFloat8 does not support bias input");
ORT_ENFORCE(scale_y == nullptr, "ROCm GemmFloat8 does not support output scaling");
Expand Down Expand Up @@ -114,20 +111,20 @@ Status GemmFloat8::ComputeFp8Fp16Fp16(OpKernelContext* ctx, const Tensor* A, con
params.tuning_ctx = GetTuningContext();
params.stream = ctx->GetComputeStream();
params.handle = GetRocblasHandle(ctx);
params.opa = tunable::blas::BlasOp::NonTrans;
params.opb = tunable::blas::BlasOp::NonTrans;
params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans;
params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans;

params.m = m;
params.n = n;
params.k = k;

params.a = static_cast<const Fp8T*>(A->DataRaw());
params.lda = k;
params.lda = transA_ ? m : k;
params.scale_a = alpha_;
params.scale_a_dev = static_cast<const float*>(scale_a->DataRaw());

params.b = static_cast<const MLFloat16*>(B->DataRaw());
params.ldb = n;
params.ldb = transB_ ? k : n;
params.scale_b = 1.0f; // NOTE: not used
params.scale_b_dev = nullptr; // NOTE: not used

Expand All @@ -136,7 +133,13 @@ Status GemmFloat8::ComputeFp8Fp16Fp16(OpKernelContext* ctx, const Tensor* A, con
params.scale_c = 1.0f; // NOTE: not implemented
params.scale_c_dev = nullptr; // NOTE: not implemented

return (*GetOp<Fp8T, true>())(&params);
// NOTE: transA is not implemented
if (transB_) {
ORT_NOT_IMPLEMENTED("transB is not implemented");
// return (*GetOp<Fp8T, MLFloat16, MLFloat16, Row, Col>())(&params);
} else {
return (*GetOp<Fp8T, MLFloat16, MLFloat16, Row, Row>())(&params);
}
}

template <typename Fp8T>
Expand All @@ -154,20 +157,20 @@ Status GemmFloat8::ComputeFp16Fp8Fp16(OpKernelContext* ctx, const Tensor* A, con
params.tuning_ctx = GetTuningContext();
params.stream = ctx->GetComputeStream();
params.handle = GetRocblasHandle(ctx);
params.opa = tunable::blas::BlasOp::NonTrans;
params.opb = tunable::blas::BlasOp::NonTrans;
params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans;
params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans;

params.m = m;
params.n = n;
params.k = k;

params.a = static_cast<const MLFloat16*>(A->DataRaw());
params.lda = k;
params.lda = transA_ ? m : k;
params.scale_a = 1.0f; // NOTE: not used
params.scale_a_dev = nullptr; // NOTE: not used

params.b = static_cast<const Fp8T*>(B->DataRaw());
params.ldb = n;
params.ldb = transB_ ? k : n;
params.scale_b = alpha_;
params.scale_b_dev = static_cast<const float*>(scale_b->DataRaw());

Expand All @@ -176,7 +179,12 @@ Status GemmFloat8::ComputeFp16Fp8Fp16(OpKernelContext* ctx, const Tensor* A, con
params.scale_c = 1.0f; // NOTE: not implemented
params.scale_c_dev = nullptr; // NOTE: not implemented

return (*GetOp<Fp8T, false>())(&params);
// NOTE: transA is not implemented
if (transB_) {
return (*GetOp<MLFloat16, Fp8T, MLFloat16, Row, Col>())(&params);
} else {
return (*GetOp<MLFloat16, Fp8T, MLFloat16, Row, Row>())(&params);
}
}

ONNX_OPERATOR_KERNEL_EX(
Expand Down
34 changes: 25 additions & 9 deletions onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ namespace onnxruntime {
namespace rocm {
namespace tunable {

using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

template <typename... Ts>
constexpr bool always_false = false;

Expand All @@ -32,7 +39,7 @@ struct Scale {
constexpr const static bool is_pack2_invocable = true;
constexpr const static bool is_pack4_invocable = true;

explicit Scale(float scale_value, const float* dev_scale_ptr) : scale_value_{scale_value}, dev_scale_ptr_{dev_scale_ptr} {}
explicit Scale(float scale_value, const float* dev_scale_ptr) : scale_value_{scale_value}, dev_scale_ptr_{dev_scale_ptr} {}

template <typename Y, typename X>
__forceinline__ __host__ __device__ Y fast_type_convert(X x) const {
Expand Down Expand Up @@ -141,8 +148,6 @@ struct GemmFloat8Params : tunable::OpParams {
int64_t ldc;
};

namespace internal {

#ifdef USE_COMPOSABLE_KERNEL

using Row = ck::tensor_layout::gemm::RowMajor;
Expand All @@ -166,6 +171,14 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, ck::half_t, ck::f8_t, ck::half_t, Nop, Scale<Float8E4M3FNUZ>, Nop>>>& instances);

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Col, Row, F16, F8, F16, Nop, Scale<Float8E4M3FN>, Nop>>>& instances);

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Col, Row, F16, F8, F16, Nop, Scale<Float8E4M3FNUZ>, Nop>>>& instances);

template <typename OrtT>
auto CreateOp(float scale, const float* dev_scale) {
if constexpr (std::is_same_v<OrtT, Float8E4M3FN>) {
Expand Down Expand Up @@ -197,12 +210,17 @@ auto GetCKF8SplitKGemmTypeStringAndOps() {
for (auto num_split : {1, 4, 16, 64}) {
std::vector<std::unique_ptr<DeviceGemm>> instances{};
// only supports fp8_fp16_fp16_row_row_row and fp16_fp8_fp16_row_row_row now.
if constexpr (std::is_same_v<CKTA, ck::f8_t> && std::is_same_v<CKTB, ck::half_t> && std::is_same_v<CKTC, ck::half_t>) {
if constexpr (std::is_same_v<CKTA, ck::f8_t> && std::is_same_v<CKTB, ck::half_t> && std::is_same_v<CKTC, ck::half_t> &&
std::is_same_v<ALayout, Row> && std::is_same_v<BLayout, Row>) {
add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(instances);
} else if constexpr (std::is_same_v<CKTA, ck::half_t> && std::is_same_v<CKTB, ck::f8_t> && std::is_same_v<CKTC, ck::half_t>) {
} else if constexpr (std::is_same_v<CKTA, ck::half_t> && std::is_same_v<CKTB, ck::f8_t> && std::is_same_v<CKTC, ck::half_t> &&
std::is_same_v<ALayout, Row> && std::is_same_v<BLayout, Row>) {
add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(instances);
} else if constexpr (std::is_same_v<CKTA, ck::half_t> && std::is_same_v<CKTB, ck::f8_t> && std::is_same_v<CKTC, ck::half_t> &&
std::is_same_v<ALayout, Row> && std::is_same_v<BLayout, Col>) {
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(instances);
} else {
static_assert(always_false<CKTA, CKTB, CKTC>, "no instances for the type combination");
static_assert(always_false<CKTA, CKTB, CKTC, ALayout, BLayout>, "no instances for the type combination");
LOGS_DEFAULT(FATAL) << "no instances for the type combination";
}
for (auto&& impl : instances) {
Expand Down Expand Up @@ -230,14 +248,12 @@ auto GetCKF8SplitKGemmTypeStringAndOps() {

#endif // USE_COMPOSABLE_KERNEL

} // namespace internal

template <typename TA, typename TB, typename TC, typename ALayout, typename BLayout>
class F8GemmTunableOp : public TunableOp<GemmFloat8Params<TA, TB, TC>> {
public:
F8GemmTunableOp() {
#ifdef USE_COMPOSABLE_KERNEL
for (auto&& [_, op] : internal::GetCKF8SplitKGemmTypeStringAndOps<TA, TB, TC, ALayout, BLayout>()) {
for (auto&& [_, op] : GetCKF8SplitKGemmTypeStringAndOps<TA, TB, TC, ALayout, BLayout>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ namespace onnxruntime {
namespace rocm {
namespace tunable {
namespace blas {
namespace internal {

using F8 = ck::f8_t;
using F16 = ck::half_t;
Expand All @@ -27,6 +26,7 @@ using S = ck::Sequence<Is...>;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

namespace internal {
void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FN>, PassThrough>>>& instances);
Expand All @@ -42,23 +42,25 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(
void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FNUZ>, PassThrough>>>& instances);
} // namespace internal

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FN>, PassThrough>>>&
instances) {
add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances);
add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances);
internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances);
internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances);
}

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FNUZ>, PassThrough>>>&
instances) {
add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances);
add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances);
internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances);
internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances);
}

namespace internal {
void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, F8, F16, F16, Scale<Float8E4M3FN>, PassThrough, PassThrough>>>& instances);
Expand All @@ -76,22 +78,48 @@ void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(
// void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(
// std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
// Row, Row, Row, F8, F16, F16, Scale<Float8E4M3FNUZ>, PassThrough, PassThrough>>>& instances);
} // namespace internal

void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, F8, F16, F16, Scale<Float8E4M3FN>, PassThrough, PassThrough>>>& instances) {
add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances);
// add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO:
internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances);
// internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO:
}

void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Row, Row, F8, F16, F16, Scale<Float8E4M3FNUZ>, PassThrough, PassThrough>>>& instances) {
add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances);
// add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO:
internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances);
// internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO:
}

namespace internal {
void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Col, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FN>, PassThrough>>>&
instances);

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Col, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FNUZ>, PassThrough>>>&
instances);
} // namespace internal

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Col, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FN>, PassThrough>>>&
instances) {
internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances);
}

void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceGemmSplitK<
Row, Col, Row, F16, F8, F16, PassThrough, Scale<Float8E4M3FNUZ>, PassThrough>>>&
instances) {
internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances);
}

} // namespace blas
} // namespace tunable
} // namespace rocm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,14 @@ namespace tunable {
namespace blas {
namespace internal {

using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;

#define DeviceGemmXdlSplitKCShuffle ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle
using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle;

// The derived version is simply double BBlockTransferSrcScalarPerVector and adjust other values correspondingly
template <typename ScaleElemT>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,14 @@ namespace tunable {
namespace blas {
namespace internal {

using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;

#define DeviceGemmXdlSplitKCShuffle ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle
using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle;

// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
template <typename ScaleElemT>
Expand Down
Loading

0 comments on commit b601911

Please sign in to comment.