From b76994dc3a98715f624a9263bb0ff123b03b1e25 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 7 Oct 2023 09:05:02 -0700 Subject: [PATCH] Improve CUDA EP's GetCapability (#17809) Improve CUDA EP's GetCapability: Add layout transformer support. Currently the code detects if a node is already assigned to some EP, if yes, it will directly return. ```c++ if (!node.GetExecutionProviderType().empty()) { return; } ``` So, if you call the GetCapability function twice, ```c++ auto caps = GetCapability(); assign_nodes_to_eps(..., caps, ...); auto caps2 = GetCapability(); ``` The second GetCapability() call will return fewer results than the first one. Layout transformer needs to call GetCapability twice as above. So the current GetCapability() implementation is incompatible with the Layout transformer. It is not an issue right now because the CUDA EP doesn't need to do layout transform. But we might want to support a different layout. --- .../providers/cuda/cuda_execution_provider.cc | 23 ++++++++++++------- .../providers/shared_library/provider_api.h | 3 +++ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index ad892eab3b843..de01e240a06c7 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -373,7 +373,7 @@ Status CUDAExecutionProvider::OnRunStart() { // always set CUDA device when session::Run() in case it runs in a worker thread CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId())); if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { - LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; + LOGS(*GetLogger(), INFO) << "Capturing the cuda graph for this model"; GetPerThreadContext().CaptureBegin(); } return Status::OK(); @@ -2410,7 +2410,7 @@ static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node, return false; } -static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node) { +static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const logging::Logger& logger) { const auto& node_attributes = node.GetAttributes(); // Check attributes for (auto& attr : node_attributes) { @@ -2428,7 +2428,7 @@ static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node) { int rank = pads_size / 2; for (int i = 0; i < rank; i++) { if (pads.Get(i) != pads.Get(i + rank)) { - LOGS_DEFAULT(WARNING) << "Dropping the ConvTranspose node: " << node.Name() + LOGS(logger, WARNING) << "Dropping the ConvTranspose node: " << node.Name() << " to CPU because it requires asymmetric padding which the CUDA EP" << " currently does not support"; return true; @@ -2450,7 +2450,7 @@ static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node) { // symmetric padding. // TODO: Remove this after we have supported asymmetric padding in the CUDA ConvTranspose kernel if (auto_pad_attr == "SAME_UPPER" || auto_pad_attr == "SAME_LOWER") { - LOGS_DEFAULT(WARNING) << "Dropping the ConvTranspose node: " << node.Name() + LOGS(logger, WARNING) << "Dropping the ConvTranspose node: " << node.Name() << " to CPU because it uses the auto_pad attribute which may lead to asymmetric padding which" << " the CUDA EP currently does not support"; return true; @@ -2487,6 +2487,9 @@ std::vector> CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup) const { InlinedVector candidates; + // A subset of the above vector. A subset of the tentative_nodes might be moved to CPU. + InlinedVector tentative_nodes; + const logging::Logger& logger = *GetLogger(); for (auto& node_index : graph.GetNodesInTopologicalOrder()) { const auto* p_node = graph.GetNode(node_index); if (p_node == nullptr) @@ -2494,13 +2497,16 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const auto& node = *p_node; if (!node.GetExecutionProviderType().empty()) { + if (node.GetExecutionProviderType() == kCudaExecutionProvider) { + candidates.push_back(node.Index()); + } continue; } const KernelCreateInfo* cuda_kernel_def = kernel_lookup.LookUpKernel(node); // none of the provided registries has a CUDA kernel for this node if (cuda_kernel_def == nullptr) { - LOGS_DEFAULT(INFO) << "CUDA kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name(); + LOGS(logger, INFO) << "CUDA kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name(); continue; } @@ -2520,7 +2526,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, not_supported = RNNNeedFallbackToCPU(node, activations_supported, node.OpType()); force_inside = !not_supported; } else if ("ConvTranspose" == node.OpType()) { - not_supported = ConvTransposeNeedFallbackToCPU(node); + not_supported = ConvTransposeNeedFallbackToCPU(node, logger); force_inside = !not_supported; } else if ("Cast" == node.OpType()) { not_supported = CastNeedFallbackToCPU(node); @@ -2529,9 +2535,10 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, if (!force_inside && not_supported) { if (not_supported) { - LOGS_DEFAULT(WARNING) << "CUDA kernel not supported. Fallback to CPU execution provider for Op type: " << node.OpType() << " node name: " << node.Name(); + LOGS(logger, WARNING) << "CUDA kernel not supported. Fallback to CPU execution provider for Op type: " << node.OpType() << " node name: " << node.Name(); } } else { + tentative_nodes.push_back(node.Index()); candidates.push_back(node.Index()); } } @@ -2539,7 +2546,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // For CUDA EP, exclude the subgraph that is preferred to be placed in CPU // These are usually shape related computation subgraphs // Following logic can be extended for other EPs - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, candidates); + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 0d7da46142170..85599fab808b3 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -350,6 +350,9 @@ void InitProviderOrtApi(); if ((logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, ::onnxruntime::logging::DataType::SYSTEM)) \ CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::SYSTEM)->Stream() +#define LOGS(logger, severity) \ + LOGS_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime) + #define LOGS_DEFAULT_CATEGORY(severity, category) \ LOGS_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category)