Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable FP16 Clip and Handle Bias in FP16 Depthwise Conv #21493

Merged
merged 8 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,7 @@ MLASCALL
MlasConvDepthwise(
const MLAS_FP16* const* Input,
const MLAS_FP16* Filter,
const MLAS_FP16* Bias,
yihonglyu marked this conversation as resolved.
Show resolved Hide resolved
MLAS_FP16* Output,
size_t Channels,
size_t OutputCount,
Expand Down
31 changes: 16 additions & 15 deletions onnxruntime/core/mlas/lib/dwconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ Module Name:

--*/


#include "fp16_common.h"

#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
Expand All @@ -24,19 +23,20 @@ void
MlasConvDepthwiseKernel(
const _mlas_fp16_* const* Input,
const _mlas_fp16_* Filter,
const _mlas_fp16_* Bias,
_mlas_fp16_* Output,
size_t Channels,
size_t OutputCount,
size_t KernelSize,
MLAS_HALF_GEMM_POSTPROCESSOR* PostProc
)
)
{
while (OutputCount > 0) {
size_t ChannelOffset = 0;
size_t c = Channels;

while (c >= 8) {
MLAS_FLOAT16X8 Accumulator = MlasZeroFloat16x8();
MLAS_FLOAT16X8 Accumulator = Bias == nullptr ? MlasZeroFloat16x8() : MlasLoadFloat16x8(&Bias[ChannelOffset]);
size_t ChannelKernelOffset = ChannelOffset;

for (size_t k = 0; k < KernelSize; k++) {
Expand All @@ -54,7 +54,7 @@ MlasConvDepthwiseKernel(
}

if (c >= 4) {
MLAS_FLOAT16X4 Accumulator = MlasZeroFloat16x4();
MLAS_FLOAT16X4 Accumulator = Bias == nullptr ? MlasZeroFloat16x4() : MlasLoadFloat16x4(&Bias[ChannelOffset]);
size_t ChannelKernelOffset = ChannelOffset;

for (size_t k = 0; k < KernelSize; k++) {
Expand All @@ -72,7 +72,7 @@ MlasConvDepthwiseKernel(
}

if (c > 0) {
MLAS_FLOAT16X4 Accumulator = MlasZeroFloat16x4();
MLAS_FLOAT16X4 Accumulator = Bias == nullptr ? MlasZeroFloat16x4() : MlasLoadFloat16x4(&Bias[ChannelOffset]);
yihonglyu marked this conversation as resolved.
Show resolved Hide resolved
size_t ChannelKernelOffset = ChannelOffset;

for (size_t k = 0; k < KernelSize; k++) {
Expand All @@ -86,8 +86,7 @@ MlasConvDepthwiseKernel(
Output += c;
}
if (PostProc) {
PostProc->Process(reinterpret_cast<MLAS_FP16*>(Output - Channels), 0, 0, 1, Channels,
Channels);
PostProc->Process(reinterpret_cast<MLAS_FP16*>(Output - Channels), 0, 0, 1, Channels, Channels);
}
Input += KernelSize;
OutputCount -= 1;
Expand All @@ -101,16 +100,17 @@ void
MlasConvDepthwiseKernel(
const _mlas_fp16_* const* Input,
const _mlas_fp16_* Filter,
const _mlas_fp16_* Bias,
_mlas_fp16_* Output,
size_t Channels,
size_t OutputCount,
size_t KernelSize,
MLAS_HALF_GEMM_POSTPROCESSOR* PostProc
)
)
{
while (OutputCount > 0) {
for (size_t ChannelOffset = 0; ChannelOffset < Channels; ChannelOffset++) {
float Accumulator = 0.0f;
float Accumulator = Bias == nullptr ? 0.0f : MLAS_Half2Float(Bias[ChannelOffset]);
size_t ChannelKernelOffset = ChannelOffset;

for (size_t k = 0; k < KernelSize; k++) {
Expand All @@ -120,35 +120,36 @@ MlasConvDepthwiseKernel(
*Output++ = MLAS_Float2Half(Accumulator);
}
if (PostProc) {
PostProc->Process(reinterpret_cast<MLAS_FP16*>(Output - Channels), 0, 0, 1, Channels,
Channels);
PostProc->Process(reinterpret_cast<MLAS_FP16*>(Output - Channels), 0, 0, 1, Channels, Channels);
}
Input += KernelSize;
OutputCount -= 1;
}
}

#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED

#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED

void
MLASCALL
MlasConvDepthwise(
const MLAS_FP16* const* Input,
const MLAS_FP16* Filter,
const MLAS_FP16* Bias,
MLAS_FP16* Output,
size_t Channels,
size_t OutputCount,
size_t KernelSize,
MLAS_HALF_GEMM_POSTPROCESSOR* PostProc
)
)
{
MlasConvDepthwiseKernel(
reinterpret_cast<const _mlas_fp16_* const*>(Input),
reinterpret_cast<const _mlas_fp16_*>(Filter),
reinterpret_cast<const _mlas_fp16_*>(Bias),
reinterpret_cast<_mlas_fp16_*>(Output),
Channels,
OutputCount,
KernelSize,
PostProc);
PostProc
);
}
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/cpu/fp16/fp16_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ Status FusedConvFp16::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr

bool share_prepacked_weights = (prepacked_weights != nullptr);

const bool is_depthwise_conv = (group_input_channels == 1 && group_output_channels == 1);
// Don't pack the filter buffer if the MlasConvDepthwise path is used.
if (!(group_input_channels == 1 && group_output_channels == 1)) {
if (!is_depthwise_conv) {
packed_W_size_ = MlasHalfGemmPackBSize(group_output_channels, kernel_dim, false);
if (packed_W_size_ != 0) {
size_t packed_W_data_size = SafeInt<size_t>(group_count) * packed_W_size_;
Expand Down Expand Up @@ -472,6 +473,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
MlasConvDepthwise(
worker_indirection_buffer,
reordered_W,
Bdata,
worker_output,
static_cast<size_t>(M),
static_cast<size_t>(output_count),
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/math/clip.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(
float);
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(
kCpuExecutionProvider, kOnnxDomain, Clip, 12, Input, 0,
float, double, int8_t, uint8_t, int32_t, uint32_t, int64_t, uint64_t);
float, MLFloat16, double, int8_t, uint8_t, int32_t, uint32_t, int64_t, uint64_t);
} // namespace op_kernel_type_control

using EnabledClip11Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/test/providers/cpu/math/clip_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,24 @@ TEST(MathOpTest, Clip_Default_uint64) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

TEST(MathOpTest, Clip_MLFloat16) {
OpTester test("Clip", 12);

std::vector<int64_t> dims{3, 3};
test.AddInput<MLFloat16>("X", dims,
{MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(-3.0f),
MLFloat16(-4.0f), MLFloat16(0.0f), MLFloat16(2.0f),
MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(8.0f)});
test.AddInput<MLFloat16>("min", {}, {MLFloat16(0.0f)});
test.AddInput<MLFloat16>("max", {}, {MLFloat16(6.0f)});
test.AddOutput<MLFloat16>("Y", dims,
{MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f),
MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(2.0f),
MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(6.0f)});

test.Run();
}

TEST(MathOpTest, Clip_int32) {
OpTester test("Clip", 12);

Expand Down
198 changes: 197 additions & 1 deletion onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,202 @@ TEST(ConvFp16Test, Conv2D_group) {
TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true);
}

TEST(ConvFp16Test, Depthwise2D_Bias) {
ConvOpAndTestAttributes attrs = {
"", // auto_pad
vector<int64_t>{1, 1}, // dilations
2, // group
vector<int64_t>{1, 1}, // kernel_shape
vector<int64_t>{0, 0, 0, 0}, // pads
vector<int64_t>{1, 1}, // strides
{} // excluded EPs
};

vector<MLFloat16> X = {
MLFloat16(0.0f), MLFloat16(1.0f), MLFloat16(2.0f),
MLFloat16(3.0f), MLFloat16(4.0f), MLFloat16(5.0f),
MLFloat16(6.0f), MLFloat16(7.0f), MLFloat16(8.0f),

MLFloat16(9.0f), MLFloat16(10.0f), MLFloat16(11.0f),
MLFloat16(12.0f), MLFloat16(13.0f), MLFloat16(14.0f),
MLFloat16(15.0f), MLFloat16(16.0f), MLFloat16(17.0f)};
vector<int64_t> X_shape = {1, 2, 3, 3};
vector<MLFloat16> W = {MLFloat16(1.0f), MLFloat16(2.0f)};
vector<int64_t> W_shape = {2, 1, 1, 1};
vector<MLFloat16> B = {MLFloat16(1.0f), MLFloat16(-1.0f)};
vector<int64_t> B_shape = {2};
vector<int64_t> Y_shape = {1, 2, 3, 3};
auto expected_vals = {
MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f),
MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f),
MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f),

MLFloat16(17.0f), MLFloat16(19.0f), MLFloat16(21.0f),
MLFloat16(23.0f), MLFloat16(25.0f), MLFloat16(27.0f),
MLFloat16(29.0f), MLFloat16(31.0f), MLFloat16(33.0f)};

TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape);
TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true);
}

TEST(ConvFp16Test, Depthwise2D_Bias_Complex) {
ConvOpAndTestAttributes attrs = {
"", // auto_pad
vector<int64_t>{1, 1}, // dilations
13, // group
vector<int64_t>{2, 2}, // kernel_shape
vector<int64_t>{0, 0, 0, 0}, // pads
vector<int64_t>{1, 1}, // strides
{} // excluded EPs
};

vector<MLFloat16> X = {
// C = 0
MLFloat16(0.0f), MLFloat16(1.0f),
MLFloat16(2.0f), MLFloat16(3.0f),

// C = 1
MLFloat16(4.0f), MLFloat16(5.0f),
MLFloat16(6.0f), MLFloat16(7.0f),

// C = 2
MLFloat16(8.0f), MLFloat16(9.0f),
MLFloat16(10.0f), MLFloat16(11.0f),

// C = 3
MLFloat16(12.0f), MLFloat16(13.0f),
MLFloat16(14.0f), MLFloat16(15.0f),

// C = 4
MLFloat16(16.0f), MLFloat16(17.0f),
MLFloat16(18.0f), MLFloat16(19.0f),

// C = 5
MLFloat16(20.0f), MLFloat16(21.0f),
MLFloat16(22.0f), MLFloat16(23.0f),

// C = 6
MLFloat16(24.0f), MLFloat16(25.0f),
MLFloat16(26.0f), MLFloat16(27.0f),

// C = 7
MLFloat16(28.0f), MLFloat16(29.0f),
MLFloat16(30.0f), MLFloat16(31.0f),

// C = 8
MLFloat16(32.0f), MLFloat16(33.0f),
MLFloat16(34.0f), MLFloat16(35.0f),

// C = 9
MLFloat16(36.0f), MLFloat16(37.0f),
MLFloat16(38.0f), MLFloat16(39.0f),

// C = 10
MLFloat16(40.0f), MLFloat16(41.0f),
MLFloat16(42.0f), MLFloat16(43.0f),

// C = 11
MLFloat16(44.0f), MLFloat16(45.0f),
MLFloat16(46.0f), MLFloat16(47.0f),

// C = 12
MLFloat16(48.0f), MLFloat16(49.0f),
MLFloat16(50.0f), MLFloat16(51.0f),
};
vector<int64_t> X_shape = {1, 13, 2, 2};
vector<MLFloat16> W = {
// M = 0
MLFloat16(0.0f), MLFloat16(1.0f),
MLFloat16(2.0f), MLFloat16(3.0f),

// M = 1
MLFloat16(4.0f), MLFloat16(5.0f),
MLFloat16(6.0f), MLFloat16(7.0f),

// M = 2
MLFloat16(8.0f), MLFloat16(9.0f),
MLFloat16(10.0f), MLFloat16(11.0f),

// M = 3
MLFloat16(12.0f), MLFloat16(13.0f),
MLFloat16(14.0f), MLFloat16(15.0f),

// M = 4
MLFloat16(16.0f), MLFloat16(17.0f),
MLFloat16(18.0f), MLFloat16(19.0f),

// M = 5
MLFloat16(20.0f), MLFloat16(21.0f),
MLFloat16(22.0f), MLFloat16(23.0f),

// M = 6
MLFloat16(24.0f), MLFloat16(25.0f),
MLFloat16(26.0f), MLFloat16(27.0f),

// M = 7
MLFloat16(28.0f), MLFloat16(29.0f),
MLFloat16(30.0f), MLFloat16(31.0f),

// M = 8
MLFloat16(32.0f), MLFloat16(33.0f),
MLFloat16(34.0f), MLFloat16(35.0f),

// M = 9
MLFloat16(36.0f), MLFloat16(37.0f),
MLFloat16(38.0f), MLFloat16(39.0f),

// M = 10
MLFloat16(40.0f), MLFloat16(41.0f),
MLFloat16(42.0f), MLFloat16(43.0f),

// M = 11
MLFloat16(44.0f), MLFloat16(45.0f),
MLFloat16(46.0f), MLFloat16(47.0f),

// M = 12
MLFloat16(48.0f), MLFloat16(49.0f),
MLFloat16(50.0f), MLFloat16(51.0f),
};
vector<int64_t> W_shape = {13, 1, 2, 2};
vector<MLFloat16> B = {
MLFloat16(1.0f),
MLFloat16(2.0f),
MLFloat16(3.0f),
MLFloat16(4.0f),
MLFloat16(5.0f),
MLFloat16(6.0f),
MLFloat16(7.0f),
MLFloat16(8.0f),
MLFloat16(9.0f),
MLFloat16(10.0f),
MLFloat16(11.0f),
MLFloat16(12.0f),
MLFloat16(13.0f),
};
vector<int64_t> B_shape = {13};
vector<int64_t> Y_shape = {1, 13, 1, 1};
auto expected_vals = {
MLFloat16(15.0f), // 0.0*0.0 + 1.0*1.0 + 2.0*2.0 + 3.0*3.0 + 1.0
MLFloat16(128.0f),
MLFloat16(369.0f),
MLFloat16(738.0f),
MLFloat16(1235.0f),
MLFloat16(1860.0f),
MLFloat16(2613.0f), // 24.0*24.0 + 25.0*25.0 + 26.0*26.0 + 27.0*27.0 + 7.0
MLFloat16(3494.0f),
MLFloat16(4503.0f),
MLFloat16(5640.0f),
MLFloat16(6905.0f),
MLFloat16(8298.0f),
MLFloat16(9819.0f), // 48.0*48.0 + 49.0*49.0 + 50.0*50.0 + 51.0*51.0 + 13.0
};

TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape);

// NNAPI/CoreML EP requires weight to be an initializer
TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true);
}

TEST(ConvFp16Test, ConvDimWithZero) {
ConvOpAndTestAttributes attrs = {
"", // auto_pad
Expand Down Expand Up @@ -1074,4 +1270,4 @@ TEST(ConvFp16Test, SharedPrepackedWeights) {
} // namespace test
} // namespace onnxruntime

#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED
#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED
Loading
Loading