diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu index d7d98dfa15756..1e175b37b02d8 100644 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu @@ -27,9 +27,11 @@ class GemmFloat8 final : public RocmKernel { private: #if !defined(DISABLE_FLOAT8_TYPES) template - Status ComputeFp8Fp16Fp16(OpKernelContext* ctx, const Tensor* A, const Tensor* scaleA, const Tensor* B, Tensor* C) const; + Status ComputeFp8Fp16Fp16(OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, + const Tensor* A, const Tensor* scaleA, const Tensor* B, Tensor* C) const; template - Status ComputeFp16Fp8Fp16(OpKernelContext* ctx, const Tensor* A, const Tensor* B, const Tensor* scaleB, Tensor* C) const; + Status ComputeFp16Fp8Fp16(OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, + const Tensor* A, const Tensor* B, const Tensor* scaleB, Tensor* C) const; template [[nodiscard]] inline auto* GetOp() const { @@ -38,7 +40,7 @@ class GemmFloat8 final : public RocmKernel { return static_cast(tunable_op_.get()); } - auto create = std::make_unique(); // avoid new + auto create = std::make_unique(); // avoid new tunable_op_ = std::shared_ptr(create.release(), [](void* ptr) { auto release = std::unique_ptr(); // avoid delete release.reset(static_cast(ptr)); @@ -71,11 +73,15 @@ Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { auto a_shape = A->Shape(); auto b_shape = B->Shape(); - ORT_ENFORCE(a_shape.NumDimensions() >= 2 && b_shape.NumDimensions() == 2); // is in form of input @ weight - ORT_ENFORCE(a_shape[a_shape.NumDimensions() - 1] == b_shape[0]); // k is compatiable + ORT_ENFORCE(a_shape.NumDimensions() == 2); + ORT_ENFORCE(b_shape.NumDimensions() == 2); - TensorShapeVector output_shape = a_shape.AsShapeVector(); - output_shape[output_shape.size() - 1] = b_shape[b_shape.NumDimensions() - 1]; + auto m = !transA_ ? a_shape[0] : a_shape[1]; + auto k = !transA_ ? a_shape[1] : a_shape[0]; + ORT_ENFORCE(k == (!transB_ ? b_shape[0] : b_shape[1])); // k is compatiable + auto n = !transB_ ? b_shape[1] : b_shape[0]; + + TensorShapeVector output_shape = {m, n}; Tensor* Y = ctx->Output(0, output_shape); ORT_ENFORCE(!transA_, "ROCm GemmFloat8 does not support input A transpose"); @@ -84,13 +90,13 @@ Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { ORT_ENFORCE(scale_y == nullptr, "ROCm GemmFloat8 does not support output scaling"); if (A->IsDataType()) { - return ComputeFp8Fp16Fp16(ctx, A, scale_a, B, Y); + return ComputeFp8Fp16Fp16(ctx, m, n, k, A, scale_a, B, Y); } else if (A->IsDataType()) { - return ComputeFp8Fp16Fp16(ctx, A, scale_a, B, Y); + return ComputeFp8Fp16Fp16(ctx, m, n, k, A, scale_a, B, Y); } else if (B->IsDataType()) { - return ComputeFp16Fp8Fp16(ctx, A, B, scale_b, Y); + return ComputeFp16Fp8Fp16(ctx, m, n, k, A, B, scale_b, Y); } else if (B->IsDataType()) { - return ComputeFp16Fp8Fp16(ctx, A, B, scale_b, Y); + return ComputeFp16Fp8Fp16(ctx, m, n, k, A, B, scale_b, Y); } return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unhandled type combination of GemmFloat8"); @@ -99,16 +105,11 @@ Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { #if !defined(DISABLE_FLOAT8_TYPES) template -Status GemmFloat8::ComputeFp8Fp16Fp16(OpKernelContext* ctx, const Tensor* A, const Tensor* scale_a, const Tensor* B, Tensor* C) const { +Status GemmFloat8::ComputeFp8Fp16Fp16( + OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, + const Tensor* A, const Tensor* scale_a, const Tensor* B, Tensor* C) const { ORT_ENFORCE(A->IsDataType() && scale_a->IsDataType() && B->IsDataType()); - auto a_shape = A->Shape(); - auto b_shape = B->Shape(); - - auto m = a_shape.Slice(0, a_shape.NumDimensions() - 1).Size(); - auto k = a_shape[a_shape.NumDimensions() - 1]; - auto n = b_shape[b_shape.NumDimensions() - 1]; - onnxruntime::rocm::tunable::blas::GemmFloat8Params params{}; params.tuning_ctx = GetTuningContext(); params.stream = ctx->GetComputeStream(); @@ -148,16 +149,11 @@ Status GemmFloat8::ComputeFp8Fp16Fp16(OpKernelContext* ctx, const Tensor* A, con } template -Status GemmFloat8::ComputeFp16Fp8Fp16(OpKernelContext* ctx, const Tensor* A, const Tensor* B, const Tensor* scale_b, Tensor* C) const { +Status GemmFloat8::ComputeFp16Fp8Fp16( + OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, + const Tensor* A, const Tensor* B, const Tensor* scale_b, Tensor* C) const { ORT_ENFORCE(A->IsDataType() && B->IsDataType() && scale_b->IsDataType()); - auto a_shape = A->Shape(); - auto b_shape = B->Shape(); - - auto m = a_shape.Slice(0, a_shape.NumDimensions() - 1).Size(); - auto k = a_shape[a_shape.NumDimensions() - 1]; - auto n = b_shape[b_shape.NumDimensions() - 1]; - onnxruntime::rocm::tunable::blas::GemmFloat8Params params{}; params.tuning_ctx = GetTuningContext(); params.stream = ctx->GetComputeStream();