Skip to content

Commit

Permalink
Add more F16 kernels of XNNPack (microsoft#22381)
Browse files Browse the repository at this point in the history
### Description
1. Add Gemm, MatMul, Softmax, AveragePool and  Resize F16 kernels

This PR has included all changes in microsoft#22378


[AB#51066](https://aiinfra.visualstudio.com/6a833879-cd9b-44a4-a9de-adc2d818f13c/_workitems/edit/51066)

[AB#51026](https://aiinfra.visualstudio.com/6a833879-cd9b-44a4-a9de-adc2d818f13c/_workitems/edit/51026)

2. Matrix B must be const and martrix A and B dim_size shoule NOT bigger
than 2 in XNNPack, so I added 2 tests in matmul_test.cc to make sure
it's really tested. (that is, compute() must be called.)
### Motivation and Context
  • Loading branch information
mszhanyi authored and Ishwar Raut committed Nov 19, 2024
1 parent e6a0cf8 commit c37821e
Show file tree
Hide file tree
Showing 19 changed files with 334 additions and 203 deletions.
10 changes: 7 additions & 3 deletions onnxruntime/core/framework/kernel_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <algorithm>
#include <memory>
#include <numeric>
#include <string>
#include <unordered_map>

#include "core/framework/kernel_type_str_resolver.h"
Expand Down Expand Up @@ -310,9 +311,12 @@ Status KernelRegistry::Register(KernelCreateInfo&& create_info) {
for (auto i = range.first; i != range.second; ++i) {
if (i->second.kernel_def &&
i->second.kernel_def->IsConflict(*create_info.kernel_def)) {
return Status(common::ONNXRUNTIME, common::FAIL,
"Failed to add kernel for " + key +
": Conflicting with a registered kernel with op versions.");
int since_version = i->second.kernel_def->SinceVersion().first;
std::string since_version_str = std::to_string(since_version);
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Failed to add kernel for ", key,
": Conflicting with a registered kernel with op versions. the since version is: ",
since_version_str);
}
}

Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@ const NodeUnit* ClipReluChecker(const NodeUnit& node_unit,
} // namespace

bool NodeSupportChecker::IsNodeSupported(const NodeUnit& nodeunit) {
#ifndef XNNPACK_FP16_SUPPORTED
// check whether the hardware support XNNPack FP16
// Note. In CI, ios pipeline on ADO doesn't support XNNPack FP16. Because ADO mac pool is still x64.
const auto& inputs = nodeunit.Inputs();
const auto& x_arg = inputs[0].node_arg;
const auto* x_type = x_arg.TypeAsProto();
if (x_type == nullptr || x_type->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
return false;
}
#endif

static std::unordered_map<std::string, CheckerFn> checkers{
{"Conv", Conv::IsOnnxNodeSupported},
{"ConvTranspose", ConvTranspose::IsOnnxNodeSupported},
Expand Down
115 changes: 78 additions & 37 deletions onnxruntime/core/providers/xnnpack/math/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "gemm.h"
#include "core/framework/transpose_helper.h"
#include "core/providers/utils.h"
#include "core/providers/xnnpack/xnnpack_init.h"

namespace onnxruntime {
namespace xnnpack {
Expand Down Expand Up @@ -37,7 +38,8 @@ bool Gemm::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& gra
const auto* A_type = A_arg->TypeAsProto();

if (A_type == nullptr ||
A_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
(A_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
A_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)) {
break;
}

Expand Down Expand Up @@ -74,19 +76,26 @@ bool Gemm::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& gra
supported = true;

} while (false);

return supported;
}

Gemm::Gemm(const OpKernelInfo& info) : GemmBase(info), XnnpackKernel(info, /*enable_caches*/ true) {
const auto& node{Node()};

info.GetAttrOrDefault<float>("alpha", &alpha_, 1.f);
info.GetAttrOrDefault<float>("beta", &beta_, 1.f);

const auto& node{Node()};
const auto& input_defs = node.InputDefs();
const auto* shapeA = input_defs[0]->Shape();
const auto* shapeB = input_defs[1]->Shape();

const NodeArg& X = *input_defs[0];
auto input_dtype = X.TypeAsProto()->tensor_type().elem_type();
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
op_compute_type_ = OpComputeType::op_compute_type_fp32;
} else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
op_compute_type_ = OpComputeType::op_compute_type_fp16;
}

const NodeArg* C_arg = input_defs.size() == 2 ? nullptr : input_defs[2];

C_matrix_exists_ = C_arg && C_arg->Exists();
Expand Down Expand Up @@ -127,32 +136,49 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr,

// flags - 1 - for no transpose - 0 for transpose
uint32_t flags = trans_B_ == CblasTrans ? 0 : XNN_FLAG_TRANSPOSE_WEIGHTS;

float output_min = clip_min_max_ ? clip_min_max_->first : -INFINITY;
float output_max = clip_min_max_ ? clip_min_max_->second : INFINITY;

const float* bias_Data = nullptr;

if (C_matrix_exists_) {
bias_Data = tensor.Data<float>();
}

auto code_cache = GetCodeCache();
auto weights_cache = GetWeightsCache();
xnn_status status = xnn_status::xnn_status_uninitialized;
struct xnn_operator* p = nullptr;
status = xnn_create_fully_connected_nc_f32(
trans_B_ == CblasNoTrans ? B_->Shape()[0] : B_->Shape()[1], // size_t input_channels,
trans_B_ == CblasNoTrans ? B_->Shape()[1] : B_->Shape()[0], // size_t output_channels,
trans_B_ == CblasNoTrans ? B_->Shape()[0] : B_->Shape()[1], // size_t input_stride,
trans_B_ == CblasNoTrans ? B_->Shape()[1] : B_->Shape()[0], // size_t output_stride,
B_->Data<float>(), // const float* kernel,
bias_Data, // const float* bias,
output_min, output_max,
flags,
GetCodeCache(), GetWeightsCache(),
&p);
float foutput_min = clip_min_max_ ? clip_min_max_->first : -INFINITY;
float foutput_max = clip_min_max_ ? clip_min_max_->second : INFINITY;
if (op_compute_type_ == OpComputeType::op_compute_type_fp32) {
const float* bias_data = nullptr;
if (C_matrix_exists_) {
bias_data = tensor.Data<float>();
}
status = xnn_create_fully_connected_nc_f32(
trans_B_ == CblasNoTrans ? B_->Shape()[0] : B_->Shape()[1], // size_t input_channels,
trans_B_ == CblasNoTrans ? B_->Shape()[1] : B_->Shape()[0], // size_t output_channels,
trans_B_ == CblasNoTrans ? B_->Shape()[0] : B_->Shape()[1], // size_t input_stride,
trans_B_ == CblasNoTrans ? B_->Shape()[1] : B_->Shape()[0], // size_t output_stride,
B_->Data<float>(), // const float* kernel,
bias_data, // const float* bias,
foutput_min, foutput_max,
flags,
code_cache, weights_cache,
&p);
} else if (op_compute_type_ == OpComputeType::op_compute_type_fp16) {
const MLFloat16* bias_data = nullptr;
if (C_matrix_exists_) {
bias_data = tensor.Data<MLFloat16>();
}
status = xnn_create_fully_connected_nc_f16(
trans_B_ == CblasNoTrans ? B_->Shape()[0] : B_->Shape()[1], // size_t input_channels,
trans_B_ == CblasNoTrans ? B_->Shape()[1] : B_->Shape()[0], // size_t output_channels,
trans_B_ == CblasNoTrans ? B_->Shape()[0] : B_->Shape()[1], // size_t input_stride,
trans_B_ == CblasNoTrans ? B_->Shape()[1] : B_->Shape()[0], // size_t output_stride,
B_->Data<MLFloat16>(), // const MLFloat16* kernel,
bias_data, // const float* bias,
foutput_min, foutput_max,
flags,
code_cache, weights_cache,
&p);
}

if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_create_fully_connected_nc_f32 returned ", status);
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_create_fully_connected_nc_",
OpTypeToString(op_compute_type_), " returned ", status);
}
op0_.reset(p);

Expand All @@ -169,19 +195,30 @@ Status Gemm::Compute(OpKernelContext* context) const {
return Status::OK();
}

xnn_status status = xnn_reshape_fully_connected_nc_f32(op0_.get(),
// Number of rows to multiply
trans_A_ == CblasNoTrans ? M_ : K_,
threadpool);
auto reshape_func = xnn_reshape_fully_connected_nc_f32;
if (op_compute_type_ == OpComputeType::op_compute_type_fp16) {
reshape_func = xnn_reshape_fully_connected_nc_f16;
}
xnn_status status = reshape_func(op0_.get(),
// Number of rows to multiply
trans_A_ == CblasNoTrans ? M_ : K_,
threadpool);

if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_fully_connected_nc_f32 returned ", status);
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_fully_connected_nc_",
OpTypeToString(op_compute_type_), " returned ", status);
}

