From fe8a10caa40f64a8fbd144e7049cf5b14c65542d Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 20 Sep 2024 17:24:13 -0700 Subject: [PATCH] Address ZeroK case for Gemm for CPU and CUDA (#22111) ### Description When K == 0 output a MxN matrix filled with bias if present or filled with zeros. This brings it inline with MatMul behavior especially when Gemm is used to fuse MatMul with Add. ### Motivation and Context * Comply with numpy spec of MatMul * Address a case when empty initializers are used for computation. --- onnxruntime/core/providers/cpu/math/gemm.cc | 59 ++++++++++++------- .../core/providers/cpu/math/gemm_helper.h | 3 +- onnxruntime/core/providers/cpu/math/matmul.cc | 10 ++-- onnxruntime/core/providers/cuda/math/gemm.cc | 10 ++++ .../test/providers/cpu/math/gemm_test.cc | 40 +++++++++++++ 5 files changed, 95 insertions(+), 27 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index 5a886cce9d5d0..5406dd1a40446 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -154,6 +154,14 @@ void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, // Broadcast the bias as needed if bias is given GemmBroadcastBias(M, N, beta, c_data, c_shape, y_data); + if (K == 0) { + if (beta == 0 || c_data == nullptr) { + EigenMatrixMapRowMajor dest(y_data, narrow(M), narrow(N)); + dest.setZero(); + } + return; + } + math::Gemm(trans_a, trans_b, M, N, K, alpha, @@ -179,16 +187,18 @@ void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans if (M == 0 || N == 0) return; -#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wclass-memaccess" -#endif - // MLFloat16's constructor is explicit, so here we need to use memset + if (K == 0) { + if (beta != onnxruntime::MLFloat16::Zero && c_data != nullptr) { + GemmBroadcastBias(M, N, beta, c_data, c_shape, y_data); + } else { + auto output_span = gsl::make_span(y_data, SafeInt(M) * N); + std::fill(output_span.begin(), output_span.end(), onnxruntime::MLFloat16::Zero); + } + return; + } + if (c_data == nullptr) - memset(&beta, 0, sizeof(MLFloat16)); -#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS) -#pragma GCC diagnostic pop -#endif + beta = onnxruntime::MLFloat16::Zero; #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED bool support_mlas = false; if (c_shape == nullptr) { @@ -413,19 +423,24 @@ Status Gemm::Compute(OpKernelContext* context) const { c_data, c_shape, y_data, thread_pool); } else { GemmBroadcastBias(M, N, beta_, c_data, c_shape, y_data); - MlasGemm( - trans_A_, - static_cast(M), - static_cast(N), - static_cast(K), - alpha_, - A->Data(), - static_cast(trans_A_ != CblasNoTrans ? M : K), - packed_b_.get(), - c_data != nullptr ? beta_ : 0.0f, - y_data, - static_cast(N), - thread_pool); + if (K > 0) { + MlasGemm( + trans_A_, + static_cast(M), + static_cast(N), + static_cast(K), + alpha_, + A->Data(), + static_cast(trans_A_ != CblasNoTrans ? M : K), + packed_b_.get(), + c_data != nullptr ? beta_ : 0.0f, + y_data, + static_cast(N), + thread_pool); + } else if (beta_ == 0 || c_data == nullptr) { + EigenMatrixMapRowMajor dest(y_data, narrow(M), narrow(N)); + dest.setZero(); + } } ComputeActivation(y_data, SafeInt(M) * N, thread_pool); diff --git a/onnxruntime/core/providers/cpu/math/gemm_helper.h b/onnxruntime/core/providers/cpu/math/gemm_helper.h index f37b00ac2c16d..b55bf2b5dbbfa 100644 --- a/onnxruntime/core/providers/cpu/math/gemm_helper.h +++ b/onnxruntime/core/providers/cpu/math/gemm_helper.h @@ -56,7 +56,8 @@ class GemmHelper { status_ = common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Gemm: Invalid bias shape for broadcast"); // it is possible the input is empty tensor, for example the output of roipool in fast rcnn. - ORT_ENFORCE(M_ >= 0 && K_ > 0 && N_ >= 0); + // it is also possible that K == 0 + ORT_ENFORCE(M_ >= 0 && K_ >= 0 && N_ >= 0); } ptrdiff_t M() const { return M_; } diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 6a71283f9dbd4..2c6d23e4de908 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -106,8 +106,9 @@ Status MatMul::Compute(OpKernelContext* ctx) const { if (helper.K() == 0) { // When we have (M, 0, N) then the inputs are empty, but the output should // be filled out with zeros. - auto output_span = y->MutableDataAsSpan(); - std::fill(output_span.begin(), output_span.end(), T{}); + EigenMatrixMapRowMajor dest(y->MutableData(), + narrow(helper.M()), narrow(helper.N())); + dest.setZero(); return Status::OK(); } @@ -241,8 +242,9 @@ Status MatMul::Compute(OpKernelContext* ctx) const { if (helper.K() == 0) { // When we have (M, 0, N) then the inputs are empty, but the output should // be filled out with zeros. - auto output_span = y->MutableDataAsSpan(); - std::fill(output_span.begin(), output_span.end(), float{}); + EigenMatrixMapRowMajor dest(y->MutableData(), + narrow(helper.M()), narrow(helper.N())); + dest.setZero(); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/math/gemm.cc b/onnxruntime/core/providers/cuda/math/gemm.cc index 4e61e0c8c69c6..7fa5e74b54248 100644 --- a/onnxruntime/core/providers/cuda/math/gemm.cc +++ b/onnxruntime/core/providers/cuda/math/gemm.cc @@ -137,6 +137,16 @@ Status Gemm::ComputeDefault(OpKernelContext* ctx, int M, int N, int K) const } } + if (K == 0) { + if (beta_ == 0 || B == nullptr) { + // When we have (M, 0, N) then the output should be filled out with zeros + // unless we have a bias + Fill(Stream(ctx), reinterpret_cast(Y->MutableData()), CudaT(0.f), + Y->Shape().Size()); + } + return Status::OK(); + } + CudaT alpha = ToCudaType::FromFloat(alpha_); CudaT beta = ToCudaType::FromFloat(beta_); // Gemm, note that CUDA assumes col-major, so Y(N,M) = alpha * op(W) x op(X) + beta * Y diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 7ec84d87b2a8b..625ff29d4ccf9 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -641,6 +641,46 @@ TYPED_TEST(GemmOpTypedTests, GemmEmptyTensor) { .Config(run_with_tunable_op) .RunWithConfig(); } + +TYPED_TEST(GemmOpTypedTests, ZeroKWithBias) { + OpTester test("Gemm", 13); + + test.AddAttribute("transA", static_cast(0)); + test.AddAttribute("transB", static_cast(0)); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + + test.AddInput("A", {4, 0}, {}); + test.AddInput("B", {0, 4}, {}); + test.AddInput("C", {4}, std::vector(4, static_cast(1.0f))); + test.AddOutput("Y", {4, 4}, std::vector(16, static_cast(1.0f))); + + test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider, + kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider, + kOpenVINOExecutionProvider}) + .Config(run_with_tunable_op) + .RunWithConfig(); +} + +TYPED_TEST(GemmOpTypedTests, ZeroKWithNoBias) { + OpTester test("Gemm", 13); + + test.AddAttribute("transA", static_cast(0)); + test.AddAttribute("transB", static_cast(0)); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", .0f); + + test.AddInput("A", {4, 0}, {}); + test.AddInput("B", {0, 4}, {}); + test.AddOutput("Y", {4, 4}, std::vector(16, static_cast(0.0f))); + + test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider, + kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider, + kOpenVINOExecutionProvider}) + .Config(run_with_tunable_op) + .RunWithConfig(); +} + TYPED_TEST(GemmOpTypedTests, MissingBias) { OpTester test("Gemm", 11);