Skip to content

Commit

Permalink
add python function def test_shape_reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
jslhcl committed Mar 22, 2024
1 parent 45bae60 commit 5c083eb
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions onnxruntime/test/python/onnxruntime_test_python_cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,31 @@ def run_model_with_cuda_graph(self, providers):
atol=1e-05,
)

def test_shape_reshape(self):
providers = [("CUDAExecutionProvider", {"enable_cuda_graph": True})]
x = np.random.rand(1,3,60,10).astype(np.float32)
y = np.random.rand(1,3,60,10).astype(np.float32)
x_ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(x, "cuda", 0)
y_ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(y, "cuda", 0)

onnxrt.set_default_logger_severity(0)
session = onnxrt.InferenceSession("/bert_ort/leca/models/prod_model11/prod_model11.onnx", providers=providers)
io_binding = session.io_binding()

# Bind the input and output
io_binding.bind_ortvalue_input("data", x_ortvalue)
io_binding.bind_ortvalue_output("score", y_ortvalue)

# One regular run for the necessary memory allocation and cuda graph capturing
session.run_with_iobinding(io_binding)

# After capturing, CUDA graph replay happens from this Run onwards
session.run_with_iobinding(io_binding)

# Update input and then replay CUDA graph
x_ortvalue.update_inplace(np.random.rand(1,3,60,10).astype(np.float32))
session.run_with_iobinding(io_binding)

def run_model_with_cuda_graph_annotation(self, providers):
INPUT_SIZE = 1280 # noqa: N806

Expand Down

0 comments on commit 5c083eb

Please sign in to comment.