Skip to content

Commit

Permalink
Improve CUDA EP's GetCapability (#17809)
Browse files Browse the repository at this point in the history
Improve CUDA EP's GetCapability: Add layout transformer support.   
Currently the code detects if a node is already assigned to some EP, if
yes, it will directly return.
```c++
    if (!node.GetExecutionProviderType().empty()) {
      return;
     }
```

So, if you call the GetCapability function twice,

```c++
auto caps = GetCapability();
assign_nodes_to_eps(..., caps, ...);
auto caps2 = GetCapability();
```
The second GetCapability() call will return fewer results than the first
one. Layout transformer needs to call GetCapability twice as above. So
the current GetCapability() implementation is incompatible with the
Layout transformer. It is not an issue right now because the CUDA EP
doesn't need to do layout transform.  But we might want to support a
different layout.
  • Loading branch information
snnn authored Oct 7, 2023
1 parent 37f4f27 commit b76994d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
23 changes: 15 additions & 8 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ Status CUDAExecutionProvider::OnRunStart() {
// always set CUDA device when session::Run() in case it runs in a worker thread
CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId()));
if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) {
LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model";
LOGS(*GetLogger(), INFO) << "Capturing the cuda graph for this model";
GetPerThreadContext().CaptureBegin();
}
return Status::OK();
Expand Down Expand Up @@ -2410,7 +2410,7 @@ static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node,
return false;
}

static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node) {
static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const logging::Logger& logger) {
const auto& node_attributes = node.GetAttributes();
// Check attributes
for (auto& attr : node_attributes) {
Expand All @@ -2428,7 +2428,7 @@ static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node) {
int rank = pads_size / 2;
for (int i = 0; i < rank; i++) {
if (pads.Get(i) != pads.Get(i + rank)) {
LOGS_DEFAULT(WARNING) << "Dropping the ConvTranspose node: " << node.Name()
LOGS(logger, WARNING) << "Dropping the ConvTranspose node: " << node.Name()
<< " to CPU because it requires asymmetric padding which the CUDA EP"
<< " currently does not support";
return true;
Expand All @@ -2450,7 +2450,7 @@ static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node) {
// symmetric padding.
// TODO: Remove this after we have supported asymmetric padding in the CUDA ConvTranspose kernel
if (auto_pad_attr == "SAME_UPPER" || auto_pad_attr == "SAME_LOWER") {
LOGS_DEFAULT(WARNING) << "Dropping the ConvTranspose node: " << node.Name()
LOGS(logger, WARNING) << "Dropping the ConvTranspose node: " << node.Name()
<< " to CPU because it uses the auto_pad attribute which may lead to asymmetric padding which"
<< " the CUDA EP currently does not support";
return true;
Expand Down Expand Up @@ -2487,20 +2487,26 @@ std::vector<std::unique_ptr<ComputeCapability>>
CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
const IKernelLookup& kernel_lookup) const {
InlinedVector<NodeIndex> candidates;
// A subset of the above vector. A subset of the tentative_nodes might be moved to CPU.
InlinedVector<NodeIndex> tentative_nodes;
const logging::Logger& logger = *GetLogger();
for (auto& node_index : graph.GetNodesInTopologicalOrder()) {
const auto* p_node = graph.GetNode(node_index);
if (p_node == nullptr)
continue;

const auto& node = *p_node;
if (!node.GetExecutionProviderType().empty()) {
if (node.GetExecutionProviderType() == kCudaExecutionProvider) {
candidates.push_back(node.Index());
}
continue;
}

const KernelCreateInfo* cuda_kernel_def = kernel_lookup.LookUpKernel(node);
// none of the provided registries has a CUDA kernel for this node
if (cuda_kernel_def == nullptr) {
LOGS_DEFAULT(INFO) << "CUDA kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name();
LOGS(logger, INFO) << "CUDA kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name();
continue;
}

Expand All @@ -2520,7 +2526,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
not_supported = RNNNeedFallbackToCPU(node, activations_supported, node.OpType());
force_inside = !not_supported;
} else if ("ConvTranspose" == node.OpType()) {
not_supported = ConvTransposeNeedFallbackToCPU(node);
not_supported = ConvTransposeNeedFallbackToCPU(node, logger);
force_inside = !not_supported;
} else if ("Cast" == node.OpType()) {
not_supported = CastNeedFallbackToCPU(node);
Expand All @@ -2529,17 +2535,18 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,

if (!force_inside && not_supported) {
if (not_supported) {
LOGS_DEFAULT(WARNING) << "CUDA kernel not supported. Fallback to CPU execution provider for Op type: " << node.OpType() << " node name: " << node.Name();
LOGS(logger, WARNING) << "CUDA kernel not supported. Fallback to CPU execution provider for Op type: " << node.OpType() << " node name: " << node.Name();
}
} else {
tentative_nodes.push_back(node.Index());
candidates.push_back(node.Index());
}
}

// For CUDA EP, exclude the subgraph that is preferred to be placed in CPU
// These are usually shape related computation subgraphs
// Following logic can be extended for other EPs
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, candidates);
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes);
std::vector<std::unique_ptr<ComputeCapability>> result;
for (auto& node_index : candidates) {
if (cpu_nodes.count(node_index) > 0)
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/shared_library/provider_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,9 @@ void InitProviderOrtApi();
if ((logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, ::onnxruntime::logging::DataType::SYSTEM)) \
CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::SYSTEM)->Stream()

#define LOGS(logger, severity) \
LOGS_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime)

#define LOGS_DEFAULT_CATEGORY(severity, category) \
LOGS_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category)

Expand Down

0 comments on commit b76994d

Please sign in to comment.