From 48b2750eea1220af8bc5bd6b2cddd16d93b7200d Mon Sep 17 00:00:00 2001 From: Julius Tischbein Date: Thu, 15 Aug 2024 12:28:43 +0200 Subject: [PATCH 1/5] ConvTranpose using CUDNN Frontend with NHWC support and small fix in conv.cc --- .../providers/cuda/cuda_execution_provider.cc | 2 + onnxruntime/core/providers/cuda/nn/conv.cc | 2 +- .../core/providers/cuda/nn/conv_transpose.cc | 625 +++++++++++------- .../core/providers/cuda/nn/conv_transpose.h | 31 +- .../core/providers/cuda/nn/conv_transpose_8.h | 266 ++++++++ 5 files changed, 702 insertions(+), 224 deletions(-) create mode 100644 onnxruntime/core/providers/cuda/nn/conv_transpose_8.h diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index f74754c3cd064..16d18ee4d86fb 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2411,6 +2411,7 @@ static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node, static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const logging::Logger& logger, [[maybe_unused]] const GraphViewer& graph_viewer, [[maybe_unused]] const bool prefer_nhwc) { +#if CUDNN_MAJOR < 9 const auto& node_attributes = node.GetAttributes(); // Check attributes for (auto& attr : node_attributes) { @@ -2457,6 +2458,7 @@ static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const } } } +#endif #ifdef ENABLE_CUDA_NHWC_OPS if (prefer_nhwc) { diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 95ba698b707ac..cc76198dc3ae9 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -385,7 +385,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected if (cuda_ep->GetCudnnConv1dPadToNc1d()) { x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); - w_dims_cudnn.insert(w_dims.begin() + 2, 1); + w_dims_cudnn.insert(w_dims_cudnn.begin() + 2, 1); pads.insert(pads.begin() + kernel_rank, 0); pads.insert(pads.begin(), 0); kernel_shape.insert(kernel_shape.begin(), 1); diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index bac99d6a81ed2..b3fee863b5f19 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -7,6 +7,11 @@ #include "conv_transpose.h" #include "core/providers/cuda/tensor/transpose.h" +#if CUDNN_MAJOR < 9 +// if compiled with cuDNN 8 we want to use the legacy cuDNN API +#include "conv_transpose_8.h" +#endif + // To suppress FP static analyzer warnings: // https://msdata.visualstudio.com/Vienna/_workitems/edit/1944928 and // https://msdata.visualstudio.com/Vienna/_workitems/edit/1944950 @@ -38,48 +43,42 @@ REGISTER_KERNEL_TYPED(float, kMSInternalNHWCDomain, true) REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true) #endif -template -Status ConvTranspose::ComputeInternal(OpKernelContext* context) const { - return DoConvTranspose(context, false); -} - +// First input (in this case X) is in case NHWC == true also in NHWC format, the other inputs in NCHW template Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, - PrePackedWeights* prepacked_weights) { + [[maybe_unused]] PrePackedWeights* prepacked_weights) { is_packed = false; // only layout of weight input is adjusted via PrePack - if constexpr (NHWC) { // InputTensors::IN_W - if (input_idx == 1) { + if constexpr (NHWC) { + if (is_nhwc_domain_ && input_idx == 1) { // InputTensors::IN_W + // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} auto orig_shape = tensor.Shape(); - const auto rank = orig_shape.NumDimensions(); - - InlinedVector perm; - TensorShapeVector new_dims; - - // Input is { N, C, ...}. Output is { N, M, ...}. 'input channels' is C. 'output channels' is M. - // Transpose the output channels related dimension (M/group) to be last. Leave the input channels as-is. - if (rank == 3) { - // Transpose from {C, M/group, k1} to {C, k1, M/group} - perm = {0, 2, 1}; - new_dims = TensorShapeVector{orig_shape[0], orig_shape[2], orig_shape[1]}; - } else if (rank == 4) { - // Transpose from {C, M/group, kH, kW} to {C, kH, kW, M/group} - perm = {0, 2, 3, 1}; - new_dims = TensorShapeVector{orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1]}; - } else if (rank == 5) { - // Transpose from {C, M/group, k1, k2, k3} to {C, k1, k2, k3, M/group} - perm = {0, 2, 3, 4, 1}; - new_dims = TensorShapeVector{orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[4], orig_shape[1]}; - } + auto shape_size = orig_shape.GetDims().size(); + + InlinedVector perm; + perm.push_back(0); + for (size_t i = 2; i < shape_size; i++) perm.push_back(i); + perm.push_back(1); + gsl::span permutation(perm.data(), shape_size); - gsl::span permutation(perm.data(), rank); - W_ = Tensor::Create(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + TensorShapeVector nhwc_dims; + for (size_t i = 0; i < shape_size; i++) { + nhwc_dims.push_back(orig_shape[perm[i]]); + } - ORT_RETURN_IF_ERROR(cuda::Transpose::DoTranspose(GetDeviceProp(), DefaultCudaStream(), DefaultCublasHandle(), - permutation, tensor, *W_)); + W_ = Tensor::Create(tensor.DataType(), TensorShape(nhwc_dims), std::move(alloc)); + auto status = cuda::Transpose::DoTranspose(GetDeviceProp(), + DefaultCudaStream(), + DefaultCublasHandle(), + permutation, tensor, *W_); + if (!status.IsOK()) { + return status; + } CUDA_CALL_THROW(cudaStreamSynchronize(DefaultCudaStream())); is_packed = true; + } else { + W_already_nhwc = true; } } else { ORT_UNUSED_PARAMETER(tensor); @@ -91,236 +90,418 @@ Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, Allo return Status::OK(); } -template -Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { - typedef typename ToCudaType::MappedType CudaT; +#if CUDNN_MAJOR >= 9 +#if !defined(__CUDACC__) + +template +Status ConvTranspose::CreateCudnnFeExecutionPlan(const onnxruntime::TensorShapeVector& x_dims, + const onnxruntime::TensorShapeVector& w_dims, + const Tensor* B, + const TensorShapeVector& y_dims, + cudnnContext* handle, + const cudnn_frontend::HeurMode_t heur_mode, + const std::vector& pads, + const std::vector& strides, + const std::vector& dilations, + const bool fuse_bias, + const bool fuse_act, + const bool w_in_nhwc, + const bool use_tf32) const { + s_.bias_fused = fuse_bias; + s_.act_fused = fuse_act; + s_.variant_pack.clear(); // clear variant pack, as stored pointers to tensors change + s_.cudnn_fe_graph = std::make_unique(); + cudnn_frontend::DataType_t data_type = CudnnFeTensor::GetDataType(); + s_.cudnn_fe_graph->set_io_data_type(data_type).set_intermediate_data_type(data_type); + if (data_type == cudnn_frontend::DataType_t::HALF) { + s_.cudnn_fe_graph->set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + } else { + s_.cudnn_fe_graph->set_compute_data_type(data_type); + } - const Tensor* X = context->Input(0); - const TensorShape& x_shape = X->Shape(); - auto x_dims = x_shape.AsShapeVector(); - auto x_data = reinterpret_cast(X->Data()); - - auto x_dimensions = X->Shape().NumDimensions(); - if (x_dimensions < 3 || x_dimensions > 5) { - // TODO: the error message should tell which operator raises it. - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3-, 4- or 5-dimensional.", - " X: ", X->Shape().ToString().c_str()); + s_.cudnn_fe_X = s_.cudnn_fe_graph->tensor(CudnnFeTensor(x_dims, "x", data_type, Layout == LAYOUT_NHWC).Get()); + s_.cudnn_fe_W = s_.cudnn_fe_graph->tensor(CudnnFeTensor(w_dims, "w", data_type, w_in_nhwc).Get()); + + auto conv_options = cudnn_frontend::graph::Conv_dgrad_attributes() + .set_pre_padding(std::vector(pads.begin(), + pads.begin() + pads.size() / 2)) + .set_post_padding(std::vector(pads.begin() + pads.size() / 2, pads.end())) + .set_stride(strides) + .set_dilation(dilations); + s_.cudnn_fe_conv_Y = s_.cudnn_fe_graph->conv_dgrad(s_.cudnn_fe_X, s_.cudnn_fe_W, conv_options); + auto cudnn_fe_y_tensor = CudnnFeTensor(y_dims, "y", data_type, Layout == LAYOUT_NHWC).Get(); + + if (B == nullptr) { + s_.cudnn_fe_Y = s_.cudnn_fe_conv_Y; + } else { + int64_t bias_size; + if (B != nullptr) { + bias_size = B->Shape()[0]; + } else { + bias_size = w_dims[0]; + } + + if (fuse_bias) { + onnxruntime::TensorShapeVector b_dims; + for (size_t i = 0; i < x_dims.size(); i++) { + b_dims.push_back(i == 1 ? bias_size : 1); + } + auto bias_tensor = CudnnFeTensor(b_dims, "b", data_type, Layout == LAYOUT_NHWC).Get(); + auto bias_options = cudnn_frontend::graph::Pointwise_attributes().set_mode(cudnn_frontend::PointwiseMode_t::ADD); + s_.cudnn_fe_B = s_.cudnn_fe_graph->tensor(bias_tensor); + s_.cudnn_fe_Y = s_.cudnn_fe_graph->pointwise(s_.cudnn_fe_conv_Y, s_.cudnn_fe_B, bias_options); + } else { + s_.cudnn_fe_Y = s_.cudnn_fe_conv_Y; + + TensorShapeVector b_dims(y_dims.size(), 1); + TensorShapeVector b_strides(y_dims.size(), 1); + b_dims[1] = bias_size; + b_strides[0] = bias_size; + + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType(), b_strides)); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType(), cudnn_fe_y_tensor.get_stride())); + + /* Creating an own CUDNN Frontend graph for the bias addition. + s_.cudnn_fe_bias_graph = std::make_unique(); + s_.cudnn_fe_bias_graph->set_io_data_type(data_type) + .set_compute_data_type(data_type == cudnn_frontend::DataType_t::HALF ? + cudnn_frontend::DataType_t::FLOAT : data_type) + .set_intermediate_data_type(data_type); + s_.cudnn_fe_bias_X = s_.cudnn_fe_bias_graph->tensor(CudnnFeTensor(y_dims, "x", data_type).Get()); + + s_.cudnn_fe_B = s_.cudnn_fe_bias_graph->tensor(bias_tensor); + s_.cudnn_fe_bias_Y = s_.cudnn_fe_bias_graph->pointwise(s_.cudnn_fe_bias_X, s_.cudnn_fe_B, bias_options); + s_.cudnn_fe_bias_Y->set_output(true); + + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->validate()); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->build_operation_graph(handle)); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->create_execution_plans({heur_mode})); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->check_support(handle)); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->build_plans(handle));*/ + } + } + if (fuse_act && s_.cudnn_fe_act_attr.has_value()) { + auto& activation_attr = s_.cudnn_fe_act_attr.value(); + s_.cudnn_fe_Y = s_.cudnn_fe_graph->pointwise(s_.cudnn_fe_Y, activation_attr); } - // use pre-packed W if available - const Tensor* W = W_ ? W_.get() : context->Input(1); + s_.cudnn_fe_Y->set_dim(cudnn_fe_y_tensor.get_dim()); + s_.cudnn_fe_Y->set_stride(cudnn_fe_y_tensor.get_stride()); + s_.cudnn_fe_Y->set_output(true); + + try { + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->validate()); + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_operation_graph(handle)); + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->create_execution_plans({heur_mode})); + } catch (const std::exception& ex) { + std::string message = MakeString("Failed to initialize CUDNN Frontend", ex.what(), + "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); + return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); + } - const TensorShape& w_shape = W->Shape(); - TensorShapeVector w_dims = w_shape.AsShapeVector(); - auto w_data = reinterpret_cast(W->Data()); + if (!use_tf32) s_.cudnn_fe_graph->deselect_numeric_notes({cudnn_frontend::NumericalNote_t::TENSOR_CORE}); + + try { + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->check_support(handle)); + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_plans(handle)); + } catch (const std::exception& ex) { + if (!fuse_bias && !fuse_act && use_tf32) { + std::string message = MakeString("OP not supported by CUDNN Frontend", ex.what(), + "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); + return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); + } + + // Try fallback. + return CreateCudnnFeExecutionPlan(x_dims, w_dims, B, y_dims, handle, heur_mode, + pads, strides, dilations, false, false, w_in_nhwc, true); + } + + s_.workspace_bytes = s_.cudnn_fe_graph->get_workspace_size(); + return Status::OK(); +} + +#endif + +template +Status ConvTranspose::UpdateState(OpKernelContext* context, bool dynamic_padding) const { + constexpr bool channels_last = Layout == LAYOUT_NHWC; size_t num_inputs = OpKernel::Node().InputDefs().size(); bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; - CudaT* y_data = nullptr; + // set X + const Tensor* X = context->Input(0); + const TensorShape& x_shape = X->Shape(); + // X incl. x_dims is in NHWC Format iff. NHWC == true + const auto x_dims = x_shape.AsShapeVector(); + + s_.x_data = reinterpret_cast(X->Data()); + s_.element_size = X->DataType()->Size(); + + // set W + bool w_in_nhwc; + const Tensor* W; + if (!W_) { + W = context->Input(1); + w_in_nhwc = false; + // Dims and memory layout are in NCHW format + } else { + W = W_.get(); + w_in_nhwc = channels_last; + // W got prepacked, therefore if NHWC == true, then dims and memory layout are in NHWC + } + const TensorShape& w_shape = W->Shape(); + onnxruntime::TensorShapeVector w_dims = w_shape.AsShapeVector(); + s_.w_data = reinterpret_cast(W->Data()); + + // set B + // Always in NCHW format + const Tensor* B = nullptr; + if (has_bias) { + B = context->Input(dynamic_padding ? 3 : 2); + s_.b_data = reinterpret_cast(B->Data()); + } else { + s_.b_data = nullptr; + } + + const Tensor* Pads = dynamic_padding ? context->Input(2) : nullptr; - const auto* cuda_ep = static_cast(Info().GetExecutionProvider()); + bool input_dims_changed = (s_.last_x_dims != x_dims); + bool w_dims_changed = (s_.last_w_dims != w_dims); + if (input_dims_changed || w_dims_changed) { + if (input_dims_changed) + s_.last_x_dims = gsl::make_span(x_dims); - // convert 1D to 2D - if (x_dimensions == 3) { - // we can either add a fake H or W dimension with a value of 1. to be consistent with the Conv behavior we use - // GetCudnnConv1dPadToNc1d to determine which is added. - // see Conv::UpdateState in /onnxruntime/core/providers/cuda/nn/conv.cc for more details. - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { - // add fake H dimension - const auto insert_at = NHWC ? 1 : 2; + if (w_dims_changed) { + s_.last_w_dims = gsl::make_span(w_dims); + } - // NCHW: N, C, d1 -> N, C, 1, d1 - // NHWC: N, d1, C -> N, 1, d1, C - x_dims.insert(x_dims.begin() + insert_at, 1); + // The following code is from ConvTransposeAttributes::PrepareForCompute - // 'M' is channels dim in CUDA implementation - // NCHW: C, M/g, k1 -> C, M/g, 1, k1 - // NHWC: C, k1, M/g -> C, 1, k1, M/g - w_dims.insert(w_dims.begin() + insert_at, 1); - } else { - // add fake W dimension - const auto insert_at = NHWC ? 2 : 3; + const int rank = static_cast(X->Shape().NumDimensions()); + TensorShape input_shape = X->Shape().Slice(channels_last ? 1 : 2, channels_last ? rank - 1 : rank); + const int64_t num_input_channels = channels_last ? X->Shape()[rank - 1] : X->Shape()[1]; + const int64_t N = X->Shape()[0]; + const int64_t num_output_channels_multiplier = w_in_nhwc ? w_shape[rank - 1] : w_shape[1]; + const int64_t num_output_channels = num_output_channels_multiplier * conv_transpose_attrs_.group; - // NCHW: N, C, d1 -> N, C, d1, 1 - // NHWC: N, d1, C -> N, d1, 1, C - x_dims.insert(x_dims.begin() + insert_at, 1); + if (conv_transpose_attrs_.group <= 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "group count is <= 0", + " group: ", conv_transpose_attrs_.group); + } - // NCHW: C, M/g, k1 -> C, M/g, k1, 1 - // NHWC: C, k1, M/g -> C, k1, 1, M/g - w_dims.insert(w_dims.begin() + insert_at, 1); + if (X->Shape().NumDimensions() != w_shape.NumDimensions()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "X num_dims does not match W num_dims.", + " X: ", X->Shape().ToString().c_str(), + " W: ", w_shape.ToString().c_str()); } - } - { - std::lock_guard lock(s_.mutex); - // CUDNN_CONFIG_RETURN_IF_ERROR(cudnnSetStream(CudnnHandle(), Stream(context))); - // TODO: add a global cache if need to handle cases for multiple frames running simultaneously with different batch_size - bool input_dims_changed = (s_.last_x_dims.AsShapeVector() != x_dims); - bool w_dims_changed = (s_.last_w_dims.AsShapeVector() != w_dims); - if (input_dims_changed || w_dims_changed) { - if (input_dims_changed) { - s_.last_x_dims = gsl::make_span(x_dims); - } + if (w_shape[0] != num_input_channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "filter number not equal to input channel number.", + " filter_number: ", w_shape[0], + " num_input_channels: ", num_input_channels); + } - if (w_dims_changed) { - s_.last_w_dims = gsl::make_span(w_dims); - s_.cached_benchmark_results.clear(); - } + // it looks like num_output_channels is really k*group similar to how in the conv case + // num_input_channels is k*group. hence removing the check for num_output_channels here. - ConvTransposeAttributes::Prepare p; - // PrePack moves the M/group dimension of W to the end, with 'M' being interpreted as 'output channels' - const bool transposed_input_channels = false; - ORT_RETURN_IF_ERROR( - conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding, &w_shape, NHWC, transposed_input_channels)); - - auto y_dims = p.Y->Shape().AsShapeVector(); - if (x_dimensions == 3) { - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { - // add fake H dimension of 1 - // NCHW: N, M, d1 -> N, M, 1, d1 or - // NHWC: N, d1, M -> N, 1, d1, M - y_dims.insert(y_dims.begin() + (NHWC ? 1 : 2), 1); - p.kernel_shape.insert(p.kernel_shape.begin(), 1); - p.pads.insert(p.pads.begin(), 0); - p.pads.insert(p.pads.begin() + 2, 0); - p.strides.insert(p.strides.begin(), 1); - p.dilations.insert(p.dilations.begin(), 1); - } else { - // add fake W dimension of 1 - // NCHW: N, M, d1 -> N, M, d1, 1 or - // NHWC: N, d1, M -> N, d1, 1, M - y_dims.insert(y_dims.begin() + (NHWC ? 2 : 3), 1); - p.kernel_shape.push_back(1); - p.pads.insert(p.pads.begin() + 1, 0); - p.pads.push_back(0); - p.strides.push_back(1); - p.dilations.push_back(1); - } - } + if (num_input_channels % conv_transpose_attrs_.group != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input channels is not divisible by group.", + " num_input_channels: ", num_input_channels, + " group: ", conv_transpose_attrs_.group); + } - s_.y_dims = gsl::make_span(y_dims); + TensorShapeVector kernel_shape; + ORT_RETURN_IF_ERROR(conv_transpose_attrs_.ComputeKernelShape(w_shape, kernel_shape, w_in_nhwc)); - if (w_dims_changed) { - if constexpr (NHWC) { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), - static_cast(w_dims[0]), static_cast(w_dims[3]), - static_cast(w_dims[1]), static_cast(w_dims[2]))); - } else { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); - } - } + const size_t kernel_rank = kernel_shape.size(); - // Special case when there is a dim value of 0 in the shape. - // Return only after we have cached the following for subsequent runs : - // 1) `w_dims` in the `w_desc` - // 2) `y_dims` in s_.y_dims - if (p.Y->Shape().Size() == 0) { - return Status::OK(); + TensorShapeVector local_output_padding(conv_transpose_attrs_.output_padding); + if (local_output_padding.empty()) { + local_output_padding.resize(kernel_shape.size(), 0); + } + ConvPadVector pads; + pads.reserve(2 * (input_shape.NumDimensions())); + if (dynamic_padding) { + for (int64_t i = 0; i < Pads->Shape().SizeFromDimension(0); ++i) { + pads.push_back(Pads->Data()[i]); } + } else { + pads.assign(conv_transpose_attrs_.pads.begin(), conv_transpose_attrs_.pads.end()); + } + if (pads.empty()) { + pads.resize(kernel_shape.size() * 2, 0); + } + TensorShapeVector dilations(conv_transpose_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(kernel_shape.size(), 1); + } + TensorShapeVector strides(conv_transpose_attrs_.strides); + if (strides.empty()) { + strides.resize(kernel_shape.size(), 1); + } - if constexpr (NHWC) { - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), - static_cast(x_dims[0]), static_cast(x_dims[3]), - static_cast(x_dims[1]), static_cast(x_dims[2]))); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), - static_cast(y_dims[0]), static_cast(y_dims[3]), - static_cast(y_dims[1]), static_cast(y_dims[2]))); - } else { - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType())); - } + TensorShapeVector y_dims; - cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; - ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, - gsl::narrow_cast(conv_transpose_attrs_.group), mode, - CudnnTensor::GetDataType(), - UseTF32())); - - if (has_bias) { - const auto& b_shape = p.B->Shape(); - ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); - TensorShapeVector b_dims(2 + p.kernel_shape.size()); - b_dims[0] = 1; // N - b_dims[NHWC ? 3 : 1] = b_shape[0]; // C - for (size_t i = 0; i < p.kernel_shape.size(); i++) { - b_dims[(NHWC ? 1 : 2) + i] = 1; - } - - ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType(), NHWC)); - } + conv_transpose_attrs_.ComputePadsAndOutputShape(input_shape, num_output_channels, kernel_shape, + strides, dilations, local_output_padding, N, &pads, &y_dims, channels_last); + TensorShape Yshape(y_dims); + s_.Y = context->Output(0, Yshape); - y_data = reinterpret_cast(p.Y->MutableData()); - - if (!s_.cached_benchmark_results.contains(x_dims)) { - IAllocatorUniquePtr algo_search_workspace = - GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); - - // set math type to tensor core before algorithm search - if constexpr (std::is_same::value) { - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); - } else if constexpr (std::is_same::value) { - if (!UseTF32()) { - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); - } - } - - cudnnConvolutionBwdDataAlgoPerf_t perf; - int algo_count = 1; - CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx( - GetCudnnHandle(context), s_.w_desc, w_data, s_.x_tensor, x_data, s_.conv_desc, s_.y_tensor, y_data, 1, - &algo_count, &perf, algo_search_workspace.get(), AlgoSearchWorkspaceSize)); - s_.cached_benchmark_results.insert(x_dims, {perf.algo, perf.memory, perf.mathType}); - } + s_.y_data = reinterpret_cast(s_.Y->MutableData()); + const CUDAExecutionProvider* cuda_ep = + static_cast(this->Info().GetExecutionProvider()); - const auto& perf = s_.cached_benchmark_results.at(x_dims); - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType)); - s_.algo = perf.algo; - s_.workspace_bytes = perf.memory; - } + TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()}; + TensorShapeVector y_dims_cudnn{y_dims.begin(), y_dims.end()}; + TensorShapeVector w_dims_cudnn{w_dims.begin(), w_dims.end()}; - // The following block will be executed in case there has been no change in the shapes of the - // input and the filter compared to the previous run - if (!y_data) { - auto y_dims = s_.y_dims.AsShapeVector(); - if (x_dimensions == 3) { - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { - // erase the fake H dimension - y_dims.erase(y_dims.begin() + (NHWC ? 1 : 2)); - } else { - // erase the fake W dimension - y_dims.erase(y_dims.begin() + (NHWC ? 2 : 3)); - } - } + if constexpr (channels_last) { + x_dims_cudnn.insert(x_dims_cudnn.begin() + 1, *(x_dims_cudnn.end() - 1)); + y_dims_cudnn.insert(y_dims_cudnn.begin() + 1, *(y_dims_cudnn.end() - 1)); + x_dims_cudnn.erase(x_dims_cudnn.end() - 1); + y_dims_cudnn.erase(y_dims_cudnn.end() - 1); - Tensor* Y = context->Output(0, TensorShape(y_dims)); - y_data = reinterpret_cast(Y->MutableData()); + if (w_in_nhwc) { + w_dims_cudnn.insert(w_dims_cudnn.begin() + 1, *(w_dims_cudnn.end() - 1)); + w_dims_cudnn.erase(w_dims_cudnn.end() - 1); + } + } - // Bail out early if one of the output dimensions is zero. - if (Y->Shape().Size() == 0) { - return Status::OK(); + if (kernel_rank < 2) { + // TODO: Explore padding the provided input shape [N, C, D] to [N, C, 1, D] + // especially for EXHAUSTIVE algo search which may result in a better algo selection. + // ORTModule uses different algo search options (HEURISTIC, and use max workspace size) compared to + // inference build (EXHAUSTIVE, 32M workspace size). We observed better perf when we pad input shape + // [N,C,D] to [N,C,1,D], expecially on A100, and especially for ConvGrad. + // PyTorch also pads to [N,C,1,D]. For inference build, we still pad it to [N, C, D, 1] as this seems + // to be the sweet spot for all algo search options: EXHAUSTIVE, HEURISTIC, and DEFAULT. + // See PR #7348 and #7702 for more context. + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); + y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); + w_dims_cudnn.insert(w_dims_cudnn.begin() + 2, 1); + pads.insert(pads.begin() + kernel_rank, 0); + pads.insert(pads.begin(), 0); + kernel_shape.insert(kernel_shape.begin(), 1); + strides.insert(strides.begin(), 1); + dilations.insert(dilations.begin(), 1); + } else { + x_dims_cudnn.push_back(1); + y_dims_cudnn.push_back(1); + w_dims_cudnn.push_back(1); + pads.insert(pads.begin() + kernel_rank, 0); + pads.insert(pads.end(), 0); + kernel_shape.push_back(1); + strides.push_back(1); + dilations.push_back(1); } } - const auto alpha = Consts::One; - const auto beta = Consts::Zero; + // We must delay returning early until here so that the weight dims have been cached properly + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } - IAllocatorUniquePtr workspace = GetScratchBuffer(s_.workspace_bytes, context->GetComputeStream()); + auto handle = GetCudnnHandle(context); + + int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); +#if !defined(__CUDACC__) + cudnn_frontend::HeurMode_t heur_mode; + switch (cudnn_conv_algo) { + case 0: + heur_mode = cudnn_frontend::HeurMode_t::B; + break; + case 1: + heur_mode = cudnn_frontend::HeurMode_t::A; + break; + case 2: + heur_mode = cudnn_frontend::HeurMode_t::FALLBACK; + break; + default: + heur_mode = cudnn_frontend::HeurMode_t::A; + break; + } - CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardData(GetCudnnHandle(context), &alpha, s_.w_desc, w_data, s_.x_tensor, - x_data, s_.conv_desc, s_.algo, workspace.get(), - s_.workspace_bytes, &beta, s_.y_tensor, y_data)); + auto use_tf32 = cuda_ep->UseTF32(); + const auto fuse_bias = cuda_ep->IsFuseConvBias() || is_fused_node_; + const auto fuse_act = is_fused_node_; + + ORT_RETURN_IF_ERROR(CreateCudnnFeExecutionPlan(x_dims_cudnn, w_dims_cudnn, B, y_dims_cudnn, handle, heur_mode, + std::vector(pads.begin(), + pads.end()), + std::vector(strides.begin(), + strides.end()), + std::vector(dilations.begin(), + dilations.end()), + fuse_bias, fuse_act, w_in_nhwc, use_tf32)); +#endif + } else { + // set Y + s_.Y = context->Output(0, s_.y_dims); + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + s_.y_data = reinterpret_cast(s_.Y->MutableData()); + } + return Status::OK(); +} - if (has_bias) { - const Tensor* B = dynamic_padding ? context->Input(3) : context->Input(2); - auto b_data = reinterpret_cast(B->Data()); - CUDNN_RETURN_IF_ERROR( - cudnnAddTensor(GetCudnnHandle(context), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor, y_data)); +template +Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { + std::lock_guard lock(s_.mutex); + ORT_RETURN_IF_ERROR(UpdateState(context, dynamic_padding)); + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + const auto alpha = onnxruntime::cuda::Consts::One; + auto cudnn_handle = GetCudnnHandle(context); +#if !defined(__CUDACC__) + s_.variant_pack.insert_or_assign(s_.cudnn_fe_X, const_cast(s_.x_data)); + s_.variant_pack.insert_or_assign(s_.cudnn_fe_W, const_cast(s_.w_data)); + s_.variant_pack.insert_or_assign(s_.cudnn_fe_Y, s_.y_data); + if (s_.bias_fused && s_.b_data != nullptr) { + s_.variant_pack.insert_or_assign(s_.cudnn_fe_B, const_cast(s_.b_data)); + } + if (s_.bias_fused && s_.z_data != nullptr) { + s_.variant_pack.insert_or_assign(s_.cudnn_fe_Z, const_cast(s_.z_data)); + if (Layout == LAYOUT_NCHW && s_.z_data == s_.y_data) { + // memset Z if it's required for a succesful fusion + CUDA_RETURN_IF_ERROR(cudaMemset(s_.y_data, 0, s_.Y->SizeInBytes())); } } + auto ws = GetWorkSpace(context->GetComputeStream()); + + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_graph->execute(cudnn_handle, + s_.variant_pack, + ws.get())); + + if (!s_.bias_fused && s_.z_data != nullptr) { + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.z_tensor, s_.z_data, + &alpha, s_.y_tensor, s_.y_data)); + } + if (!s_.bias_fused && s_.b_data != nullptr) { + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.b_tensor, s_.b_data, + &alpha, s_.y_tensor, s_.y_data)); + + /* For the standalone bias addition graph. + s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_bias_X, s_.y_data); + s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_bias_Y, s_.y_data); + s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_B, const_cast(s_.b_data)); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->execute(cudnn_handle, + s_.variant_pack_bias, + GetWorkSpace(context->GetComputeStream()).get()));*/ + } +#endif return Status::OK(); } +#endif + +template +Status ConvTranspose::ComputeInternal(OpKernelContext* context) const { + return DoConvTranspose(context, false); +} } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.h b/onnxruntime/core/providers/cuda/nn/conv_transpose.h index 71ad3ee6e2147..3b8f117522210 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.h @@ -18,7 +18,9 @@ namespace cuda { template class ConvTranspose : public CudaKernel { public: - ConvTranspose(const OpKernelInfo& info) : CudaKernel(info), conv_transpose_attrs_(info) {}; + using CudaT = typename ToCudaType::MappedType; + + ConvTranspose(const OpKernelInfo& info) : CudaKernel(info), conv_transpose_attrs_(info){}; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; Status ComputeInternal(OpKernelContext* context) const override; @@ -29,6 +31,33 @@ class ConvTranspose : public CudaKernel { mutable CudnnConvState s_; std::unique_ptr W_; + + bool is_nhwc_domain_; // prepack is only needed for the Conv in kMSInternalNHWCDomain + bool is_fused_node_ = false; // ensures the node is fused although the session option is not set + bool W_already_nhwc = false; // In case NHWC == true and Conv is not in kMSInternalNHWCDomain + + protected: + inline IAllocatorUniquePtr GetWorkSpace(onnxruntime::Stream* stream) const { + return GetScratchBuffer(s_.workspace_bytes, stream); + } + + Status UpdateState(OpKernelContext* context, bool bias_expected) const; + +#if !defined(__CUDACC__) && CUDNN_MAJOR >= 9 + Status CreateCudnnFeExecutionPlan(const onnxruntime::TensorShapeVector& x_dims, + const onnxruntime::TensorShapeVector& w_dims, + const Tensor* B, + const TensorShapeVector& y_dims, + cudnnContext* handle, + const cudnn_frontend::HeurMode_t heur_mode, + const std::vector& pads, + const std::vector& strides, + const std::vector& dilations, + const bool fuse_bias, + const bool fuse_act, + const bool w_in_nhwc, + const bool use_tf32) const; +#endif }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h b/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h new file mode 100644 index 0000000000000..b46d41b887e41 --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. +// Licensed under the MIT License. + +#include + +#include "conv_transpose.h" +#include + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/nn/conv.h" +#include "core/providers/cpu/nn/conv_transpose_attributes.h" + +#include "core/providers/cuda/tensor/transpose.h" + +// To suppress FP static analyzer warnings: +// https://msdata.visualstudio.com/Vienna/_workitems/edit/1944928 and +// https://msdata.visualstudio.com/Vienna/_workitems/edit/1944950 +#ifdef _WIN32 +#pragma warning(push) +#pragma warning(disable : 26110) +#pragma warning(disable : 26117) +#endif + +namespace onnxruntime { +namespace cuda { + +template +Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { + const Tensor* X = context->Input(0); + const TensorShape& x_shape = X->Shape(); + auto x_dims = x_shape.AsShapeVector(); + auto x_data = reinterpret_cast(X->Data()); + + auto x_dimensions = X->Shape().NumDimensions(); + if (x_dimensions < 3 || x_dimensions > 5) { + // TODO: the error message should tell which operator raises it. + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3-, 4- or 5-dimensional.", + " X: ", X->Shape().ToString().c_str()); + } + + // use pre-packed W if available + const Tensor* W = W_ ? W_.get() : context->Input(1); + + const TensorShape& w_shape = W->Shape(); + TensorShapeVector w_dims = w_shape.AsShapeVector(); + auto w_data = reinterpret_cast(W->Data()); + + size_t num_inputs = OpKernel::Node().InputDefs().size(); + bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; + + CudaT* y_data = nullptr; + + const auto* cuda_ep = static_cast(Info().GetExecutionProvider()); + + // convert 1D to 2D + if (x_dimensions == 3) { + // we can either add a fake H or W dimension with a value of 1. to be consistent with the Conv behavior we use + // GetCudnnConv1dPadToNc1d to determine which is added. + // see Conv::UpdateState in /onnxruntime/core/providers/cuda/nn/conv.cc for more details. + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + // add fake H dimension + const auto insert_at = NHWC ? 1 : 2; + + // NCHW: N, C, d1 -> N, C, 1, d1 + // NHWC: N, d1, C -> N, 1, d1, C + x_dims.insert(x_dims.begin() + insert_at, 1); + + // 'M' is channels dim in CUDA implementation + // NCHW: C, M/g, k1 -> C, M/g, 1, k1 + // NHWC: C, k1, M/g -> C, 1, k1, M/g + w_dims.insert(w_dims.begin() + insert_at, 1); + } else { + // add fake W dimension + const auto insert_at = NHWC ? 2 : 3; + + // NCHW: N, C, d1 -> N, C, d1, 1 + // NHWC: N, d1, C -> N, d1, 1, C + x_dims.insert(x_dims.begin() + insert_at, 1); + + // NCHW: C, M/g, k1 -> C, M/g, k1, 1 + // NHWC: C, k1, M/g -> C, k1, 1, M/g + w_dims.insert(w_dims.begin() + insert_at, 1); + } + } + + { + std::lock_guard lock(s_.mutex); + // CUDNN_CONFIG_RETURN_IF_ERROR(cudnnSetStream(CudnnHandle(), Stream(context))); + // TODO: add a global cache if need to handle cases for multiple frames running simultaneously with + // different batch_size + bool input_dims_changed = (s_.last_x_dims.AsShapeVector() != x_dims); + bool w_dims_changed = (s_.last_w_dims.AsShapeVector() != w_dims); + if (input_dims_changed || w_dims_changed) { + if (input_dims_changed) { + s_.last_x_dims = gsl::make_span(x_dims); + } + + if (w_dims_changed) { + s_.last_w_dims = gsl::make_span(w_dims); + s_.cached_benchmark_results.clear(); + } + + ConvTransposeAttributes::Prepare p; + // PrePack moves the M/group dimension of W to the end, with 'M' being interpreted as 'output channels' + const bool transposed_input_channels = false; + ORT_RETURN_IF_ERROR( + conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding, + &w_shape, NHWC, transposed_input_channels)); + + auto y_dims = p.Y->Shape().AsShapeVector(); + if (x_dimensions == 3) { + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + // add fake H dimension of 1 + // NCHW: N, M, d1 -> N, M, 1, d1 or + // NHWC: N, d1, M -> N, 1, d1, M + y_dims.insert(y_dims.begin() + (NHWC ? 1 : 2), 1); + p.kernel_shape.insert(p.kernel_shape.begin(), 1); + p.pads.insert(p.pads.begin(), 0); + p.pads.insert(p.pads.begin() + 2, 0); + p.strides.insert(p.strides.begin(), 1); + p.dilations.insert(p.dilations.begin(), 1); + } else { + // add fake W dimension of 1 + // NCHW: N, M, d1 -> N, M, d1, 1 or + // NHWC: N, d1, M -> N, d1, 1, M + y_dims.insert(y_dims.begin() + (NHWC ? 2 : 3), 1); + p.kernel_shape.push_back(1); + p.pads.insert(p.pads.begin() + 1, 0); + p.pads.push_back(0); + p.strides.push_back(1); + p.dilations.push_back(1); + } + } + + s_.y_dims = gsl::make_span(y_dims); + + if (w_dims_changed) { + if constexpr (NHWC) { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + static_cast(w_dims[0]), static_cast(w_dims[3]), + static_cast(w_dims[1]), static_cast(w_dims[2]))); + } else { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); + } + } + + // Special case when there is a dim value of 0 in the shape. + // Return only after we have cached the following for subsequent runs : + // 1) `w_dims` in the `w_desc` + // 2) `y_dims` in s_.y_dims + if (p.Y->Shape().Size() == 0) { + return Status::OK(); + } + + if constexpr (NHWC) { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + static_cast(x_dims[0]), static_cast(x_dims[3]), + static_cast(x_dims[1]), static_cast(x_dims[2]))); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + static_cast(y_dims[0]), static_cast(y_dims[3]), + static_cast(y_dims[1]), static_cast(y_dims[2]))); + } else { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType())); + } + + cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; + ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, + gsl::narrow_cast(conv_transpose_attrs_.group), mode, + CudnnTensor::GetDataType(), + UseTF32())); + + if (has_bias) { + const auto& b_shape = p.B->Shape(); + ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); + TensorShapeVector b_dims(2 + p.kernel_shape.size()); + b_dims[0] = 1; // N + b_dims[NHWC ? 3 : 1] = b_shape[0]; // C + for (size_t i = 0; i < p.kernel_shape.size(); i++) { + b_dims[(NHWC ? 1 : 2) + i] = 1; + } + + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType(), NHWC)); + } + + y_data = reinterpret_cast(p.Y->MutableData()); + + if (!s_.cached_benchmark_results.contains(x_dims)) { + IAllocatorUniquePtr algo_search_workspace = + GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); + + // set math type to tensor core before algorithm search + if constexpr (std::is_same::value) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } + + cudnnConvolutionBwdDataAlgoPerf_t perf; + int algo_count = 1; + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx( + GetCudnnHandle(context), s_.w_desc, w_data, s_.x_tensor, x_data, s_.conv_desc, s_.y_tensor, y_data, 1, + &algo_count, &perf, algo_search_workspace.get(), AlgoSearchWorkspaceSize)); + s_.cached_benchmark_results.insert(x_dims, {perf.algo, perf.memory, perf.mathType}); + } + + const auto& perf = s_.cached_benchmark_results.at(x_dims); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType)); + s_.algo = perf.algo; + s_.workspace_bytes = perf.memory; + } + + // The following block will be executed in case there has been no change in the shapes of the + // input and the filter compared to the previous run + if (!y_data) { + auto y_dims = s_.y_dims.AsShapeVector(); + if (x_dimensions == 3) { + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + // erase the fake H dimension + y_dims.erase(y_dims.begin() + (NHWC ? 1 : 2)); + } else { + // erase the fake W dimension + y_dims.erase(y_dims.begin() + (NHWC ? 2 : 3)); + } + } + + Tensor* Y = context->Output(0, TensorShape(y_dims)); + y_data = reinterpret_cast(Y->MutableData()); + + // Bail out early if one of the output dimensions is zero. + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + } + + const auto alpha = Consts::One; + const auto beta = Consts::Zero; + + IAllocatorUniquePtr workspace = GetScratchBuffer(s_.workspace_bytes, context->GetComputeStream()); + + CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardData(GetCudnnHandle(context), &alpha, s_.w_desc, w_data, + s_.x_tensor, x_data, s_.conv_desc, s_.algo, workspace.get(), + s_.workspace_bytes, &beta, s_.y_tensor, y_data)); + + if (has_bias) { + const Tensor* B = dynamic_padding ? context->Input(3) : context->Input(2); + auto b_data = reinterpret_cast(B->Data()); + CUDNN_RETURN_IF_ERROR( + cudnnAddTensor(GetCudnnHandle(context), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor, y_data)); + } + } + + return Status::OK(); +} + +} // namespace cuda +} // namespace onnxruntime + +#ifdef _WIN32 +#pragma warning(pop) +#endif From cfd7da71efc78a208e194fefd4c66da46eb1a540 Mon Sep 17 00:00:00 2001 From: Julius Tischbein Date: Fri, 16 Aug 2024 09:55:55 +0200 Subject: [PATCH 2/5] Adding [[maybe_unused]] to logger in ConvTransposeNeedFallbackToCPU and Linting --- onnxruntime/core/providers/cuda/cuda_execution_provider.cc | 3 ++- onnxruntime/core/providers/cuda/nn/conv_transpose.h | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 16d18ee4d86fb..7669ed5ae0fbc 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2408,7 +2408,8 @@ static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node, return false; } -static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const logging::Logger& logger, +static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, + [[maybe_unused]] const logging::Logger& logger, [[maybe_unused]] const GraphViewer& graph_viewer, [[maybe_unused]] const bool prefer_nhwc) { #if CUDNN_MAJOR < 9 diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.h b/onnxruntime/core/providers/cuda/nn/conv_transpose.h index 3b8f117522210..1a6957164d22f 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.h @@ -20,7 +20,7 @@ class ConvTranspose : public CudaKernel { public: using CudaT = typename ToCudaType::MappedType; - ConvTranspose(const OpKernelInfo& info) : CudaKernel(info), conv_transpose_attrs_(info){}; + ConvTranspose(const OpKernelInfo& info) : CudaKernel(info), conv_transpose_attrs_(info) {}; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; Status ComputeInternal(OpKernelContext* context) const override; From b5a48161be8b2a534de08f2dabcdbad22ba9d72c Mon Sep 17 00:00:00 2001 From: Julius Tischbein Date: Mon, 19 Aug 2024 10:52:54 +0200 Subject: [PATCH 3/5] Adding node as maybe_unused, due to MSVC error --- onnxruntime/core/providers/cuda/cuda_execution_provider.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 7669ed5ae0fbc..30cc0addb2d7c 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2408,7 +2408,7 @@ static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node, return false; } -static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, +static bool ConvTransposeNeedFallbackToCPU([[maybe_unused]] const onnxruntime::Node& node, [[maybe_unused]] const logging::Logger& logger, [[maybe_unused]] const GraphViewer& graph_viewer, [[maybe_unused]] const bool prefer_nhwc) { From 60a048426461b8131c2079ffbcd70824f3ba70fa Mon Sep 17 00:00:00 2001 From: Julius Tischbein Date: Mon, 2 Sep 2024 12:50:55 +0200 Subject: [PATCH 4/5] Bringing back conv transpose fallback in case of asymmetric padding --- onnxruntime/core/providers/cuda/cuda_execution_provider.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 30cc0addb2d7c..62a7b4afb7071 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2412,7 +2412,6 @@ static bool ConvTransposeNeedFallbackToCPU([[maybe_unused]] const onnxruntime::N [[maybe_unused]] const logging::Logger& logger, [[maybe_unused]] const GraphViewer& graph_viewer, [[maybe_unused]] const bool prefer_nhwc) { -#if CUDNN_MAJOR < 9 const auto& node_attributes = node.GetAttributes(); // Check attributes for (auto& attr : node_attributes) { @@ -2459,7 +2458,6 @@ static bool ConvTransposeNeedFallbackToCPU([[maybe_unused]] const onnxruntime::N } } } -#endif #ifdef ENABLE_CUDA_NHWC_OPS if (prefer_nhwc) { From 3e1a2d1a900726b1b461aa2c4e89f4bf383deb99 Mon Sep 17 00:00:00 2001 From: Julius Tischbein Date: Tue, 10 Sep 2024 13:29:24 +0200 Subject: [PATCH 5/5] ConvTranspose fix: save y_dims for next run --- onnxruntime/core/providers/cuda/nn/conv_transpose.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index b3fee863b5f19..d4876e1714861 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -345,8 +345,9 @@ Status ConvTranspose::UpdateState(OpKernelContext* context, bool dyna conv_transpose_attrs_.ComputePadsAndOutputShape(input_shape, num_output_channels, kernel_shape, strides, dilations, local_output_padding, N, &pads, &y_dims, channels_last); - TensorShape Yshape(y_dims); - s_.Y = context->Output(0, Yshape); + + s_.y_dims = gsl::make_span(y_dims); + s_.Y = context->Output(0, s_.y_dims); s_.y_data = reinterpret_cast(s_.Y->MutableData()); const CUDAExecutionProvider* cuda_ep =