Skip to content

Commit

Permalink
fromatting and adding more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gedoensmax committed Oct 11, 2023
1 parent 8b15fb8 commit b2402a9
Show file tree
Hide file tree
Showing 11 changed files with 202 additions and 197 deletions.
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class CUDAExecutionProvider : public IExecutionProvider {
}

cudaStream_t ComputeStream() {
// this will return the CUDA EP level stream which can differ from the actual compute tasks stream
// the compute task stream is supplied within OpKernelContext during inference
return stream_;
}

Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ class CudaKernel : public OpKernel {
}

inline cudaStream_t DefaultCudaStream() const {
// this will return the CUDA EP level stream which can differ from the actual compute tasks stream
// the compute task stream is supplied within OpKernelContext during inference
return provider_->ComputeStream();
}

Expand Down
186 changes: 124 additions & 62 deletions onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/nn/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class Conv : public CudaKernel {
constexpr static auto kDefaultConvAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
static const cudnnConvolutionFwdAlgo_t kAllAlgos[];
std::unique_ptr<Tensor> W_;

Check warning on line 209 in onnxruntime/core/providers/cuda/nn/conv.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cuda/nn/conv.h:209: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
bool is_nhwc_domain_; // prepack is only needed for the Conv in kMSInternalNHWCDomain
bool is_nhwc_domain_; // prepack is only needed for the Conv in kMSInternalNHWCDomain
};

Status SliceOutUnwantedOutputSection(cudaStream_t stream,
Expand Down
103 changes: 32 additions & 71 deletions onnxruntime/core/providers/cuda/nn/pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,50 +14,31 @@ namespace cuda {

#define POOLING_KERNEL(op_name, data_type, pool_type, since_version, op_domain, nhwc) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
op_name, \
op_domain, \
since_version, \
data_type, \
kCudaExecutionProvider, \
op_name, op_domain, since_version, data_type, kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()), \
Pool<data_type, pool_type, nhwc>);

#define POOLING_KERNEL_VERSIONED(op_name, data_type, pool_type, since_version, end_version, op_domain, nhwc) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
op_name, \
op_domain, \
since_version, \
end_version, \
data_type, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()), \
op_name, op_domain, since_version, end_version, data_type, kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()), \
Pool<data_type, pool_type, nhwc>);

#define POOLING_KERNEL_WITH_INDICES(op_name, data_type, pool_type, since_version, op_domain, nhwc) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
op_name, \
op_domain, \
since_version, \
data_type, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()) \
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()), \
Pool<data_type, pool_type, nhwc>);

#define POOLING_KERNEL_VERSIONED_WITH_INDICES(op_name, data_type, pool_type, since_version, end_version, op_domain, nhwc) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
op_name, \
op_domain, \
since_version, \
end_version, \
data_type, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()) \
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()), \
Pool<data_type, pool_type, nhwc>);
#define POOLING_KERNEL_WITH_INDICES(op_name, data_type, pool_type, since_version, op_domain, nhwc) \
ONNX_OPERATOR_TYPED_KERNEL_EX(op_name, op_domain, since_version, data_type, kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()) \
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()), \
Pool<data_type, pool_type, nhwc>);

#define POOLING_KERNEL_VERSIONED_WITH_INDICES(op_name, data_type, pool_type, since_version, end_version, op_domain, \
nhwc) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(op_name, op_domain, since_version, end_version, data_type, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()) \
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()), \
Pool<data_type, pool_type, nhwc>);

