Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] How can I forcefully assign nodes to CUDA EP? #17930

Closed
dhung-msft opened this issue Oct 13, 2023 · 5 comments
Closed

[Performance] How can I forcefully assign nodes to CUDA EP? #17930

dhung-msft opened this issue Oct 13, 2023 · 5 comments
Labels
ep:CUDA issues related to the CUDA execution provider

Comments

@dhung-msft
Copy link

dhung-msft commented Oct 13, 2023

Describe the issue

Similar to #16863, I have a model which are assigned to CPU with the following warning:

[W:onnxruntime:, session_state.cc:1030 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.

The nodes placed on CPU EP are Shape and nodes that consume from the Shape output (such as Gather, Unsqueeze, Concat in my model). I want to run the model using CUDA Graphs, so I need all nodes to be placed on the CUDA EP. Is there a way I can force assignment of these nodes onto CUDA EP?

To reproduce

import onnx
from onnx import TensorProto
from onnx.helper import make_model, make_node, make_graph, make_tensor_value_info, make_opsetid
from onnx.checker import check_model
import numpy as np
import onnxruntime as ort

def create_model() -> onnx.ModelProto:
    # Create a simple graph that includes Shape op
    data = make_tensor_value_info('data', TensorProto.FLOAT, [None, None])
    indices = make_tensor_value_info('indices', TensorProto.INT32, [None])
    output = make_tensor_value_info('output', TensorProto.INT64, [None])

    # Create nodes
    node0 = make_node('Shape', ['data'], ['shape'])
    node1 = make_node('Gather', ['shape', 'indices'], ['output'])
    
    # from nodes to graph
    graph = make_graph([node0, node1],  # nodes
                        'graph',  # a name
                        [data, indices],  # inputs
                        [output])  # outputs
    onnx_model = make_model(graph, opset_imports=[make_opsetid('', 17)])
    check_model(onnx_model)
    return onnx_model

def run_model(model_proto: bytes):
    providers = [("CUDAExecutionProvider", {'enable_cuda_graph': True})]
    x = np.zeros((2,5), dtype=np.float32)
    idx = np.array([0], dtype=np.int32)
    y = np.zeros_like(idx, dtype=np.float32)
    x_ortvalue = ort.OrtValue.ortvalue_from_numpy(x, 'cuda', 0)
    idx_ortvalue = ort.OrtValue.ortvalue_from_numpy(idx, 'cuda', 0)
    y_ortvalue = ort.OrtValue.ortvalue_from_numpy(y, 'cuda', 0)
    
    session = ort.InferenceSession(model_proto, providers=providers)
    io_binding = session.io_binding()
    
    # Bind the input and output
    io_binding.bind_ortvalue_input('data', x_ortvalue)
    io_binding.bind_ortvalue_input('indices', idx_ortvalue)
    io_binding.bind_ortvalue_output('output', y_ortvalue)
    
    # Run
    session.run_with_iobinding(io_binding)
    
if __name__ == '__main__':
    model = create_model()
    run_model(model.SerializeToString())

Urgency

No response

Platform

Linux

OS Version

Ubuntu 20.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.13.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

No response

Model File

No response

Is this a quantized model?

No

@github-actions github-actions bot added the ep:CUDA issues related to the CUDA execution provider label Oct 13, 2023
@tianleiwu
Copy link
Contributor

You can try ORT 1.16.1. Running some nodes in CPU does not always prevent CUDA Graph in ORT 1.16.1. However, if there is tensor need copy from host to device, or device to host, that will prevent CUDA Graph.

If that does not work, try session.disable_fallback() to see whether it could help.

If the above does not work, you will need use some offline script (like this) to remove shape computation nodes, and replace with initializers.

@skottmckay
Copy link
Contributor

I need all nodes to be placed on the CUDA EP

At the cost of performance?

Whilst the CUDA EP has a Shape operator, it's using the CPU EPs implementation as the shape information is in CPU allocated memory not CUDA memory.

#include "core/providers/cpu/tensor/shape_op.h"

The output of the CUDA Shape node is CPU based memory.

.OutputMemoryType(OrtMemTypeCPUInput, 0)

What's generally happening is that after a Shape node some manipulations happen to this CPU based data (Gather, Slice, Unsqueeze, Concat type things) and it's less efficient to attempt to do that on CUDA.

// The algo below is trying to identity a subgraph that only depends on cpu tensors.
// Usually it is a subgraph that doing shape calculation based on a GPU tensor, then reshape it back.
// The detail:
// for each candidate, if one of its input is a cpu tensor and the Non-CPU kernel doesn't mark it as cpu input,
// force the node to CPU to avoid memory cpu and add its output to the small cpu tensors.

If you set log level to INFO you should see messages from here:

LOGS_DEFAULT(INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name()
<< " because the CPU execution path is deemed faster than overhead involved with execution on other EPs "
<< " capable of executing this node";

@dhung-msft
Copy link
Author

I need all nodes to be placed on the CUDA EP

At the cost of performance?

Yes, just to compare performance as-is vs using CUDA Graphs. However if even the CUDA Shape node outputs to CPU memory then sounds like it still won't work with CUDA Graphs. I'll try @tianleiwu's suggestions and report back.

@tianleiwu
Copy link
Contributor

In ORT 1.16.1, when the shape computation (in CPU) ends with a Reshape node (That's the common use case), it could still work with CUDA graph.

@dhung-msft
Copy link
Author

The model runs fine with CUDA Graphs after upgrading to ORT 1.16.1. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider
Projects
None yet
Development

No branches or pull requests

3 participants