From 4064fcbd09900d69a4d2c4d8c9b1f41cbce621ba Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Wed, 13 Mar 2024 14:46:02 +0100 Subject: [PATCH 1/9] Fix broken MaxPool NHWC Ops and ensure NCHW / NHWC parity. --- .../core/providers/cuda/cuda_nhwc_kernels.cc | 7 ++ .../providers/cuda/nn/max_pool_with_index.cu | 116 +++++++++++++----- .../providers/cuda/nn/max_pool_with_index.h | 2 +- onnxruntime/core/providers/cuda/nn/pool.cc | 17 ++- .../test/providers/cpu/nn/pool_op_test.cc | 36 +++--- 5 files changed, 126 insertions(+), 52 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc index 8fdcaacdb0f29..7afd2d430ec46 100644 --- a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc @@ -74,6 +74,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kM MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, float, MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, MLFloat16, MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, int8_t, MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, uint8_t, MaxPool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, double, @@ -165,6 +167,7 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) { kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, float, MaxPool)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo +template __global__ void MaxPoolWithIndexKernel( int64_t batch, int64_t channels, @@ -44,11 +45,29 @@ __global__ void MaxPoolWithIndexKernel( int id = blockIdx.x * blockDim.x + threadIdx.x; if (id >= output_size) return; + auto compute_offset = + [height, width, depth, channels](int n_index, int c_index, int h_index, int w_index, int d_index) -> int64_t { + if constexpr (Layout == LAYOUT_NCHW) { + return (((n_index * channels + c_index) * height + h_index) * width + w_index) * depth + d_index; + } else if constexpr (Layout == LAYOUT_NHWC) { + return (((n_index * height + h_index) * width + w_index) * depth + d_index) * channels + c_index; + } else { + static_assert(0, "unsupported layout"); + } + }; + int d_index, w_index, h_index, c_index, n_index, id_tmp; - fdm_d.divmod(id, id_tmp, d_index); - fdm_w.divmod(id_tmp, id_tmp, w_index); - fdm_h.divmod(id_tmp, id_tmp, h_index); - fdm_c.divmod(id_tmp, n_index, c_index); + if constexpr (Layout == LAYOUT_NCHW) { + fdm_d.divmod(id, id_tmp, d_index); + fdm_w.divmod(id_tmp, id_tmp, w_index); + fdm_h.divmod(id_tmp, id_tmp, h_index); + fdm_c.divmod(id_tmp, n_index, c_index); + } else if constexpr (Layout == LAYOUT_NHWC) { + fdm_c.divmod(id, id_tmp, c_index); + fdm_d.divmod(id_tmp, id_tmp, d_index); + fdm_w.divmod(id_tmp, id_tmp, w_index); + fdm_h.divmod(id_tmp, n_index, h_index); + } int64_t d_start = d_index * stride_d - pad_d; int64_t w_start = w_index * stride_w - pad_w; @@ -64,29 +83,44 @@ __global__ void MaxPoolWithIndexKernel( int64_t d_index_max = -1; int64_t w_index_max = -1; int64_t h_index_max = -1; - int64_t offset = (n_index * channels + c_index) * height * width * depth; + int64_t offset = compute_offset(n_index, c_index, 0, 0, 0); const T* p_slice = p_input + offset; - T maxval = p_slice[h_start * width * depth + w_start * depth + d_start] - (T)1; + T maxval = p_slice[compute_offset(0, 0, h_start, w_start, d_start)] - (T)1; for (int64_t d = d_start; d < d_end; d += dilation_d) { for (int64_t w = w_start; w < w_end; w += dilation_w) { for (int64_t h = h_start; h < h_end; h += dilation_h) { - if (p_slice[h * width * depth + w * depth + d] > maxval) { + auto pool_offset = compute_offset(0, 0, h, w, d); + if (p_slice[pool_offset] > maxval) { h_index_max = h; w_index_max = w; d_index_max = d; - maxval = static_cast(p_slice[h * width * depth + w * depth + d]); + maxval = static_cast(p_slice[pool_offset]); } } } } - p_output[id] = p_input[offset + h_index_max * width * depth + w_index_max * depth + d_index_max]; + p_output[id] = p_input[offset + compute_offset(0, 0, h_index_max, w_index_max, d_index_max)]; + if (p_indices) { - p_indices[id] = storage_order == 0 ? offset + h_index_max * width * depth + w_index_max * depth + d_index_max - : offset + h_index_max + w_index_max * height + d_index_max * width * height; + if constexpr (Layout == LAYOUT_NCHW) { + p_indices[id] = storage_order == 0 ? offset + h_index_max * width * depth + w_index_max * depth + d_index_max + : offset + h_index_max + w_index_max * height + d_index_max * width * height; + } else if constexpr (Layout == LAYOUT_NHWC) { + // The tests currently have to be provided in NHWC layout so that tests do not fail. When converting between + // layouts, does it make sense to do an index conversion as well? + // Storing indices in NHWC layout isn't critical as they are supposed to be used by Unpooling operations + // which currently assume that indices reference to Tensors in NHWC layout. + int64_t id_nchw = (((n_index * channels + c_index) * pooled_height + h_index) * pooled_width + w_index) * pooled_depth + d_index; + int64_t offset_nchw = (n_index * channels + c_index) * width * height * depth; + + p_indices[id_nchw] = (storage_order == 0) + ? offset_nchw + h_index_max * width * depth + w_index_max * depth + d_index_max + : offset_nchw + h_index_max + w_index_max * height + d_index_max * width * height; + } } } -template +template void MaxPoolWithIndex( cudaStream_t stream, const TensorShape& input_shape, @@ -99,14 +133,30 @@ void MaxPoolWithIndex( const T* p_input, T* p_output, int64_t* p_indices) { - int64_t batchs = input_shape[0]; - int64_t channels = input_shape[1]; - int64_t height = input_shape[2]; - int64_t width = kernel_shape.size() > 1 ? input_shape[3] : 1; - int64_t depth = kernel_shape.size() > 2 ? input_shape[4] : 1; - int64_t pooled_height = output_shape[2]; - int64_t pooled_width = kernel_shape.size() > 1 ? output_shape[3] : 1; - int64_t pooled_depth = kernel_shape.size() > 2 ? output_shape[4] : 1; + int64_t batchs, channels, height, width, depth; + int64_t pooled_height, pooled_width, pooled_depth; + if constexpr (Layout == LAYOUT_NCHW) { + batchs = input_shape[0]; + channels = input_shape[1]; + height = input_shape[2]; + width = kernel_shape.size() > 1 ? input_shape[3] : 1; + depth = kernel_shape.size() > 2 ? input_shape[4] : 1; + + pooled_height = output_shape[2]; + pooled_width = kernel_shape.size() > 1 ? output_shape[3] : 1; + pooled_depth = kernel_shape.size() > 2 ? output_shape[4] : 1; + + } else if constexpr (Layout == LAYOUT_NHWC) { + batchs = input_shape[0]; + height = input_shape[1]; + width = kernel_shape.size() > 1 ? input_shape[2] : 1; + depth = kernel_shape.size() > 2 ? input_shape[3] : 1; + channels = input_shape[input_shape.NumDimensions() - 1]; + + pooled_height = output_shape[1]; + pooled_width = kernel_shape.size() > 1 ? output_shape[2] : 1; + pooled_depth = kernel_shape.size() > 2 ? output_shape[3] : 1; + } int64_t kernel_h = kernel_shape[0]; int64_t kernel_w = kernel_shape.size() > 1 ? kernel_shape[1] : 1; int64_t kernel_d = kernel_shape.size() > 2 ? kernel_shape[2] : 1; @@ -130,7 +180,7 @@ void MaxPoolWithIndex( fast_divmod fdm_d(static_cast(pooled_depth)); int blocksPerGrid = (int)((output_size + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock); - MaxPoolWithIndexKernel<<>>( + MaxPoolWithIndexKernel<<>>( batchs, channels, height, @@ -162,8 +212,8 @@ void MaxPoolWithIndex( p_indices); } -#define INSTANTIATEMAXPOOLWITHINDEX(T) \ - template void MaxPoolWithIndex( \ +#define INSTANTIATEMAXPOOLWITHINDEX(T, Layout) \ + template void MaxPoolWithIndex( \ cudaStream_t stream, \ const TensorShape& input_shape, \ const TensorShape& output_shape, \ @@ -176,11 +226,19 @@ void MaxPoolWithIndex( T* p_output, \ int64_t* p_indices); -INSTANTIATEMAXPOOLWITHINDEX(float) -INSTANTIATEMAXPOOLWITHINDEX(double) -INSTANTIATEMAXPOOLWITHINDEX(half) -INSTANTIATEMAXPOOLWITHINDEX(int8_t) -INSTANTIATEMAXPOOLWITHINDEX(uint8_t) +INSTANTIATEMAXPOOLWITHINDEX(float, LAYOUT_NCHW) +INSTANTIATEMAXPOOLWITHINDEX(double, LAYOUT_NCHW) +INSTANTIATEMAXPOOLWITHINDEX(half, LAYOUT_NCHW) +INSTANTIATEMAXPOOLWITHINDEX(int8_t, LAYOUT_NCHW) +INSTANTIATEMAXPOOLWITHINDEX(uint8_t, LAYOUT_NCHW) + +#ifdef ENABLE_CUDA_NHWC_OPS +INSTANTIATEMAXPOOLWITHINDEX(float, LAYOUT_NHWC) +INSTANTIATEMAXPOOLWITHINDEX(double, LAYOUT_NHWC) +INSTANTIATEMAXPOOLWITHINDEX(half, LAYOUT_NHWC) +INSTANTIATEMAXPOOLWITHINDEX(int8_t, LAYOUT_NHWC) +INSTANTIATEMAXPOOLWITHINDEX(uint8_t, LAYOUT_NHWC) +#endif } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/max_pool_with_index.h b/onnxruntime/core/providers/cuda/nn/max_pool_with_index.h index 27f5b241cc785..98f14c3f6a626 100644 --- a/onnxruntime/core/providers/cuda/nn/max_pool_with_index.h +++ b/onnxruntime/core/providers/cuda/nn/max_pool_with_index.h @@ -7,7 +7,7 @@ namespace onnxruntime { namespace cuda { -template +template void MaxPoolWithIndex( cudaStream_t stream, const TensorShape& input_shape, diff --git a/onnxruntime/core/providers/cuda/nn/pool.cc b/onnxruntime/core/providers/cuda/nn/pool.cc index 8bc96958693bc..05bd4e2a37941 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.cc +++ b/onnxruntime/core/providers/cuda/nn/pool.cc @@ -87,6 +87,8 @@ POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 11, 11, kMSInt POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 11, 11, kMSInternalNHWCDomain, true) POOLING_KERNEL_WITH_INDICES(MaxPool, float, MaxPool<8>, 12, kMSInternalNHWCDomain, true) POOLING_KERNEL_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 12, kMSInternalNHWCDomain, true) +POOLING_KERNEL_WITH_INDICES(MaxPool, int8_t, MaxPool<8>, 12, kMSInternalNHWCDomain, true) +POOLING_KERNEL_WITH_INDICES(MaxPool, uint8_t, MaxPool<8>, 12, kMSInternalNHWCDomain, true) POOLING_KERNEL(GlobalMaxPool, float, MaxPool<1>, 1, kMSInternalNHWCDomain, true) POOLING_KERNEL(GlobalMaxPool, MLFloat16, MaxPool<1>, 1, kMSInternalNHWCDomain, true) @@ -165,7 +167,7 @@ Status Pool::ComputeInternal(OpKernelContext* context) const pads.assign(kernel_shape.size(), 0); strides.assign(kernel_shape.size(), 1); } - auto out_channel = NHWC ? x_shape[3] : x_shape[1]; + auto out_channel = NHWC ? x_shape[x_dims.size() - 1] : x_shape[1]; auto y_dims = pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, NHWC); TensorShape y_shape(y_dims); Tensor* Y = context->Output(0, y_shape); @@ -255,7 +257,7 @@ Status Pool, NHWC>::ComputeInternal(OpKernelContext* context) cons pads.assign(kernel_shape.size(), 0); strides.assign(kernel_shape.size(), 1); } - auto out_channel = NHWC ? x_shape[3] : x_shape[1]; + auto out_channel = NHWC ? x_shape[x_shape.NumDimensions() - 1] : x_shape[1]; auto y_dims = this->pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, NHWC); Tensor* Y = context->Output(0, TensorShape(y_dims)); @@ -265,10 +267,17 @@ Status Pool, NHWC>::ComputeInternal(OpKernelContext* context) cons auto x_data = reinterpret_cast(X->Data()); auto y_data = reinterpret_cast(Y->MutableData()); - Tensor* I = context->Output(1, TensorShape(y_dims)); + + // I is in NCHW format and the contained indices use NCHW math to compute the index + auto i_dims = y_dims; + if (NHWC) { + std::swap(i_dims[1], i_dims[x_shape.NumDimensions() - 1]); + } + + Tensor* I = context->Output(1, TensorShape(i_dims)); if (nullptr != I || !this->pool_attrs_.default_dilations) { auto i_data = nullptr == I ? nullptr : I->MutableData(); - MaxPoolWithIndex(this->Stream(context), x_shape, TensorShape(y_dims), kernel_shape, strides, pads, + MaxPoolWithIndex(this->Stream(context), x_shape, TensorShape(y_dims), kernel_shape, strides, pads, this->pool_attrs_.dilations, this->pool_attrs_.storage_order, x_data, y_data, i_data); } else { ORT_RETURN_IF_ERROR((Pool, NHWC>::ComputeInternal(context))); diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index e24cda17166ed..c905ff848ffdd 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -58,7 +58,7 @@ TEST(PoolTest, MaxPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); // TensorRT: result differs - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } // Only CUDA kernel has float 16 support @@ -117,7 +117,7 @@ TEST(PoolTest, MaxPool_F16) { test.AddInput("X", x_dims, f_X); test.AddOutput("Y", expected_dims, f_Y); // TensorRT: Assertion `!attrs.count("pads")' failed - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } #endif @@ -170,7 +170,7 @@ static void MaxPool_8_WithIndexTest(bool has_index, int64_t storage_order = 0) { : test.AddOutput("Indices", expected_dims, expected_indices_col); } test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kDnnlExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, + {kDnnlExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider, kArmNNExecutionProvider, kOpenVINOExecutionProvider}); } @@ -200,7 +200,7 @@ TEST(PoolTest, MaxPool1D) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } static void MaxPool1D_8_WithIndexTest(int64_t storage_order) { @@ -222,7 +222,7 @@ static void MaxPool1D_8_WithIndexTest(int64_t storage_order) { test.AddOutput("Y", expected_dims, expected_vals); test.AddOutput("Indices", expected_dims, expected_indices); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); + {kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool1D_8_With_Index) { @@ -249,7 +249,7 @@ static void MaxPool1D_12_WithIndexTest_int8(int64_t storage_order) { test.AddOutput("Y", expected_dims, expected_vals); test.AddOutput("Indices", expected_dims, expected_indices); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); + {kTensorrtExecutionProvider, kAclExecutionProvider}); } static void MaxPool1D_12_WithIndexTest_uint8(int64_t storage_order) { @@ -271,7 +271,7 @@ static void MaxPool1D_12_WithIndexTest_uint8(int64_t storage_order) { test.AddOutput("Y", expected_dims, expected_vals); test.AddOutput("Indices", expected_dims, expected_indices); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); + {kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool1D_12_With_Index_8bits) { @@ -309,9 +309,9 @@ TEST(PoolTest, MaxPool2D_uint8) { test.AddOutput("Output", output_shape, output); #if defined(OPENVINO_CONFIG_GPU_FP32) || defined(OPENVINO_CONFIG_GPU_FP16) - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kOpenVINOExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {skOpenVINOExecutionProvider}); #else - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); #endif } @@ -337,7 +337,7 @@ TEST(PoolTest, MaxPool_10_Dilation_1d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_DefaultDilations) { @@ -357,7 +357,7 @@ TEST(PoolTest, MaxPool_DefaultDilations) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_DefaultDilations_int8) { @@ -377,7 +377,7 @@ TEST(PoolTest, MaxPool_DefaultDilations_int8) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_DefaultDilations_uint8) { @@ -397,7 +397,7 @@ TEST(PoolTest, MaxPool_DefaultDilations_uint8) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_10_DilationPadding_1d) { @@ -451,7 +451,7 @@ TEST(PoolTest, MaxPool_10_Dilation_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_10_Dilation_2d_int8) { @@ -479,7 +479,7 @@ TEST(PoolTest, MaxPool_10_Dilation_2d_int8) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_10_DilationPadding_2d) { @@ -536,7 +536,7 @@ TEST(PoolTest, MaxPool_10_Dilation_Ceil0_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); + { kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool_12_Dilation_Ceil0_2d_int8) { @@ -565,7 +565,7 @@ TEST(PoolTest, MaxPool_12_Dilation_Ceil0_2d_int8) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); + {kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool_10_Dilation_Ceil1_2d) { @@ -595,7 +595,7 @@ TEST(PoolTest, MaxPool_10_Dilation_Ceil1_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); + {kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool_10_DilationPadding_3d) { From 205f227b7ff78464cc41f8e705f1bbc85d81113f Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Wed, 13 Mar 2024 15:50:50 +0100 Subject: [PATCH 2/9] Fix typo and rerun linter --- onnxruntime/core/providers/cuda/nn/pool.cc | 5 ++--- onnxruntime/test/providers/cpu/nn/pool_op_test.cc | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/pool.cc b/onnxruntime/core/providers/cuda/nn/pool.cc index 05bd4e2a37941..09ee8855eb8b3 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.cc +++ b/onnxruntime/core/providers/cuda/nn/pool.cc @@ -267,18 +267,17 @@ Status Pool, NHWC>::ComputeInternal(OpKernelContext* context) cons auto x_data = reinterpret_cast(X->Data()); auto y_data = reinterpret_cast(Y->MutableData()); - // I is in NCHW format and the contained indices use NCHW math to compute the index auto i_dims = y_dims; if (NHWC) { std::swap(i_dims[1], i_dims[x_shape.NumDimensions() - 1]); } - + Tensor* I = context->Output(1, TensorShape(i_dims)); if (nullptr != I || !this->pool_attrs_.default_dilations) { auto i_data = nullptr == I ? nullptr : I->MutableData(); MaxPoolWithIndex(this->Stream(context), x_shape, TensorShape(y_dims), kernel_shape, strides, pads, - this->pool_attrs_.dilations, this->pool_attrs_.storage_order, x_data, y_data, i_data); + this->pool_attrs_.dilations, this->pool_attrs_.storage_order, x_data, y_data, i_data); } else { ORT_RETURN_IF_ERROR((Pool, NHWC>::ComputeInternal(context))); } diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index c905ff848ffdd..820eb31637840 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -309,7 +309,7 @@ TEST(PoolTest, MaxPool2D_uint8) { test.AddOutput("Output", output_shape, output); #if defined(OPENVINO_CONFIG_GPU_FP32) || defined(OPENVINO_CONFIG_GPU_FP16) - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {skOpenVINOExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); #else test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); #endif @@ -536,7 +536,7 @@ TEST(PoolTest, MaxPool_10_Dilation_Ceil0_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - { kTensorrtExecutionProvider, kAclExecutionProvider}); + {kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool_12_Dilation_Ceil0_2d_int8) { From fc4b4c8d2d1ff5ffc30d5fd7a24d76a988135c6a Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Wed, 13 Mar 2024 16:05:30 +0100 Subject: [PATCH 3/9] static_assert in else branch is not evaluated as constexpr, remove it --- onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu | 3 --- 1 file changed, 3 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu b/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu index 8c6c49eb6140a..3cc325c903aea 100644 --- a/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu +++ b/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu @@ -51,8 +51,6 @@ __global__ void MaxPoolWithIndexKernel( return (((n_index * channels + c_index) * height + h_index) * width + w_index) * depth + d_index; } else if constexpr (Layout == LAYOUT_NHWC) { return (((n_index * height + h_index) * width + w_index) * depth + d_index) * channels + c_index; - } else { - static_assert(0, "unsupported layout"); } }; @@ -145,7 +143,6 @@ void MaxPoolWithIndex( pooled_height = output_shape[2]; pooled_width = kernel_shape.size() > 1 ? output_shape[3] : 1; pooled_depth = kernel_shape.size() > 2 ? output_shape[4] : 1; - } else if constexpr (Layout == LAYOUT_NHWC) { batchs = input_shape[0]; height = input_shape[1]; From f21978a691fb5b5ef6ef39eacb8b56a43d48462c Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Thu, 14 Mar 2024 10:39:30 +0100 Subject: [PATCH 4/9] Fix NHWC<->NCHW conversion in CuDnnTensor::Set. Fix GlobalPool functionality, enable AveragePool NHWC tests and disable all pooling tests not supported by the CUDA EP. Add more MaxPool1D test cases. --- .../core/providers/cuda/cudnn_common.cc | 29 +++++-- onnxruntime/core/providers/cuda/nn/pool.cc | 72 ++++++++++------- onnxruntime/core/providers/cuda/nn/pool.h | 6 +- .../test/providers/cpu/nn/pool_op_test.cc | 79 ++++++++++++++----- 4 files changed, 129 insertions(+), 57 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index 39b73163794f0..1ec01ae31cbbf 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -37,13 +37,28 @@ Status CudnnTensor::Set(gsl::span input_dims, cudnnDataType_t dat TensorPitches pitches(input_dims); InlinedVector dims(rank); InlinedVector strides(rank); - for (int i = 0; i < rank; i++) { - dims[i] = gsl::narrow_cast(input_dims[i]); - strides[i] = gsl::narrow_cast(pitches[i]); - } - if (is_nhwc) { - std::swap(dims[1], dims[rank - 1]); - std::swap(strides[1], strides[rank - 1]); + + if (!is_nhwc) { + for (int i = 0; i < rank; i++) { + dims[i] = gsl::narrow_cast(input_dims[i]); + strides[i] = gsl::narrow_cast(pitches[i]); + } + } else { + // NHWDC <-> NCHWD + + // N + dims[0] = gsl::narrow_cast(input_dims[0]); + strides[0] = gsl::narrow_cast(pitches[0]); + + // HWD + for (int i = 1; i < rank - 1; i++) { + dims[i + 1] = gsl::narrow_cast(input_dims[i]); + strides[i + 1] = gsl::narrow_cast(pitches[i]); + } + + // C + dims[1] = input_dims[rank - 1]; + strides[1] = pitches[rank - 1]; } CUDNN_RETURN_IF_ERROR(cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast(rank), dims.data(), strides.data())); return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/nn/pool.cc b/onnxruntime/core/providers/cuda/nn/pool.cc index 09ee8855eb8b3..3bf3e8fd789b2 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.cc +++ b/onnxruntime/core/providers/cuda/nn/pool.cc @@ -147,8 +147,8 @@ class CudnnPoolingDescriptor final { cudnnPoolingDescriptor_t desc_; }; -template -Status Pool::ComputeInternal(OpKernelContext* context) const { +template +Status Pool::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; const Tensor* X = context->Input(0); const TensorShape& x_shape = X->Shape(); @@ -163,12 +163,16 @@ Status Pool::ComputeInternal(OpKernelContext* context) const auto strides = pool_attrs_.strides; if (pool_attrs_.global_pooling) { - kernel_shape.assign(x_dims.begin() + 2, x_dims.end()); + if constexpr (Layout == LAYOUT_NCHW) { + kernel_shape.assign(x_dims.begin() + 2, x_dims.end()); + } else if constexpr (Layout == LAYOUT_NHWC) { + kernel_shape.assign(x_dims.begin() + 1, x_dims.end() - 1); + } pads.assign(kernel_shape.size(), 0); strides.assign(kernel_shape.size(), 1); } - auto out_channel = NHWC ? x_shape[x_dims.size() - 1] : x_shape[1]; - auto y_dims = pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, NHWC); + auto out_channel = (Layout == LAYOUT_NHWC) ? x_shape[x_dims.size() - 1] : x_shape[1]; + auto y_dims = pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, Layout == LAYOUT_NHWC); TensorShape y_shape(y_dims); Tensor* Y = context->Output(0, y_shape); // special case when there is a dim value of 0 in the shape. @@ -180,20 +184,20 @@ Status Pool::ComputeInternal(OpKernelContext* context) const TensorShapeVector x_dims_cudnn(x_dims.begin(), x_dims.end()); TensorShapeVector y_dims_cudnn(y_dims); if (kernel_shape.size() < 2) { - // cudnn only takes 4D or 5D input, so pad dimensions if needed - if (NHWC) { + // cuDNN only takes 4D or 5D input, so pad dimensions if needed + if (Layout == LAYOUT_NHWC) { x_dims_cudnn.insert(x_dims_cudnn.begin() + 1, 1); y_dims_cudnn.insert(y_dims_cudnn.begin() + 1, 1); - kernel_shape.insert(kernel_shape.begin() + 1, 1); - strides.insert(strides.begin() + 1, 1); + pads.insert(pads.begin(), 0); + kernel_shape.insert(kernel_shape.begin(), 1); + strides.insert(strides.begin(), 1); } else { - x_dims_cudnn.push_back(1); - y_dims_cudnn.push_back(1); - kernel_shape.push_back(1); - strides.push_back(1); + x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); + y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); + pads.insert(pads.begin(), 0); + kernel_shape.insert(kernel_shape.begin(), 1); + strides.insert(strides.begin(), 1); } - pads.insert(pads.begin() + kernel_shape.size(), 0); - pads.insert(pads.end(), 0); } cudnnPoolingMode_t mode = CUDNN_POOLING_MAX; @@ -210,8 +214,8 @@ Status Pool::ComputeInternal(OpKernelContext* context) const const auto beta = Consts::Zero; CudnnTensor x_tensor; CudnnTensor y_tensor; - ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); - ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); + ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType(), Layout == LAYOUT_NHWC)); + ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType(), Layout == LAYOUT_NHWC)); const auto input_count = x_shape.Size(); const auto output_count = y_shape.Size(); @@ -227,8 +231,8 @@ Status Pool::ComputeInternal(OpKernelContext* context) const const auto beta = Consts::Zero; CudnnTensor x_tensor; CudnnTensor y_tensor; - ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); - ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); + ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType(), Layout == LAYOUT_NHWC)); + ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType(), Layout == LAYOUT_NHWC)); CUDNN_RETURN_IF_ERROR( PoolingForwardHelper(GetCudnnHandle(context), pooling_desc, &alpha, x_tensor, x_data, &beta, y_tensor, y_data)); @@ -237,8 +241,8 @@ Status Pool::ComputeInternal(OpKernelContext* context) const return Status::OK(); } -template -Status Pool, NHWC>::ComputeInternal(OpKernelContext* context) const { +template +Status Pool, Layout>::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; const Tensor* X = context->Input(0); const TensorShape& x_shape = X->Shape(); @@ -253,12 +257,19 @@ Status Pool, NHWC>::ComputeInternal(OpKernelContext* context) cons auto strides = this->pool_attrs_.strides; if (this->pool_attrs_.global_pooling) { - kernel_shape.assign(x_dims.begin() + 2, x_dims.end()); + // the logic below is most likely broken. Unfortunately no test runs through this case case. + // accessing x_dims.end() should result in a crash since it is OOB. + // i assume the last element is supposed to be accessed and thus used end() -1 / end() - 2. + if constexpr (Layout == LAYOUT_NCHW) { + kernel_shape.assign(x_dims.begin() + 2, x_dims.end() - 1); + } else if constexpr (Layout == LAYOUT_NHWC) { + kernel_shape.assign(x_dims.begin() + 1, x_dims.end() - 2); + } pads.assign(kernel_shape.size(), 0); strides.assign(kernel_shape.size(), 1); } - auto out_channel = NHWC ? x_shape[x_shape.NumDimensions() - 1] : x_shape[1]; - auto y_dims = this->pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, NHWC); + auto out_channel = Layout == LAYOUT_NHWC ? x_shape[x_shape.NumDimensions() - 1] : x_shape[1]; + auto y_dims = this->pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, Layout == LAYOUT_NHWC); Tensor* Y = context->Output(0, TensorShape(y_dims)); // special case when there is a dim value of 0 in the shape. @@ -269,17 +280,20 @@ Status Pool, NHWC>::ComputeInternal(OpKernelContext* context) cons // I is in NCHW format and the contained indices use NCHW math to compute the index auto i_dims = y_dims; - if (NHWC) { - std::swap(i_dims[1], i_dims[x_shape.NumDimensions() - 1]); + if constexpr (Layout == LAYOUT_NHWC) { + // y_dims in NHWDC format, i_dims has to be in NCHWD format. + i_dims.insert(i_dims.begin() + 1, i_dims.back()); // N*C*HWDC + i_dims.pop_back(); // NCHW } Tensor* I = context->Output(1, TensorShape(i_dims)); if (nullptr != I || !this->pool_attrs_.default_dilations) { auto i_data = nullptr == I ? nullptr : I->MutableData(); - MaxPoolWithIndex(this->Stream(context), x_shape, TensorShape(y_dims), kernel_shape, strides, pads, - this->pool_attrs_.dilations, this->pool_attrs_.storage_order, x_data, y_data, i_data); + MaxPoolWithIndex(this->Stream(context), x_shape, TensorShape(y_dims), kernel_shape, + strides, pads, this->pool_attrs_.dilations, + this->pool_attrs_.storage_order, x_data, y_data, i_data); } else { - ORT_RETURN_IF_ERROR((Pool, NHWC>::ComputeInternal(context))); + ORT_RETURN_IF_ERROR((Pool, Layout == LAYOUT_NHWC>::ComputeInternal(context))); } return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/nn/pool.h b/onnxruntime/core/providers/cuda/nn/pool.h index 8b5152a1565a9..97f7c8b8762d5 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.h +++ b/onnxruntime/core/providers/cuda/nn/pool.h @@ -19,10 +19,10 @@ class Pool : public CudaKernel, public PoolBase { Status ComputeInternal(OpKernelContext* context) const override; }; -template -class Pool, NHWC> final : public Pool, NHWC> { +template +class Pool, Layout> final : public Pool, Layout> { public: - explicit Pool(const OpKernelInfo& info) : Pool, NHWC>(info) {} + explicit Pool(const OpKernelInfo& info) : Pool, Layout>(info) {} Status ComputeInternal(OpKernelContext* context) const override; }; diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 820eb31637840..5e59abc0c9959 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -185,7 +185,7 @@ TEST(PoolTest, MaxPool_8_With_Index) { MaxPool_8_WithIndexTest(true, 1 /*storage_order*/); // col major } -TEST(PoolTest, MaxPool1D) { +TEST(PoolTest, MaxPool1D_case1) { OpTester test("MaxPool"); test.AddAttribute("auto_pad", ""); @@ -200,7 +200,46 @@ TEST(PoolTest, MaxPool1D) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); +} + +TEST(PoolTest, MaxPool1D_case2) { + OpTester test("MaxPool"); + // no padding + test.AddAttribute("auto_pad", "VALID"); + test.AddAttribute("strides", std::vector{1}); + test.AddAttribute("pads", vector{0, 0}); + test.AddAttribute("kernel_shape", vector{2}); + + std::vector x_vals = {1, 2, 3, 4, 5}; + std::vector x_dims = {1, 1, 5}; + // The last dim is (5-2+1)/1 = 4 + std::vector expected_dims = {1, 1, 4}; + std::vector expected_vals = {2, 3, 4, 5}; + + test.AddInput("X", x_dims, x_vals); + test.AddOutput("Y", expected_dims, expected_vals); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); +} + +TEST(PoolTest, MaxPool1D_case3) { + OpTester test("MaxPool"); + test.AddAttribute("auto_pad", ""); + test.AddAttribute("strides", std::vector{1}); + // Pad one element + test.AddAttribute("pads", vector{0, 1}); + test.AddAttribute("kernel_shape", vector{2}); + + std::vector x_vals = {1, 2, 3, 4, 5}; + std::vector x_dims = {1, 1, 5}; + // Since we padded it, the last dim is larger compared to the case above + std::vector expected_dims = {1, 1, 5}; + std::vector expected_vals = {2, 3, 4, 5, 5}; + + test.AddInput("X", x_dims, x_vals); + test.AddOutput("Y", expected_dims, expected_vals); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider}); } static void MaxPool1D_8_WithIndexTest(int64_t storage_order) { @@ -707,7 +746,7 @@ TEST(PoolTest, GlobalMaxPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); } TEST(PoolTest, GlobalMaxPool3D) { @@ -783,7 +822,7 @@ TEST(PoolTest, GlobalMaxPool3D) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(PoolTest, AveragePool) { @@ -864,7 +903,7 @@ TEST(PoolTest, AveragePool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(PoolTest, AveragePool_IncludePadPixel) { @@ -888,7 +927,7 @@ TEST(PoolTest, AveragePool_IncludePadPixel) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } // test 'strides' attribute not specified @@ -907,7 +946,7 @@ TEST(PoolTest, AveragePool_DefaultStrides) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(PoolTest, AveragePool_10_ceil1_2d) { @@ -931,7 +970,7 @@ TEST(PoolTest, AveragePool_10_ceil1_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); + {kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, AveragePool_19_dilation_2d) { @@ -955,7 +994,9 @@ TEST(PoolTest, AveragePool_19_dilation_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider, kOpenVINOExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, + kTensorrtExecutionProvider, kAclExecutionProvider, kOpenVINOExecutionProvider}); } TEST(PoolTest, GlobalAveragePool) { @@ -1031,7 +1072,7 @@ TEST(PoolTest, GlobalAveragePool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); } TEST(PoolTest, GlobalAveragePool_Large_128) { @@ -1044,7 +1085,7 @@ TEST(PoolTest, GlobalAveragePool_Large_128) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals, /*sort_output=*/false, /*rel_error=*/1e-3f, /*abs_error=*/1e-2f); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); } TEST(PoolTest, GlobalAveragePool_Large_256) { @@ -1057,7 +1098,7 @@ TEST(PoolTest, GlobalAveragePool_Large_256) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals, /*sort_output=*/false, /*rel_error=*/1e-3f, /*abs_error=*/1e-2f); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); } TEST(PoolTest, LpPool) { @@ -1364,7 +1405,7 @@ TEST(PoolTest, LpPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kCudaNHWCExecutionProvider}); } // test data generated with lp_pool_test_generator.py @@ -1396,7 +1437,8 @@ TEST(PoolTest, LpPool1d) { // https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html#a94f434942252e6d98ac17705c06ce060 // TensorRT does not support 1d pooling - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); y_count++; } } @@ -1428,7 +1470,7 @@ TEST(PoolTest, LpPool2d) { test.AddAttribute("kernel_shape", kernel_sizes[kernel_size_count]); test.AddOutput("Y", y_sizes[y_count], ys[y_count]); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kCudaNHWCExecutionProvider}); y_count++; } } @@ -1446,7 +1488,8 @@ TEST(PoolTest, LpPoolCeilMode) { // https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html#a94f434942252e6d98ac17705c06ce060 // TensorRT does not support 1d pooling - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, GlobalLpPool) { @@ -1701,7 +1744,7 @@ TEST(PoolTest, GlobalLpPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kCudaNHWCExecutionProvider}); } TEST(PoolTest, MaxPoolDimWithZeroForN) { @@ -1719,7 +1762,7 @@ TEST(PoolTest, MaxPoolDimWithZeroForN) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}); + {kTensorrtExecutionProvider, kQnnExecutionProvider}); } } // namespace test From c87467695080443850a2cdc800ece4ff9cafd92a Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Thu, 14 Mar 2024 10:45:05 +0100 Subject: [PATCH 5/9] Fix rocm pipeline --- onnxruntime/core/providers/rocm/nn/pool.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/rocm/nn/pool.cc b/onnxruntime/core/providers/rocm/nn/pool.cc index 045c8b55c0b0d..3a82ab598004b 100644 --- a/onnxruntime/core/providers/rocm/nn/pool.cc +++ b/onnxruntime/core/providers/rocm/nn/pool.cc @@ -257,7 +257,7 @@ Status Pool>::ComputeInternal(OpKernelContext* context) const { Tensor* I = context->Output(1, TensorShape(y_dims)); if (nullptr != I || !this->pool_attrs_.default_dilations) { auto i_data = nullptr == I ? nullptr : I->MutableData(); - MaxPoolWithIndex( + MaxPoolWithIndex( this->Stream(context), x_shape, TensorShape(y_dims), From 6b8c722b328fcbd3066ebc400025697626d21315 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Fri, 15 Mar 2024 17:45:47 +0100 Subject: [PATCH 6/9] Add logic to check for unsupported asymmetric padding in pooling operations. --- .../core/providers/cuda/cudnn_common.cc | 4 +- .../providers/cuda/nn/max_pool_with_index.cu | 5 +- onnxruntime/core/providers/cuda/nn/pool.cc | 52 ++++++++++++------- .../test/providers/cpu/nn/pool_op_test.cc | 7 ++- 4 files changed, 40 insertions(+), 28 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index 1ec01ae31cbbf..9aa011c1d0ec4 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -57,8 +57,8 @@ Status CudnnTensor::Set(gsl::span input_dims, cudnnDataType_t dat } // C - dims[1] = input_dims[rank - 1]; - strides[1] = pitches[rank - 1]; + dims[1] = gsl::narrow_cast(input_dims[rank - 1]); + strides[1] = gsl::narrow_cast(pitches[rank - 1]); } CUDNN_RETURN_IF_ERROR(cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast(rank), dims.data(), strides.data())); return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu b/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu index 3cc325c903aea..9311f044f4ec5 100644 --- a/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu +++ b/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu @@ -45,7 +45,7 @@ __global__ void MaxPoolWithIndexKernel( int id = blockIdx.x * blockDim.x + threadIdx.x; if (id >= output_size) return; - auto compute_offset = + auto compute_offset = [height, width, depth, channels](int n_index, int c_index, int h_index, int w_index, int d_index) -> int64_t { if constexpr (Layout == LAYOUT_NCHW) { return (((n_index * channels + c_index) * height + h_index) * width + w_index) * depth + d_index; @@ -108,7 +108,8 @@ __global__ void MaxPoolWithIndexKernel( // layouts, does it make sense to do an index conversion as well? // Storing indices in NHWC layout isn't critical as they are supposed to be used by Unpooling operations // which currently assume that indices reference to Tensors in NHWC layout. - int64_t id_nchw = (((n_index * channels + c_index) * pooled_height + h_index) * pooled_width + w_index) * pooled_depth + d_index; + int64_t id_nchw = + (((n_index * channels + c_index) * pooled_height + h_index) * pooled_width + w_index) * pooled_depth + d_index; int64_t offset_nchw = (n_index * channels + c_index) * width * height * depth; p_indices[id_nchw] = (storage_order == 0) diff --git a/onnxruntime/core/providers/cuda/nn/pool.cc b/onnxruntime/core/providers/cuda/nn/pool.cc index 3bf3e8fd789b2..3e8060bb083f4 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.cc +++ b/onnxruntime/core/providers/cuda/nn/pool.cc @@ -158,9 +158,19 @@ Status Pool::ComputeInternal(OpKernelContext* context) cons return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input dimension cannot be less than 3."); } + // cuDNN does not support asymmetrical padding, check for symmetry. + for (size_t idx = 0; idx < pool_attrs_.pads.size() / 2; ++idx) { + if (pool_attrs_.pads[idx] != pool_attrs_.pads[pool_attrs_.pads.size() / 2 + idx]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "pads not symmetric, unsupported"); + } + } + auto kernel_shape = pool_attrs_.kernel_shape; - auto pads = pool_attrs_.pads; auto strides = pool_attrs_.strides; + TensorShapeVector pads = pool_attrs_.pads; + + // cuDNN supports only symmetric padding, cut of all x{i}_end items + pads.resize(pads.size() / 2); if (pool_attrs_.global_pooling) { if constexpr (Layout == LAYOUT_NCHW) { @@ -172,7 +182,12 @@ Status Pool::ComputeInternal(OpKernelContext* context) cons strides.assign(kernel_shape.size(), 1); } auto out_channel = (Layout == LAYOUT_NHWC) ? x_shape[x_dims.size() - 1] : x_shape[1]; - auto y_dims = pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, Layout == LAYOUT_NHWC); + + // shape inference done in SetOutputSize requires begin + end for padding, duplicate pads vector + TensorShapeVector asymmetrical_pads = pads; + std::copy(pads.begin(), pads.end(), std::back_insert_iterator(asymmetrical_pads)); + + auto y_dims = pool_attrs_.SetOutputSize(x_shape, out_channel, &asymmetrical_pads, Layout == LAYOUT_NHWC); TensorShape y_shape(y_dims); Tensor* Y = context->Output(0, y_shape); // special case when there is a dim value of 0 in the shape. @@ -185,18 +200,18 @@ Status Pool::ComputeInternal(OpKernelContext* context) cons TensorShapeVector y_dims_cudnn(y_dims); if (kernel_shape.size() < 2) { // cuDNN only takes 4D or 5D input, so pad dimensions if needed - if (Layout == LAYOUT_NHWC) { - x_dims_cudnn.insert(x_dims_cudnn.begin() + 1, 1); - y_dims_cudnn.insert(y_dims_cudnn.begin() + 1, 1); - pads.insert(pads.begin(), 0); - kernel_shape.insert(kernel_shape.begin(), 1); - strides.insert(strides.begin(), 1); - } else { - x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); - y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); - pads.insert(pads.begin(), 0); - kernel_shape.insert(kernel_shape.begin(), 1); - strides.insert(strides.begin(), 1); + if constexpr (Layout == LAYOUT_NHWC) { + x_dims_cudnn.insert(x_dims_cudnn.end() - 1, 1); + y_dims_cudnn.insert(y_dims_cudnn.end() - 1, 1); + pads.insert(pads.end(), 0); + kernel_shape.insert(kernel_shape.end(), 1); + strides.insert(strides.end(), 1); + } else { // Layout == LAYOUT_NCHW + x_dims_cudnn.insert(x_dims_cudnn.end(), 1); + y_dims_cudnn.insert(y_dims_cudnn.end(), 1); + pads.insert(pads.end(), 0); + kernel_shape.insert(kernel_shape.end(), 1); + strides.insert(strides.end(), 1); } } @@ -257,15 +272,12 @@ Status Pool, Layout>::ComputeInternal(OpKernelContext* context) co auto strides = this->pool_attrs_.strides; if (this->pool_attrs_.global_pooling) { - // the logic below is most likely broken. Unfortunately no test runs through this case case. - // accessing x_dims.end() should result in a crash since it is OOB. - // i assume the last element is supposed to be accessed and thus used end() -1 / end() - 2. if constexpr (Layout == LAYOUT_NCHW) { - kernel_shape.assign(x_dims.begin() + 2, x_dims.end() - 1); + kernel_shape.assign(x_dims.begin() + 2, x_dims.end()); } else if constexpr (Layout == LAYOUT_NHWC) { - kernel_shape.assign(x_dims.begin() + 1, x_dims.end() - 2); + kernel_shape.assign(x_dims.begin() + 1, x_dims.end() - 1); } - pads.assign(kernel_shape.size(), 0); + pads.assign(2 * kernel_shape.size(), 0); // x{i}_begin + x{i}_end strides.assign(kernel_shape.size(), 1); } auto out_channel = Layout == LAYOUT_NHWC ? x_shape[x_shape.NumDimensions() - 1] : x_shape[1]; diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 5e59abc0c9959..17cc2e8285ad2 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -200,7 +200,7 @@ TEST(PoolTest, MaxPool1D_case1) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool1D_case2) { @@ -219,7 +219,7 @@ TEST(PoolTest, MaxPool1D_case2) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool1D_case3) { @@ -238,8 +238,7 @@ TEST(PoolTest, MaxPool1D_case3) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } static void MaxPool1D_8_WithIndexTest(int64_t storage_order) { From 3c6657cbd24500e03aba6e64ab62c2587335103b Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Mon, 18 Mar 2024 09:01:25 +0100 Subject: [PATCH 7/9] Enable MaxPool with asymmetric padding again for MaxPool. cuDNN does the correct thing padding to the end. --- onnxruntime/core/providers/cuda/nn/pool.cc | 20 ++++--------------- .../test/providers/cpu/nn/pool_op_test.cc | 2 +- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/pool.cc b/onnxruntime/core/providers/cuda/nn/pool.cc index 3e8060bb083f4..4a25cfb45c8be 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.cc +++ b/onnxruntime/core/providers/cuda/nn/pool.cc @@ -158,36 +158,22 @@ Status Pool::ComputeInternal(OpKernelContext* context) cons return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input dimension cannot be less than 3."); } - // cuDNN does not support asymmetrical padding, check for symmetry. - for (size_t idx = 0; idx < pool_attrs_.pads.size() / 2; ++idx) { - if (pool_attrs_.pads[idx] != pool_attrs_.pads[pool_attrs_.pads.size() / 2 + idx]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "pads not symmetric, unsupported"); - } - } - auto kernel_shape = pool_attrs_.kernel_shape; auto strides = pool_attrs_.strides; TensorShapeVector pads = pool_attrs_.pads; - // cuDNN supports only symmetric padding, cut of all x{i}_end items - pads.resize(pads.size() / 2); - if (pool_attrs_.global_pooling) { if constexpr (Layout == LAYOUT_NCHW) { kernel_shape.assign(x_dims.begin() + 2, x_dims.end()); } else if constexpr (Layout == LAYOUT_NHWC) { kernel_shape.assign(x_dims.begin() + 1, x_dims.end() - 1); } - pads.assign(kernel_shape.size(), 0); + pads.assign(2*kernel_shape.size(), 0); strides.assign(kernel_shape.size(), 1); } auto out_channel = (Layout == LAYOUT_NHWC) ? x_shape[x_dims.size() - 1] : x_shape[1]; - // shape inference done in SetOutputSize requires begin + end for padding, duplicate pads vector - TensorShapeVector asymmetrical_pads = pads; - std::copy(pads.begin(), pads.end(), std::back_insert_iterator(asymmetrical_pads)); - - auto y_dims = pool_attrs_.SetOutputSize(x_shape, out_channel, &asymmetrical_pads, Layout == LAYOUT_NHWC); + auto y_dims = pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, Layout == LAYOUT_NHWC); TensorShape y_shape(y_dims); Tensor* Y = context->Output(0, y_shape); // special case when there is a dim value of 0 in the shape. @@ -203,12 +189,14 @@ Status Pool::ComputeInternal(OpKernelContext* context) cons if constexpr (Layout == LAYOUT_NHWC) { x_dims_cudnn.insert(x_dims_cudnn.end() - 1, 1); y_dims_cudnn.insert(y_dims_cudnn.end() - 1, 1); + pads.insert(pads.begin() + pads.size() / 2, 0); pads.insert(pads.end(), 0); kernel_shape.insert(kernel_shape.end(), 1); strides.insert(strides.end(), 1); } else { // Layout == LAYOUT_NCHW x_dims_cudnn.insert(x_dims_cudnn.end(), 1); y_dims_cudnn.insert(y_dims_cudnn.end(), 1); + pads.insert(pads.begin() + pads.size() / 2, 0); pads.insert(pads.end(), 0); kernel_shape.insert(kernel_shape.end(), 1); strides.insert(strides.end(), 1); diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 17cc2e8285ad2..a5050f5d3edbd 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -238,7 +238,7 @@ TEST(PoolTest, MaxPool1D_case3) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } static void MaxPool1D_8_WithIndexTest(int64_t storage_order) { From 071262319d30afeb10bbfbb8e5d455b56a0f1cd4 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Tue, 19 Mar 2024 12:21:51 +0100 Subject: [PATCH 8/9] fix lintrunner issues --- onnxruntime/core/providers/cuda/nn/pool.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cuda/nn/pool.cc b/onnxruntime/core/providers/cuda/nn/pool.cc index 4a25cfb45c8be..4acdcfcf35491 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.cc +++ b/onnxruntime/core/providers/cuda/nn/pool.cc @@ -168,7 +168,7 @@ Status Pool::ComputeInternal(OpKernelContext* context) cons } else if constexpr (Layout == LAYOUT_NHWC) { kernel_shape.assign(x_dims.begin() + 1, x_dims.end() - 1); } - pads.assign(2*kernel_shape.size(), 0); + pads.assign(2 * kernel_shape.size(), 0); strides.assign(kernel_shape.size(), 1); } auto out_channel = (Layout == LAYOUT_NHWC) ? x_shape[x_dims.size() - 1] : x_shape[1]; From 625b345ea22bd8f764279ceb5d2c43659102b333 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 19 Mar 2024 23:45:23 +0000 Subject: [PATCH 9/9] lintrunner --- onnxruntime/test/providers/cpu/nn/pool_op_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index dcb7d9c50385c..c8cf183291518 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -927,7 +927,7 @@ TEST(PoolTest, AveragePool_IncludePadPixel) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.SetOutputTolerance(0.0001f); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } // test 'strides' attribute not specified