Skip to content

Commit

Permalink
Fix mnk
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Nov 28, 2023
1 parent 084325e commit 561acdf
Showing 1 changed file with 23 additions and 27 deletions.
50 changes: 23 additions & 27 deletions onnxruntime/contrib_ops/rocm/math/gemm_float8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ class GemmFloat8 final : public RocmKernel {
private:
#if !defined(DISABLE_FLOAT8_TYPES)
template <typename Fp8T>
Status ComputeFp8Fp16Fp16(OpKernelContext* ctx, const Tensor* A, const Tensor* scaleA, const Tensor* B, Tensor* C) const;
Status ComputeFp8Fp16Fp16(OpKernelContext* ctx, int64_t m, int64_t n, int64_t k,
const Tensor* A, const Tensor* scaleA, const Tensor* B, Tensor* C) const;
template <typename Fp8T>
Status ComputeFp16Fp8Fp16(OpKernelContext* ctx, const Tensor* A, const Tensor* B, const Tensor* scaleB, Tensor* C) const;
Status ComputeFp16Fp8Fp16(OpKernelContext* ctx, int64_t m, int64_t n, int64_t k,
const Tensor* A, const Tensor* B, const Tensor* scaleB, Tensor* C) const;

template <typename TA, typename TB, typename TC, BlasOp OpA, BlasOp OpB>
[[nodiscard]] inline auto* GetOp() const {
Expand All @@ -38,7 +40,7 @@ class GemmFloat8 final : public RocmKernel {
return static_cast<OpT*>(tunable_op_.get());
}

auto create = std::make_unique<OpT>(); // avoid new
auto create = std::make_unique<OpT>(); // avoid new
tunable_op_ = std::shared_ptr<void>(create.release(), [](void* ptr) {
auto release = std::unique_ptr<OpT>(); // avoid delete
release.reset(static_cast<OpT*>(ptr));
Expand Down Expand Up @@ -71,11 +73,15 @@ Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const {

auto a_shape = A->Shape();
auto b_shape = B->Shape();
ORT_ENFORCE(a_shape.NumDimensions() >= 2 && b_shape.NumDimensions() == 2); // is in form of input @ weight
ORT_ENFORCE(a_shape[a_shape.NumDimensions() - 1] == b_shape[0]); // k is compatiable
ORT_ENFORCE(a_shape.NumDimensions() == 2);
ORT_ENFORCE(b_shape.NumDimensions() == 2);

TensorShapeVector output_shape = a_shape.AsShapeVector();
output_shape[output_shape.size() - 1] = b_shape[b_shape.NumDimensions() - 1];
auto m = !transA_ ? a_shape[0] : a_shape[1];
auto k = !transA_ ? a_shape[1] : a_shape[0];
ORT_ENFORCE(k == (!transB_ ? b_shape[0] : b_shape[1])); // k is compatiable

Check notice on line 81 in onnxruntime/contrib_ops/rocm/math/gemm_float8.cu

View workflow job for this annotation

GitHub Actions / misspell

[misspell] onnxruntime/contrib_ops/rocm/math/gemm_float8.cu#L81

"compatiable" is a misspelling of "compatible"
Raw output
./onnxruntime/contrib_ops/rocm/math/gemm_float8.cu:81:67: "compatiable" is a misspelling of "compatible"
auto n = !transB_ ? b_shape[1] : b_shape[0];

TensorShapeVector output_shape = {m, n};
Tensor* Y = ctx->Output(0, output_shape);

ORT_ENFORCE(!transA_, "ROCm GemmFloat8 does not support input A transpose");
Expand All @@ -84,13 +90,13 @@ Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const {
ORT_ENFORCE(scale_y == nullptr, "ROCm GemmFloat8 does not support output scaling");

if (A->IsDataType<Float8E4M3FN>()) {
return ComputeFp8Fp16Fp16<Float8E4M3FN>(ctx, A, scale_a, B, Y);
return ComputeFp8Fp16Fp16<Float8E4M3FN>(ctx, m, n, k, A, scale_a, B, Y);
} else if (A->IsDataType<Float8E4M3FNUZ>()) {
return ComputeFp8Fp16Fp16<Float8E4M3FNUZ>(ctx, A, scale_a, B, Y);
return ComputeFp8Fp16Fp16<Float8E4M3FNUZ>(ctx, m, n, k, A, scale_a, B, Y);
} else if (B->IsDataType<Float8E4M3FN>()) {
return ComputeFp16Fp8Fp16<Float8E4M3FN>(ctx, A, B, scale_b, Y);
return ComputeFp16Fp8Fp16<Float8E4M3FN>(ctx, m, n, k, A, B, scale_b, Y);
} else if (B->IsDataType<Float8E4M3FNUZ>()) {
return ComputeFp16Fp8Fp16<Float8E4M3FNUZ>(ctx, A, B, scale_b, Y);
return ComputeFp16Fp8Fp16<Float8E4M3FNUZ>(ctx, m, n, k, A, B, scale_b, Y);
}

return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unhandled type combination of GemmFloat8");
Expand All @@ -99,16 +105,11 @@ Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const {

#if !defined(DISABLE_FLOAT8_TYPES)
template <typename Fp8T>
Status GemmFloat8::ComputeFp8Fp16Fp16(OpKernelContext* ctx, const Tensor* A, const Tensor* scale_a, const Tensor* B, Tensor* C) const {
Status GemmFloat8::ComputeFp8Fp16Fp16(
OpKernelContext* ctx, int64_t m, int64_t n, int64_t k,
const Tensor* A, const Tensor* scale_a, const Tensor* B, Tensor* C) const {
ORT_ENFORCE(A->IsDataType<Fp8T>() && scale_a->IsDataType<float>() && B->IsDataType<MLFloat16>());

auto a_shape = A->Shape();
auto b_shape = B->Shape();

auto m = a_shape.Slice(0, a_shape.NumDimensions() - 1).Size();
auto k = a_shape[a_shape.NumDimensions() - 1];
auto n = b_shape[b_shape.NumDimensions() - 1];

onnxruntime::rocm::tunable::blas::GemmFloat8Params<Fp8T, MLFloat16, MLFloat16> params{};
params.tuning_ctx = GetTuningContext();
params.stream = ctx->GetComputeStream();
Expand Down Expand Up @@ -148,16 +149,11 @@ Status GemmFloat8::ComputeFp8Fp16Fp16(OpKernelContext* ctx, const Tensor* A, con
}

template <typename Fp8T>
Status GemmFloat8::ComputeFp16Fp8Fp16(OpKernelContext* ctx, const Tensor* A, const Tensor* B, const Tensor* scale_b, Tensor* C) const {
Status GemmFloat8::ComputeFp16Fp8Fp16(
OpKernelContext* ctx, int64_t m, int64_t n, int64_t k,
const Tensor* A, const Tensor* B, const Tensor* scale_b, Tensor* C) const {
ORT_ENFORCE(A->IsDataType<MLFloat16>() && B->IsDataType<Fp8T>() && scale_b->IsDataType<float>());

auto a_shape = A->Shape();
auto b_shape = B->Shape();

auto m = a_shape.Slice(0, a_shape.NumDimensions() - 1).Size();
auto k = a_shape[a_shape.NumDimensions() - 1];
auto n = b_shape[b_shape.NumDimensions() - 1];

onnxruntime::rocm::tunable::blas::GemmFloat8Params<MLFloat16, Fp8T, MLFloat16> params{};
params.tuning_ctx = GetTuningContext();
params.stream = ctx->GetComputeStream();
Expand Down

0 comments on commit 561acdf

Please sign in to comment.