-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TensorRT EP] Fix bug for shape tensor input (#18253)
When the model has "shape tensor" as one of the inputs and user provides explicit profile shapes for it, TRT EP doesn't correctly set the "shape tensor" input. Also, there is a bug for applying explicit profile shapes for the shape tensor input. Note: It seems the model has shape tensor input is a rare case. Most of the cases, the inputs are all execution tensor.
- Loading branch information
Showing
4 changed files
with
162 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
:� | ||
) | ||
data | ||
shapereshapedReshape"Reshapetrt_engine_wrapperZ | ||
data | ||
N | ||
Z | ||
shape | ||
|
||
b | ||
reshaped | ||
B | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import onnx | ||
from onnx import TensorProto, helper | ||
|
||
|
||
def generate_model(model_name): | ||
nodes = [ | ||
helper.make_node( | ||
"Reshape", | ||
["data", "shape"], | ||
["reshaped"], | ||
"Reshape", | ||
), | ||
] | ||
|
||
graph = helper.make_graph( | ||
nodes, | ||
"trt_engine_wrapper", | ||
[ # input | ||
helper.make_tensor_value_info("data", TensorProto.FLOAT, ["N", 2]), | ||
helper.make_tensor_value_info( | ||
"shape", | ||
TensorProto.INT64, | ||
[ | ||
2, | ||
], | ||
), | ||
], | ||
[ # output | ||
helper.make_tensor_value_info("reshaped", TensorProto.FLOAT, [4, 1]), | ||
], | ||
) | ||
|
||
model = helper.make_model(graph) | ||
onnx.save(model, model_name) | ||
|
||
|
||
if __name__ == "__main__": | ||
generate_model("trt_reshape.onnx") |