diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc index f416caecd115f..139cb231c07f6 100644 --- a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc @@ -70,6 +70,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kM MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, float, MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, MLFloat16, MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, int8_t, MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, uint8_t, MaxPool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16, @@ -135,6 +137,7 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) { kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, float, MaxPool)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo input_dims, cudnnDataType_t dat TensorPitches pitches(input_dims); InlinedVector dims(rank); InlinedVector strides(rank); - for (int i = 0; i < rank; i++) { - dims[i] = gsl::narrow_cast(input_dims[i]); - strides[i] = gsl::narrow_cast(pitches[i]); - } - if (is_nhwc) { - std::swap(dims[1], dims[rank - 1]); - std::swap(strides[1], strides[rank - 1]); + + if (!is_nhwc) { + for (int i = 0; i < rank; i++) { + dims[i] = gsl::narrow_cast(input_dims[i]); + strides[i] = gsl::narrow_cast(pitches[i]); + } + } else { + // NHWDC <-> NCHWD + + // N + dims[0] = gsl::narrow_cast(input_dims[0]); + strides[0] = gsl::narrow_cast(pitches[0]); + + // HWD + for (int i = 1; i < rank - 1; i++) { + dims[i + 1] = gsl::narrow_cast(input_dims[i]); + strides[i + 1] = gsl::narrow_cast(pitches[i]); + } + + // C + dims[1] = gsl::narrow_cast(input_dims[rank - 1]); + strides[1] = gsl::narrow_cast(pitches[rank - 1]); } CUDNN_RETURN_IF_ERROR(cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast(rank), dims.data(), strides.data())); return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu b/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu index ef1155af127d1..9311f044f4ec5 100644 --- a/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu +++ b/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu @@ -7,10 +7,11 @@ #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/shared_inc/fast_divmod.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" namespace onnxruntime { namespace cuda { -template +template __global__ void MaxPoolWithIndexKernel( int64_t batch, int64_t channels, @@ -44,11 +45,27 @@ __global__ void MaxPoolWithIndexKernel( int id = blockIdx.x * blockDim.x + threadIdx.x; if (id >= output_size) return; + auto compute_offset = + [height, width, depth, channels](int n_index, int c_index, int h_index, int w_index, int d_index) -> int64_t { + if constexpr (Layout == LAYOUT_NCHW) { + return (((n_index * channels + c_index) * height + h_index) * width + w_index) * depth + d_index; + } else if constexpr (Layout == LAYOUT_NHWC) { + return (((n_index * height + h_index) * width + w_index) * depth + d_index) * channels + c_index; + } + }; + int d_index, w_index, h_index, c_index, n_index, id_tmp; - fdm_d.divmod(id, id_tmp, d_index); - fdm_w.divmod(id_tmp, id_tmp, w_index); - fdm_h.divmod(id_tmp, id_tmp, h_index); - fdm_c.divmod(id_tmp, n_index, c_index); + if constexpr (Layout == LAYOUT_NCHW) { + fdm_d.divmod(id, id_tmp, d_index); + fdm_w.divmod(id_tmp, id_tmp, w_index); + fdm_h.divmod(id_tmp, id_tmp, h_index); + fdm_c.divmod(id_tmp, n_index, c_index); + } else if constexpr (Layout == LAYOUT_NHWC) { + fdm_c.divmod(id, id_tmp, c_index); + fdm_d.divmod(id_tmp, id_tmp, d_index); + fdm_w.divmod(id_tmp, id_tmp, w_index); + fdm_h.divmod(id_tmp, n_index, h_index); + } int64_t d_start = d_index * stride_d - pad_d; int64_t w_start = w_index * stride_w - pad_w; @@ -64,29 +81,45 @@ __global__ void MaxPoolWithIndexKernel( int64_t d_index_max = -1; int64_t w_index_max = -1; int64_t h_index_max = -1; - int64_t offset = (n_index * channels + c_index) * height * width * depth; + int64_t offset = compute_offset(n_index, c_index, 0, 0, 0); const T* p_slice = p_input + offset; - T maxval = p_slice[h_start * width * depth + w_start * depth + d_start] - (T)1; + T maxval = p_slice[compute_offset(0, 0, h_start, w_start, d_start)] - (T)1; for (int64_t d = d_start; d < d_end; d += dilation_d) { for (int64_t w = w_start; w < w_end; w += dilation_w) { for (int64_t h = h_start; h < h_end; h += dilation_h) { - if (p_slice[h * width * depth + w * depth + d] > maxval) { + auto pool_offset = compute_offset(0, 0, h, w, d); + if (p_slice[pool_offset] > maxval) { h_index_max = h; w_index_max = w; d_index_max = d; - maxval = static_cast(p_slice[h * width * depth + w * depth + d]); + maxval = static_cast(p_slice[pool_offset]); } } } } - p_output[id] = p_input[offset + h_index_max * width * depth + w_index_max * depth + d_index_max]; + p_output[id] = p_input[offset + compute_offset(0, 0, h_index_max, w_index_max, d_index_max)]; + if (p_indices) { - p_indices[id] = storage_order == 0 ? offset + h_index_max * width * depth + w_index_max * depth + d_index_max - : offset + h_index_max + w_index_max * height + d_index_max * width * height; + if constexpr (Layout == LAYOUT_NCHW) { + p_indices[id] = storage_order == 0 ? offset + h_index_max * width * depth + w_index_max * depth + d_index_max + : offset + h_index_max + w_index_max * height + d_index_max * width * height; + } else if constexpr (Layout == LAYOUT_NHWC) { + // The tests currently have to be provided in NHWC layout so that tests do not fail. When converting between + // layouts, does it make sense to do an index conversion as well? + // Storing indices in NHWC layout isn't critical as they are supposed to be used by Unpooling operations + // which currently assume that indices reference to Tensors in NHWC layout. + int64_t id_nchw = + (((n_index * channels + c_index) * pooled_height + h_index) * pooled_width + w_index) * pooled_depth + d_index; + int64_t offset_nchw = (n_index * channels + c_index) * width * height * depth; + + p_indices[id_nchw] = (storage_order == 0) + ? offset_nchw + h_index_max * width * depth + w_index_max * depth + d_index_max + : offset_nchw + h_index_max + w_index_max * height + d_index_max * width * height; + } } } -template +template void MaxPoolWithIndex( cudaStream_t stream, const TensorShape& input_shape, @@ -99,14 +132,29 @@ void MaxPoolWithIndex( const T* p_input, T* p_output, int64_t* p_indices) { - int64_t batchs = input_shape[0]; - int64_t channels = input_shape[1]; - int64_t height = input_shape[2]; - int64_t width = kernel_shape.size() > 1 ? input_shape[3] : 1; - int64_t depth = kernel_shape.size() > 2 ? input_shape[4] : 1; - int64_t pooled_height = output_shape[2]; - int64_t pooled_width = kernel_shape.size() > 1 ? output_shape[3] : 1; - int64_t pooled_depth = kernel_shape.size() > 2 ? output_shape[4] : 1; + int64_t batchs, channels, height, width, depth; + int64_t pooled_height, pooled_width, pooled_depth; + if constexpr (Layout == LAYOUT_NCHW) { + batchs = input_shape[0]; + channels = input_shape[1]; + height = input_shape[2]; + width = kernel_shape.size() > 1 ? input_shape[3] : 1; + depth = kernel_shape.size() > 2 ? input_shape[4] : 1; + + pooled_height = output_shape[2]; + pooled_width = kernel_shape.size() > 1 ? output_shape[3] : 1; + pooled_depth = kernel_shape.size() > 2 ? output_shape[4] : 1; + } else if constexpr (Layout == LAYOUT_NHWC) { + batchs = input_shape[0]; + height = input_shape[1]; + width = kernel_shape.size() > 1 ? input_shape[2] : 1; + depth = kernel_shape.size() > 2 ? input_shape[3] : 1; + channels = input_shape[input_shape.NumDimensions() - 1]; + + pooled_height = output_shape[1]; + pooled_width = kernel_shape.size() > 1 ? output_shape[2] : 1; + pooled_depth = kernel_shape.size() > 2 ? output_shape[3] : 1; + } int64_t kernel_h = kernel_shape[0]; int64_t kernel_w = kernel_shape.size() > 1 ? kernel_shape[1] : 1; int64_t kernel_d = kernel_shape.size() > 2 ? kernel_shape[2] : 1; @@ -130,7 +178,7 @@ void MaxPoolWithIndex( fast_divmod fdm_d(static_cast(pooled_depth)); int blocksPerGrid = (int)((output_size + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock); - MaxPoolWithIndexKernel<<>>( + MaxPoolWithIndexKernel<<>>( batchs, channels, height, @@ -162,8 +210,8 @@ void MaxPoolWithIndex( p_indices); } -#define INSTANTIATEMAXPOOLWITHINDEX(T) \ - template void MaxPoolWithIndex( \ +#define INSTANTIATEMAXPOOLWITHINDEX(T, Layout) \ + template void MaxPoolWithIndex( \ cudaStream_t stream, \ const TensorShape& input_shape, \ const TensorShape& output_shape, \ @@ -176,11 +224,19 @@ void MaxPoolWithIndex( T* p_output, \ int64_t* p_indices); -INSTANTIATEMAXPOOLWITHINDEX(float) -INSTANTIATEMAXPOOLWITHINDEX(double) -INSTANTIATEMAXPOOLWITHINDEX(half) -INSTANTIATEMAXPOOLWITHINDEX(int8_t) -INSTANTIATEMAXPOOLWITHINDEX(uint8_t) +INSTANTIATEMAXPOOLWITHINDEX(float, LAYOUT_NCHW) +INSTANTIATEMAXPOOLWITHINDEX(double, LAYOUT_NCHW) +INSTANTIATEMAXPOOLWITHINDEX(half, LAYOUT_NCHW) +INSTANTIATEMAXPOOLWITHINDEX(int8_t, LAYOUT_NCHW) +INSTANTIATEMAXPOOLWITHINDEX(uint8_t, LAYOUT_NCHW) + +#ifdef ENABLE_CUDA_NHWC_OPS +INSTANTIATEMAXPOOLWITHINDEX(float, LAYOUT_NHWC) +INSTANTIATEMAXPOOLWITHINDEX(double, LAYOUT_NHWC) +INSTANTIATEMAXPOOLWITHINDEX(half, LAYOUT_NHWC) +INSTANTIATEMAXPOOLWITHINDEX(int8_t, LAYOUT_NHWC) +INSTANTIATEMAXPOOLWITHINDEX(uint8_t, LAYOUT_NHWC) +#endif } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/max_pool_with_index.h b/onnxruntime/core/providers/cuda/nn/max_pool_with_index.h index 27f5b241cc785..98f14c3f6a626 100644 --- a/onnxruntime/core/providers/cuda/nn/max_pool_with_index.h +++ b/onnxruntime/core/providers/cuda/nn/max_pool_with_index.h @@ -7,7 +7,7 @@ namespace onnxruntime { namespace cuda { -template +template void MaxPoolWithIndex( cudaStream_t stream, const TensorShape& input_shape, diff --git a/onnxruntime/core/providers/cuda/nn/pool.cc b/onnxruntime/core/providers/cuda/nn/pool.cc index 8bc96958693bc..4acdcfcf35491 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.cc +++ b/onnxruntime/core/providers/cuda/nn/pool.cc @@ -87,6 +87,8 @@ POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 11, 11, kMSInt POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 11, 11, kMSInternalNHWCDomain, true) POOLING_KERNEL_WITH_INDICES(MaxPool, float, MaxPool<8>, 12, kMSInternalNHWCDomain, true) POOLING_KERNEL_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 12, kMSInternalNHWCDomain, true) +POOLING_KERNEL_WITH_INDICES(MaxPool, int8_t, MaxPool<8>, 12, kMSInternalNHWCDomain, true) +POOLING_KERNEL_WITH_INDICES(MaxPool, uint8_t, MaxPool<8>, 12, kMSInternalNHWCDomain, true) POOLING_KERNEL(GlobalMaxPool, float, MaxPool<1>, 1, kMSInternalNHWCDomain, true) POOLING_KERNEL(GlobalMaxPool, MLFloat16, MaxPool<1>, 1, kMSInternalNHWCDomain, true) @@ -145,8 +147,8 @@ class CudnnPoolingDescriptor final { cudnnPoolingDescriptor_t desc_; }; -template -Status Pool::ComputeInternal(OpKernelContext* context) const { +template +Status Pool::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; const Tensor* X = context->Input(0); const TensorShape& x_shape = X->Shape(); @@ -157,16 +159,21 @@ Status Pool::ComputeInternal(OpKernelContext* context) const } auto kernel_shape = pool_attrs_.kernel_shape; - auto pads = pool_attrs_.pads; auto strides = pool_attrs_.strides; + TensorShapeVector pads = pool_attrs_.pads; if (pool_attrs_.global_pooling) { - kernel_shape.assign(x_dims.begin() + 2, x_dims.end()); - pads.assign(kernel_shape.size(), 0); + if constexpr (Layout == LAYOUT_NCHW) { + kernel_shape.assign(x_dims.begin() + 2, x_dims.end()); + } else if constexpr (Layout == LAYOUT_NHWC) { + kernel_shape.assign(x_dims.begin() + 1, x_dims.end() - 1); + } + pads.assign(2 * kernel_shape.size(), 0); strides.assign(kernel_shape.size(), 1); } - auto out_channel = NHWC ? x_shape[3] : x_shape[1]; - auto y_dims = pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, NHWC); + auto out_channel = (Layout == LAYOUT_NHWC) ? x_shape[x_dims.size() - 1] : x_shape[1]; + + auto y_dims = pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, Layout == LAYOUT_NHWC); TensorShape y_shape(y_dims); Tensor* Y = context->Output(0, y_shape); // special case when there is a dim value of 0 in the shape. @@ -178,20 +185,22 @@ Status Pool::ComputeInternal(OpKernelContext* context) const TensorShapeVector x_dims_cudnn(x_dims.begin(), x_dims.end()); TensorShapeVector y_dims_cudnn(y_dims); if (kernel_shape.size() < 2) { - // cudnn only takes 4D or 5D input, so pad dimensions if needed - if (NHWC) { - x_dims_cudnn.insert(x_dims_cudnn.begin() + 1, 1); - y_dims_cudnn.insert(y_dims_cudnn.begin() + 1, 1); - kernel_shape.insert(kernel_shape.begin() + 1, 1); - strides.insert(strides.begin() + 1, 1); - } else { - x_dims_cudnn.push_back(1); - y_dims_cudnn.push_back(1); - kernel_shape.push_back(1); - strides.push_back(1); + // cuDNN only takes 4D or 5D input, so pad dimensions if needed + if constexpr (Layout == LAYOUT_NHWC) { + x_dims_cudnn.insert(x_dims_cudnn.end() - 1, 1); + y_dims_cudnn.insert(y_dims_cudnn.end() - 1, 1); + pads.insert(pads.begin() + pads.size() / 2, 0); + pads.insert(pads.end(), 0); + kernel_shape.insert(kernel_shape.end(), 1); + strides.insert(strides.end(), 1); + } else { // Layout == LAYOUT_NCHW + x_dims_cudnn.insert(x_dims_cudnn.end(), 1); + y_dims_cudnn.insert(y_dims_cudnn.end(), 1); + pads.insert(pads.begin() + pads.size() / 2, 0); + pads.insert(pads.end(), 0); + kernel_shape.insert(kernel_shape.end(), 1); + strides.insert(strides.end(), 1); } - pads.insert(pads.begin() + kernel_shape.size(), 0); - pads.insert(pads.end(), 0); } cudnnPoolingMode_t mode = CUDNN_POOLING_MAX; @@ -208,8 +217,8 @@ Status Pool::ComputeInternal(OpKernelContext* context) const const auto beta = Consts::Zero; CudnnTensor x_tensor; CudnnTensor y_tensor; - ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); - ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); + ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType(), Layout == LAYOUT_NHWC)); + ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType(), Layout == LAYOUT_NHWC)); const auto input_count = x_shape.Size(); const auto output_count = y_shape.Size(); @@ -225,8 +234,8 @@ Status Pool::ComputeInternal(OpKernelContext* context) const const auto beta = Consts::Zero; CudnnTensor x_tensor; CudnnTensor y_tensor; - ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); - ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); + ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType(), Layout == LAYOUT_NHWC)); + ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType(), Layout == LAYOUT_NHWC)); CUDNN_RETURN_IF_ERROR( PoolingForwardHelper(GetCudnnHandle(context), pooling_desc, &alpha, x_tensor, x_data, &beta, y_tensor, y_data)); @@ -235,8 +244,8 @@ Status Pool::ComputeInternal(OpKernelContext* context) const return Status::OK(); } -template -Status Pool, NHWC>::ComputeInternal(OpKernelContext* context) const { +template +Status Pool, Layout>::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; const Tensor* X = context->Input(0); const TensorShape& x_shape = X->Shape(); @@ -251,12 +260,16 @@ Status Pool, NHWC>::ComputeInternal(OpKernelContext* context) cons auto strides = this->pool_attrs_.strides; if (this->pool_attrs_.global_pooling) { - kernel_shape.assign(x_dims.begin() + 2, x_dims.end()); - pads.assign(kernel_shape.size(), 0); + if constexpr (Layout == LAYOUT_NCHW) { + kernel_shape.assign(x_dims.begin() + 2, x_dims.end()); + } else if constexpr (Layout == LAYOUT_NHWC) { + kernel_shape.assign(x_dims.begin() + 1, x_dims.end() - 1); + } + pads.assign(2 * kernel_shape.size(), 0); // x{i}_begin + x{i}_end strides.assign(kernel_shape.size(), 1); } - auto out_channel = NHWC ? x_shape[3] : x_shape[1]; - auto y_dims = this->pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, NHWC); + auto out_channel = Layout == LAYOUT_NHWC ? x_shape[x_shape.NumDimensions() - 1] : x_shape[1]; + auto y_dims = this->pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, Layout == LAYOUT_NHWC); Tensor* Y = context->Output(0, TensorShape(y_dims)); // special case when there is a dim value of 0 in the shape. @@ -265,13 +278,22 @@ Status Pool, NHWC>::ComputeInternal(OpKernelContext* context) cons auto x_data = reinterpret_cast(X->Data()); auto y_data = reinterpret_cast(Y->MutableData()); - Tensor* I = context->Output(1, TensorShape(y_dims)); + // I is in NCHW format and the contained indices use NCHW math to compute the index + auto i_dims = y_dims; + if constexpr (Layout == LAYOUT_NHWC) { + // y_dims in NHWDC format, i_dims has to be in NCHWD format. + i_dims.insert(i_dims.begin() + 1, i_dims.back()); // N*C*HWDC + i_dims.pop_back(); // NCHW + } + + Tensor* I = context->Output(1, TensorShape(i_dims)); if (nullptr != I || !this->pool_attrs_.default_dilations) { auto i_data = nullptr == I ? nullptr : I->MutableData(); - MaxPoolWithIndex(this->Stream(context), x_shape, TensorShape(y_dims), kernel_shape, strides, pads, - this->pool_attrs_.dilations, this->pool_attrs_.storage_order, x_data, y_data, i_data); + MaxPoolWithIndex(this->Stream(context), x_shape, TensorShape(y_dims), kernel_shape, + strides, pads, this->pool_attrs_.dilations, + this->pool_attrs_.storage_order, x_data, y_data, i_data); } else { - ORT_RETURN_IF_ERROR((Pool, NHWC>::ComputeInternal(context))); + ORT_RETURN_IF_ERROR((Pool, Layout == LAYOUT_NHWC>::ComputeInternal(context))); } return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/nn/pool.h b/onnxruntime/core/providers/cuda/nn/pool.h index 8b5152a1565a9..97f7c8b8762d5 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.h +++ b/onnxruntime/core/providers/cuda/nn/pool.h @@ -19,10 +19,10 @@ class Pool : public CudaKernel, public PoolBase { Status ComputeInternal(OpKernelContext* context) const override; }; -template -class Pool, NHWC> final : public Pool, NHWC> { +template +class Pool, Layout> final : public Pool, Layout> { public: - explicit Pool(const OpKernelInfo& info) : Pool, NHWC>(info) {} + explicit Pool(const OpKernelInfo& info) : Pool, Layout>(info) {} Status ComputeInternal(OpKernelContext* context) const override; }; diff --git a/onnxruntime/core/providers/rocm/nn/pool.cc b/onnxruntime/core/providers/rocm/nn/pool.cc index 045c8b55c0b0d..3a82ab598004b 100644 --- a/onnxruntime/core/providers/rocm/nn/pool.cc +++ b/onnxruntime/core/providers/rocm/nn/pool.cc @@ -257,7 +257,7 @@ Status Pool>::ComputeInternal(OpKernelContext* context) const { Tensor* I = context->Output(1, TensorShape(y_dims)); if (nullptr != I || !this->pool_attrs_.default_dilations) { auto i_data = nullptr == I ? nullptr : I->MutableData(); - MaxPoolWithIndex( + MaxPoolWithIndex( this->Stream(context), x_shape, TensorShape(y_dims), diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 4b194ec18b31b..11c0ce0efa508 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -57,7 +57,8 @@ TEST(PoolTest, MaxPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: result differs + // TensorRT: result differs + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } // Only CUDA kernel has float 16 support @@ -115,7 +116,8 @@ TEST(PoolTest, MaxPool_F16) { test.AddInput("X", x_dims, f_X); test.AddOutput("Y", expected_dims, f_Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Assertion `!attrs.count("pads")' failed + // TensorRT: Assertion `!attrs.count("pads")' failed + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } #endif @@ -167,7 +169,9 @@ static void MaxPool_8_WithIndexTest(bool has_index, int64_t storage_order = 0) { storage_order == 0 ? test.AddOutput("Indices", expected_dims, expected_indices_row) : test.AddOutput("Indices", expected_dims, expected_indices_col); } - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDnnlExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider, kArmNNExecutionProvider, kOpenVINOExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kDnnlExecutionProvider, kTensorrtExecutionProvider, + kAclExecutionProvider, kArmNNExecutionProvider, kOpenVINOExecutionProvider}); } TEST(PoolTest, MaxPool_8_With_Index) { @@ -181,7 +185,7 @@ TEST(PoolTest, MaxPool_8_With_Index) { MaxPool_8_WithIndexTest(true, 1 /*storage_order*/); // col major } -TEST(PoolTest, MaxPool1D) { +TEST(PoolTest, MaxPool1D_case1) { OpTester test("MaxPool"); test.AddAttribute("auto_pad", ""); @@ -199,6 +203,44 @@ TEST(PoolTest, MaxPool1D) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +TEST(PoolTest, MaxPool1D_case2) { + OpTester test("MaxPool"); + // no padding + test.AddAttribute("auto_pad", "VALID"); + test.AddAttribute("strides", std::vector{1}); + test.AddAttribute("pads", vector{0, 0}); + test.AddAttribute("kernel_shape", vector{2}); + + std::vector x_vals = {1, 2, 3, 4, 5}; + std::vector x_dims = {1, 1, 5}; + // The last dim is (5-2+1)/1 = 4 + std::vector expected_dims = {1, 1, 4}; + std::vector expected_vals = {2, 3, 4, 5}; + + test.AddInput("X", x_dims, x_vals); + test.AddOutput("Y", expected_dims, expected_vals); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +TEST(PoolTest, MaxPool1D_case3) { + OpTester test("MaxPool"); + test.AddAttribute("auto_pad", ""); + test.AddAttribute("strides", std::vector{1}); + // Pad one element + test.AddAttribute("pads", vector{0, 1}); + test.AddAttribute("kernel_shape", vector{2}); + + std::vector x_vals = {1, 2, 3, 4, 5}; + std::vector x_dims = {1, 1, 5}; + // Since we padded it, the last dim is larger compared to the case above + std::vector expected_dims = {1, 1, 5}; + std::vector expected_vals = {2, 3, 4, 5, 5}; + + test.AddInput("X", x_dims, x_vals); + test.AddOutput("Y", expected_dims, expected_vals); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + static void MaxPool1D_8_WithIndexTest(int64_t storage_order) { OpTester test("MaxPool", 8); @@ -217,7 +259,8 @@ static void MaxPool1D_8_WithIndexTest(int64_t storage_order) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.AddOutput("Indices", expected_dims, expected_indices); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool1D_8_With_Index) { @@ -243,7 +286,8 @@ static void MaxPool1D_12_WithIndexTest_int8(int64_t storage_order) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.AddOutput("Indices", expected_dims, expected_indices); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kAclExecutionProvider}); } static void MaxPool1D_12_WithIndexTest_uint8(int64_t storage_order) { @@ -264,7 +308,8 @@ static void MaxPool1D_12_WithIndexTest_uint8(int64_t storage_order) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.AddOutput("Indices", expected_dims, expected_indices); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool1D_12_With_Index_8bits) { @@ -304,7 +349,7 @@ TEST(PoolTest, MaxPool2D_uint8) { #if defined(OPENVINO_CONFIG_GPU_FP32) || defined(OPENVINO_CONFIG_GPU_FP16) test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); #else - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); #endif } @@ -528,7 +573,8 @@ TEST(PoolTest, MaxPool_10_Dilation_Ceil0_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool_12_Dilation_Ceil0_2d_int8) { @@ -556,7 +602,8 @@ TEST(PoolTest, MaxPool_12_Dilation_Ceil0_2d_int8) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool_10_Dilation_Ceil1_2d) { @@ -585,7 +632,8 @@ TEST(PoolTest, MaxPool_10_Dilation_Ceil1_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool_10_DilationPadding_3d) { @@ -697,7 +745,7 @@ TEST(PoolTest, GlobalMaxPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); } TEST(PoolTest, GlobalMaxPool3D) { @@ -878,6 +926,7 @@ TEST(PoolTest, AveragePool_IncludePadPixel) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); + test.SetOutputTolerance(0.0001f); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } @@ -920,7 +969,8 @@ TEST(PoolTest, AveragePool_10_ceil1_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, AveragePool_19_dilation_2d) { @@ -944,7 +994,9 @@ TEST(PoolTest, AveragePool_19_dilation_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider, kOpenVINOExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, + kTensorrtExecutionProvider, kAclExecutionProvider, kOpenVINOExecutionProvider}); } TEST(PoolTest, GlobalAveragePool) { @@ -1020,7 +1072,7 @@ TEST(PoolTest, GlobalAveragePool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); } TEST(PoolTest, GlobalAveragePool_Large_128) { @@ -1033,7 +1085,7 @@ TEST(PoolTest, GlobalAveragePool_Large_128) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals, /*sort_output=*/false, /*rel_error=*/1e-3f, /*abs_error=*/1e-2f); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); } TEST(PoolTest, GlobalAveragePool_Large_256) { @@ -1046,7 +1098,7 @@ TEST(PoolTest, GlobalAveragePool_Large_256) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals, /*sort_output=*/false, /*rel_error=*/1e-3f, /*abs_error=*/1e-2f); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); } TEST(PoolTest, LpPool) { @@ -1353,7 +1405,7 @@ TEST(PoolTest, LpPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kCudaNHWCExecutionProvider}); } // test data generated with lp_pool_test_generator.py @@ -1385,7 +1437,8 @@ TEST(PoolTest, LpPool1d) { // https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html#a94f434942252e6d98ac17705c06ce060 // TensorRT does not support 1d pooling - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); y_count++; } } @@ -1417,7 +1470,7 @@ TEST(PoolTest, LpPool2d) { test.AddAttribute("kernel_shape", kernel_sizes[kernel_size_count]); test.AddOutput("Y", y_sizes[y_count], ys[y_count]); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kCudaNHWCExecutionProvider}); y_count++; } } @@ -1435,7 +1488,8 @@ TEST(PoolTest, LpPoolCeilMode) { // https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html#a94f434942252e6d98ac17705c06ce060 // TensorRT does not support 1d pooling - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, GlobalLpPool) { @@ -1690,7 +1744,7 @@ TEST(PoolTest, GlobalLpPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kCudaNHWCExecutionProvider}); } TEST(PoolTest, MaxPoolDimWithZeroForN) { @@ -1707,7 +1761,8 @@ TEST(PoolTest, MaxPoolDimWithZeroForN) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kQnnExecutionProvider}); } } // namespace test