Skip to content

Commit

Permalink
Address ZeroK case for Gemm for CPU and CUDA (#22111)
Browse files Browse the repository at this point in the history
### Description
When K == 0 output a MxN matrix filled with bias if present or filled
with zeros.
This brings it inline with MatMul behavior especially when Gemm is used
to fuse MatMul with Add.


### Motivation and Context
* Comply with numpy spec of MatMul
* Address a case when empty initializers are used for computation.
  • Loading branch information
yuslepukhin authored Sep 21, 2024
1 parent 8d2d407 commit fe8a10c
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 27 deletions.
59 changes: 37 additions & 22 deletions onnxruntime/core/providers/cpu/math/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,14 @@ void Gemm<T>::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b,
// Broadcast the bias as needed if bias is given
GemmBroadcastBias(M, N, beta, c_data, c_shape, y_data);

if (K == 0) {
if (beta == 0 || c_data == nullptr) {
EigenMatrixMapRowMajor<T> dest(y_data, narrow<Eigen::Index>(M), narrow<Eigen::Index>(N));
dest.setZero();
}
return;
}

math::Gemm<T>(trans_a, trans_b,
M, N, K,
alpha,
Expand All @@ -179,16 +187,18 @@ void Gemm<MLFloat16>::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans
if (M == 0 || N == 0)
return;

#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wclass-memaccess"
#endif
// MLFloat16's constructor is explicit, so here we need to use memset
if (K == 0) {
if (beta != onnxruntime::MLFloat16::Zero && c_data != nullptr) {
GemmBroadcastBias(M, N, beta, c_data, c_shape, y_data);
} else {
auto output_span = gsl::make_span(y_data, SafeInt<size_t>(M) * N);
std::fill(output_span.begin(), output_span.end(), onnxruntime::MLFloat16::Zero);
}
return;
}

if (c_data == nullptr)
memset(&beta, 0, sizeof(MLFloat16));
#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS)
#pragma GCC diagnostic pop
#endif
beta = onnxruntime::MLFloat16::Zero;
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
bool support_mlas = false;
if (c_shape == nullptr) {
Expand Down Expand Up @@ -413,19 +423,24 @@ Status Gemm<float>::Compute(OpKernelContext* context) const {
c_data, c_shape, y_data, thread_pool);
} else {
GemmBroadcastBias(M, N, beta_, c_data, c_shape, y_data);
MlasGemm(
trans_A_,
static_cast<size_t>(M),
static_cast<size_t>(N),
static_cast<size_t>(K),
alpha_,
A->Data<float>(),
static_cast<size_t>(trans_A_ != CblasNoTrans ? M : K),
packed_b_.get(),
c_data != nullptr ? beta_ : 0.0f,
y_data,
static_cast<size_t>(N),
thread_pool);
if (K > 0) {
MlasGemm(
trans_A_,
static_cast<size_t>(M),
static_cast<size_t>(N),
static_cast<size_t>(K),
alpha_,
A->Data<float>(),
static_cast<size_t>(trans_A_ != CblasNoTrans ? M : K),
packed_b_.get(),
c_data != nullptr ? beta_ : 0.0f,
y_data,
static_cast<size_t>(N),
thread_pool);
} else if (beta_ == 0 || c_data == nullptr) {
EigenMatrixMapRowMajor<float> dest(y_data, narrow<Eigen::Index>(M), narrow<Eigen::Index>(N));
dest.setZero();
}
}

ComputeActivation(y_data, SafeInt<size_t>(M) * N, thread_pool);
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cpu/math/gemm_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class GemmHelper {
status_ = common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Gemm: Invalid bias shape for broadcast");

// it is possible the input is empty tensor, for example the output of roipool in fast rcnn.
ORT_ENFORCE(M_ >= 0 && K_ > 0 && N_ >= 0);
// it is also possible that K == 0
ORT_ENFORCE(M_ >= 0 && K_ >= 0 && N_ >= 0);
}

ptrdiff_t M() const { return M_; }
Expand Down
10 changes: 6 additions & 4 deletions onnxruntime/core/providers/cpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ Status MatMul<T>::Compute(OpKernelContext* ctx) const {
if (helper.K() == 0) {
// When we have (M, 0, N) then the inputs are empty, but the output should
// be filled out with zeros.
auto output_span = y->MutableDataAsSpan<T>();
std::fill(output_span.begin(), output_span.end(), T{});
EigenMatrixMapRowMajor<T> dest(y->MutableData<T>(),
narrow<Eigen::Index>(helper.M()), narrow<Eigen::Index>(helper.N()));
dest.setZero();
return Status::OK();
}

Expand Down Expand Up @@ -241,8 +242,9 @@ Status MatMul<float>::Compute(OpKernelContext* ctx) const {
if (helper.K() == 0) {
// When we have (M, 0, N) then the inputs are empty, but the output should
// be filled out with zeros.
auto output_span = y->MutableDataAsSpan<float>();
std::fill(output_span.begin(), output_span.end(), float{});
EigenMatrixMapRowMajor<float> dest(y->MutableData<float>(),
narrow<Eigen::Index>(helper.M()), narrow<Eigen::Index>(helper.N()));
dest.setZero();
return Status::OK();
}

Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/providers/cuda/math/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ Status Gemm<T>::ComputeDefault(OpKernelContext* ctx, int M, int N, int K) const
}
}

if (K == 0) {
if (beta_ == 0 || B == nullptr) {
// When we have (M, 0, N) then the output should be filled out with zeros
// unless we have a bias
Fill<CudaT>(Stream(ctx), reinterpret_cast<CudaT*>(Y->MutableData<T>()), CudaT(0.f),
Y->Shape().Size());
}
return Status::OK();
}

CudaT alpha = ToCudaType<T>::FromFloat(alpha_);
CudaT beta = ToCudaType<T>::FromFloat(beta_);
// Gemm, note that CUDA assumes col-major, so Y(N,M) = alpha * op(W) x op(X) + beta * Y
Expand Down
40 changes: 40 additions & 0 deletions onnxruntime/test/providers/cpu/math/gemm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,46 @@ TYPED_TEST(GemmOpTypedTests, GemmEmptyTensor) {
.Config(run_with_tunable_op)
.RunWithConfig();
}

TYPED_TEST(GemmOpTypedTests, ZeroKWithBias) {
OpTester test("Gemm", 13);

test.AddAttribute("transA", static_cast<int64_t>(0));
test.AddAttribute("transB", static_cast<int64_t>(0));
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", 1.0f);

test.AddInput<TypeParam>("A", {4, 0}, {});
test.AddInput<TypeParam>("B", {0, 4}, {});
test.AddInput<TypeParam>("C", {4}, std::vector<TypeParam>(4, static_cast<TypeParam>(1.0f)));
test.AddOutput<TypeParam>("Y", {4, 4}, std::vector<TypeParam>(16, static_cast<TypeParam>(1.0f)));

test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider,
kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider,
kOpenVINOExecutionProvider})
.Config(run_with_tunable_op)
.RunWithConfig();
}

TYPED_TEST(GemmOpTypedTests, ZeroKWithNoBias) {
OpTester test("Gemm", 13);

test.AddAttribute("transA", static_cast<int64_t>(0));
test.AddAttribute("transB", static_cast<int64_t>(0));
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", .0f);

test.AddInput<TypeParam>("A", {4, 0}, {});
test.AddInput<TypeParam>("B", {0, 4}, {});
test.AddOutput<TypeParam>("Y", {4, 4}, std::vector<TypeParam>(16, static_cast<TypeParam>(0.0f)));

test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider,
kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider,
kOpenVINOExecutionProvider})
.Config(run_with_tunable_op)
.RunWithConfig();
}

TYPED_TEST(GemmOpTypedTests, MissingBias) {
OpTester test("Gemm", 11);

Expand Down

0 comments on commit fe8a10c

Please sign in to comment.