Skip to content

Commit

Permalink
Fix #if __CUDACC__ ranges
Browse files Browse the repository at this point in the history
  • Loading branch information
JTischbein committed Mar 11, 2024
1 parent 590497b commit 7a209cb
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
auto handle = GetCudnnHandle(context);

int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo();
#if !defined(__CUDACC__)
cudnn_frontend::HeurMode_t heur_mode;
switch (cudnn_conv_algo) {
case 0:
Expand All @@ -396,7 +397,6 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)

size_t kernel_shape_size = kernel_shape.size();

#if !defined(__CUDACC__)
ORT_RETURN_IF_ERROR(CreateCudnnFeExecutionPlan(X, W, B, y_dims_cudnn, handle, heur_mode,
std::vector<int64_t>(pads.begin(),
pads.begin() + kernel_shape_size),
Expand Down
9 changes: 6 additions & 3 deletions onnxruntime/core/providers/cuda/nn/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
#include <list>
#include <memory>

#if !defined(__CUDACC__)
#include <cudnn_frontend.h>
#endif

#include "core/platform/ort_mutex.h"
#include "core/providers/cuda/cuda_kernel.h"
#include "core/providers/cuda/cudnn_common.h"
#include "core/providers/cpu/nn/conv_attributes.h"
#include <cudnn_frontend.h>

namespace onnxruntime {

Expand Down Expand Up @@ -223,8 +226,8 @@ class Conv : public CudaKernel {
Status UpdateState(OpKernelContext* context, bool bias_expected = false) const;

#if !defined(__CUDACC__)
Status CreateCudnnFeExecutionPlan(const Tensor* X, const Tensor* W, const Tensor* B, cudnnContext* handle, const cudnn_frontend::HeurMode_t heur_mode,
const std::vector<int64_t>& pads, const std::vector<int64_t>& strides, const std::vector<int64_t>& dilations, const bool bias_expected, const bool fuse_bias) const;
Status CreateCudnnFeExecutionPlan(const Tensor* X, const Tensor* W, const Tensor* B, const TensorShapeVector& y_dims, cudnnContext* handle, const cudnn_frontend::HeurMode_t heur_mode,
const std::vector<int64_t>& pads, const std::vector<int64_t>& strides, const std::vector<int64_t>& dilations, const bool bias_expected, const bool fuse_bias) const;
#endif

ConvAttributes conv_attrs_;
Expand Down
2 changes: 0 additions & 2 deletions onnxruntime/test/providers/cpu/nn/conv_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes,
std::unordered_set<std::string> excluded_providers(attributes.excluded_providers);
// Disable TensorRT because weight as input is not supported
excluded_providers.insert(kTensorrtExecutionProvider);
// Disable CUDA NHWC execution provider as it is currently flaky
excluded_providers.insert(kCudaNHWCExecutionProvider);

// QNN SDK 2.10.0 has a bug that breaks support for dynamic bias inputs.
excluded_providers.insert(kQnnExecutionProvider);
Expand Down

0 comments on commit 7a209cb

Please sign in to comment.