status = xnn_setup_fully_connected_nc_f32(op0_.get(), A->Data<float>(), Y->MutableData<float>());
status = xnn_status_invalid_state;
if (op_compute_type_ == op_compute_type_fp32) {
status = xnn_setup_fully_connected_nc_f32(op0_.get(), A->Data<float>(), Y->MutableData<float>());
} else if (op_compute_type_ == OpComputeType::op_compute_type_fp16) {
status = xnn_setup_fully_connected_nc_f16(op0_.get(), A->Data<MLFloat16>(), Y->MutableData<MLFloat16>());
}

if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_fully_connected_nc_f32 returned ", status);
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_fully_connected_nc_",
OpTypeToString(op_compute_type_), " returned ", status);
}

status = xnn_run_operator(op0_.get(), nullptr);
Expand All @@ -193,19 +230,23 @@ Status Gemm::Compute(OpKernelContext* context) const {
}

ONNX_OPERATOR_VERSIONED_KERNEL_EX(Gemm, kOnnxDomain, 7, 8, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<MLFloat16>()}),
Gemm);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(Gemm, kOnnxDomain, 9, 10, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<MLFloat16>()}),
Gemm);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(Gemm, kOnnxDomain, 11, 12, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<MLFloat16>()}),
Gemm);

ONNX_OPERATOR_KERNEL_EX(Gemm, kOnnxDomain, 13, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<MLFloat16>()}),
Gemm);

} // namespace xnnpack
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/xnnpack/math/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class Gemm : protected GemmBase, public XnnpackKernel {

float alpha_;
float beta_;

OpComputeType op_compute_type_ = OpComputeType::op_compute_type_invalid;
};

} // namespace xnnpack
Expand Down
Loading

0 comments on commit c37821e

Please sign in to comment.