POOLING_KERNEL_VERSIONED(AveragePool, float, AveragePool, 7, 9, kOnnxDomain, false)
POOLING_KERNEL_VERSIONED(AveragePool, double, AveragePool, 7, 9, kOnnxDomain, false)
Expand Down Expand Up @@ -123,8 +104,7 @@ POOLING_KERNEL(GlobalAveragePool, MLFloat16, AveragePool, 1, kMSInternalNHWCDoma

class CudnnPoolingDescriptor final {
public:
CudnnPoolingDescriptor() : desc_(nullptr) {
}
CudnnPoolingDescriptor() : desc_(nullptr) {}

~CudnnPoolingDescriptor() {
if (desc_ != nullptr) {
Expand All @@ -136,12 +116,9 @@ class CudnnPoolingDescriptor final {
CudnnPoolingDescriptor(const CudnnPoolingDescriptor&) = delete;
CudnnPoolingDescriptor& operator=(const CudnnPoolingDescriptor&) = delete;

Status Set(cudnnPoolingMode_t mode,
const gsl::span<const int64_t>& kernel_shape,
const gsl::span<const int64_t>& pads,
const gsl::span<const int64_t>& strides) {
if (!desc_)
CUDNN_RETURN_IF_ERROR(cudnnCreatePoolingDescriptor(&desc_));
Status Set(cudnnPoolingMode_t mode, const gsl::span<const int64_t>& kernel_shape,
const gsl::span<const int64_t>& pads, const gsl::span<const int64_t>& strides) {
if (!desc_) CUDNN_RETURN_IF_ERROR(cudnnCreatePoolingDescriptor(&desc_));

int rank = gsl::narrow_cast<int>(kernel_shape.size());
InlinedVector<int> window(rank);
Expand All @@ -156,14 +133,8 @@ class CudnnPoolingDescriptor final {
for (int i = 0; i < rank; i++) {
stride[i] = gsl::narrow_cast<int>(strides[i]);
}
CUDNN_RETURN_IF_ERROR(SetPoolingNdDescriptorHelper(
desc_,
mode,
CUDNN_PROPAGATE_NAN,
rank,
window.data(),
padding.data(),
stride.data()));
CUDNN_RETURN_IF_ERROR(SetPoolingNdDescriptorHelper(desc_, mode, CUDNN_PROPAGATE_NAN, rank, window.data(),
padding.data(), stride.data()));

return Status::OK();
}
Expand Down Expand Up @@ -199,8 +170,7 @@ Status Pool<T, PoolType, NHWC>::ComputeInternal(OpKernelContext* context) const
TensorShape y_shape(y_dims);
Tensor* Y = context->Output(0, y_shape);
// special case when there is a dim value of 0 in the shape.
if (y_shape.Size() == 0)
return Status::OK();
if (y_shape.Size() == 0) return Status::OK();

auto x_data = reinterpret_cast<const CudaT*>(X->Data<T>());
auto y_data = reinterpret_cast<CudaT*>(Y->MutableData<T>());
Expand Down Expand Up @@ -247,7 +217,8 @@ Status Pool<T, PoolType, NHWC>::ComputeInternal(OpKernelContext* context) const
IAllocatorUniquePtr<float> temp_X = GetScratchBuffer<float>(input_count, context->GetComputeStream());
auto temp_Y = GetScratchBuffer<float>(output_count, context->GetComputeStream());
Impl_Cast<CudaT, float>(Stream(context), reinterpret_cast<const CudaT*>(x_data), temp_X.get(), input_count);
CUDNN_RETURN_IF_ERROR(PoolingForwardHelper(GetCudnnHandle(context), pooling_desc, &alpha, x_tensor, temp_X.get(), &beta, y_tensor, temp_Y.get()));
CUDNN_RETURN_IF_ERROR(PoolingForwardHelper(GetCudnnHandle(context), pooling_desc, &alpha, x_tensor, temp_X.get(),
&beta, y_tensor, temp_Y.get()));
Impl_Cast<float, CudaT>(Stream(context), temp_Y.get(), y_data, output_count);
} else {
const auto alpha = Consts<CudaT>::One;
Expand All @@ -257,7 +228,8 @@ Status Pool<T, PoolType, NHWC>::ComputeInternal(OpKernelContext* context) const
ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType<CudaT>(), NHWC));
ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType<CudaT>(), NHWC));

CUDNN_RETURN_IF_ERROR(PoolingForwardHelper(GetCudnnHandle(context), pooling_desc, &alpha, x_tensor, x_data, &beta, y_tensor, y_data));
CUDNN_RETURN_IF_ERROR(
PoolingForwardHelper(GetCudnnHandle(context), pooling_desc, &alpha, x_tensor, x_data, &beta, y_tensor, y_data));
}

return Status::OK();
Expand Down Expand Up @@ -288,27 +260,16 @@ Status Pool<T, MaxPool<8>, NHWC>::ComputeInternal(OpKernelContext* context) cons
Tensor* Y = context->Output(0, TensorShape(y_dims));

// special case when there is a dim value of 0 in the shape.
if (Y->Shape().Size() == 0)
return Status::OK();
if (Y->Shape().Size() == 0) return Status::OK();

auto x_data = reinterpret_cast<const CudaT*>(X->Data<T>());
auto y_data = reinterpret_cast<CudaT*>(Y->MutableData<T>());

Tensor* I = context->Output(1, TensorShape(y_dims));
if (nullptr != I || !this->pool_attrs_.default_dilations) {
auto i_data = nullptr == I ? nullptr : I->MutableData<int64_t>();
MaxPoolWithIndex<CudaT>(
this->Stream(context),
x_shape,
TensorShape(y_dims),
kernel_shape,
strides,
pads,
this->pool_attrs_.dilations,
this->pool_attrs_.storage_order,
x_data,
y_data,
i_data);
MaxPoolWithIndex<CudaT>(this->Stream(context), x_shape, TensorShape(y_dims), kernel_shape, strides, pads,
this->pool_attrs_.dilations, this->pool_attrs_.storage_order, x_data, y_data, i_data);
} else {
ORT_RETURN_IF_ERROR((Pool<T, MaxPool<1>, NHWC>::ComputeInternal(context)));
}
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/test/providers/compare_provider_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

#pragma once

#include <string>
#include <memory>
#include <vector>
#include <unordered_map>

#include "core/graph/constants.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/providers/provider_test_utils.h"
Expand Down
24 changes: 6 additions & 18 deletions onnxruntime/test/providers/cuda/nhwc/conv_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Copyright (c) 2023 NVIDIA Corporation.
// Licensed under the MIT License.

#include "nhwc_cuda_helper.h"
#include "test/providers/cuda/nhwc/nhwc_cuda_helper.h"

namespace onnxruntime {
namespace test {
Expand Down Expand Up @@ -41,8 +41,7 @@ struct ConvOp {
test->AddAttribute("pads", padding);

std::vector<int64_t> output_dims = {
input_dims[0],
channels,
input_dims[0], channels,
ComputeOutputShape(input_dims[2], strides[0], kernel_shape[0], dilations[0], padding[0], padding[1]),
ComputeOutputShape(input_dims[3], strides[1], kernel_shape[1], dilations[1], padding[2], padding[3])};
std::vector<T> output_data = FillZeros<T>(output_dims);
Expand All @@ -53,31 +52,20 @@ struct ConvOp {
};

TYPED_TEST(CudaNhwcTypedTest, ConvNhwcBias) {
auto op = ConvOp<TypeParam>{
.input_dims = {1, 16, 64, 64},
.kernel_shape = {3, 3},
.channels = 16,
.bias = true};
auto op = ConvOp<TypeParam>{.input_dims = {1, 16, 64, 64}, .kernel_shape = {3, 3}, .channels = 16, .bias = true};

MAKE_PROVIDERS_EPS_TYPE(TypeParam)
}

TYPED_TEST(CudaNhwcTypedTest, ConvNhwcGroupNoBias) {
auto op = ConvOp<TypeParam>{
.input_dims = {1, 16, 64, 64},
.kernel_shape = {3, 3},
.channels = 16,
.group = 4};
auto op = ConvOp<TypeParam>{.input_dims = {1, 16, 64, 64}, .kernel_shape = {3, 3}, .channels = 16, .group = 4};

MAKE_PROVIDERS_EPS_TYPE(TypeParam)
}

TYPED_TEST(CudaNhwcTypedTest, ConvNhwcPadding) {
auto op = ConvOp<TypeParam>{
.input_dims = {2, 4, 64, 64},
.kernel_shape = {3, 3},
.channels = 4,
.padding = {4, 4, 4, 4}};
auto op =
ConvOp<TypeParam>{.input_dims = {2, 4, 64, 64}, .kernel_shape = {3, 3}, .channels = 4, .padding = {4, 4, 4, 4}};

MAKE_PROVIDERS_EPS_TYPE(TypeParam)
}
Expand Down
41 changes: 16 additions & 25 deletions onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Copyright (c) 2023 NVIDIA Corporation.
// Licensed under the MIT License.

#include "nhwc_cuda_helper.h"
#include "test/providers/cuda/nhwc/nhwc_cuda_helper.h"

namespace onnxruntime {
namespace test {
Expand Down Expand Up @@ -45,8 +45,7 @@ struct ConvTransposeOp {
}

std::vector<int64_t> output_dims = {
input_dims[0],
channels,
input_dims[0], channels,
(kernel_shape[1] - 1) * dilations[1] + (input_dims[2] - 1) * strides[1] - (padding[1] + padding[0]) + 1,
(kernel_shape[0] - 1) * dilations[0] + (input_dims[3] - 1) * strides[0] - (padding[3] + padding[2]) + 1};
std::vector<T> output_data = FillZeros<T>(output_dims);
Expand All @@ -57,43 +56,35 @@ struct ConvTransposeOp {
};

TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcGroupNoBias) {
auto op = ConvTransposeOp<TypeParam>{
.input_dims = {8, 8, 32, 32},
.kernel_shape = {3, 3},
.channels = 16,
.group = 4};
auto op =
ConvTransposeOp<TypeParam>{.input_dims = {8, 8, 32, 32}, .kernel_shape = {3, 3}, .channels = 16, .group = 4};

MAKE_PROVIDERS_EPS_TYPE(TypeParam)
}

TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcBias) {
auto op = ConvTransposeOp<TypeParam>{
.input_dims = {1, 8, 80, 80},
.kernel_shape = {5, 5},
.channels = 16,
.bias = true};
auto op =
ConvTransposeOp<TypeParam>{.input_dims = {1, 8, 80, 80}, .kernel_shape = {5, 5}, .channels = 16, .bias = true};

MAKE_PROVIDERS_EPS_TYPE(TypeParam)
}

TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcPad) {
auto op = ConvTransposeOp<TypeParam>{
.input_dims = {1, 16, 8, 8},
.kernel_shape = {3, 3},
.channels = 32,
.padding = {2, 2, 2, 2},
.output_padding = {}};
auto op = ConvTransposeOp<TypeParam>{.input_dims = {1, 16, 8, 8},
.kernel_shape = {3, 3},
.channels = 32,
.padding = {2, 2, 2, 2},
.output_padding = {}};

MAKE_PROVIDERS_EPS_TYPE(TypeParam)
}

TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcOutPad) {
auto op = ConvTransposeOp<TypeParam>{
.input_dims = {1, 32, 8, 8},
.kernel_shape = {3, 3},
.channels = 32,
.strides = {2, 2},
.output_padding = {1, 1, 1, 1}};
auto op = ConvTransposeOp<TypeParam>{.input_dims = {1, 32, 8, 8},
.kernel_shape = {3, 3},
.channels = 32,
.strides = {2, 2},
.output_padding = {1, 1, 1, 1}};

MAKE_PROVIDERS_EPS_TYPE(TypeParam)
}
Expand Down
10 changes: 4 additions & 6 deletions onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Copyright (c) 2023 NVIDIA Corporation.
// Licensed under the MIT License.

#include <vector>
#include "core/providers/cuda/cuda_provider_options.h"
#include "core/providers/common.h"

Expand All @@ -12,13 +13,11 @@

#define MAKE_PROVIDERS_EPS(eps) \
std::vector<std::shared_ptr<IExecutionProvider>> execution_providers; \
OrtCUDAProviderOptionsV2 nhwc = { \
.prefer_nhwc = true}; \
OrtCUDAProviderOptionsV2 nhwc = {.prefer_nhwc = true}; \
execution_providers.push_back(CudaExecutionProviderWithOptions(&nhwc)); \
\
double error_tolerance = eps; \
OrtCUDAProviderOptionsV2 nchw = { \
.prefer_nhwc = false}; \
OrtCUDAProviderOptionsV2 nchw = {.prefer_nhwc = false}; \
auto source_ep = CudaExecutionProviderWithOptions(&nchw); \
auto test = op.get_test(); \
test->CompareEPs(std::move(source_ep), execution_providers, error_tolerance);
Expand All @@ -37,8 +36,7 @@ namespace onnxruntime {
namespace test {

template <typename T>
class CudaNhwcTypedTest : public ::testing::Test {
};
class CudaNhwcTypedTest : public ::testing::Test {};

using CudaNhwcTestTypes = ::testing::Types<float, MLFloat16>; // double,
TYPED_TEST_SUITE(CudaNhwcTypedTest, CudaNhwcTestTypes);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/providers/cuda/nhwc/norm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Copyright (c) 2023 NVIDIA Corporation.
// Licensed under the MIT License.

#include "nhwc_cuda_helper.h"
#include "test/providers/cuda/nhwc/nhwc_cuda_helper.h"

namespace onnxruntime {
namespace test {
Expand Down
Loading

0 comments on commit b2402a9

Please sign in to comment.