Skip to content

Commit

Permalink
refactor unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Nov 10, 2023
1 parent 11ce3fc commit f9206f3
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions onnxruntime/test/python/onnxruntime_test_engine_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,11 @@ class TestInferenceSessionWithCtxNode(unittest.TestCase):
trt_engine_cache_path_ = "./trt_engine_cache"
ctx_node_model_name_ = "ctx_node.onnx"

def test_ctx_node(self):
if "TensorrtExecutionProvider" in onnxrt.get_available_providers():
providers = [
(
"TensorrtExecutionProvider",
{"trt_engine_cache_enable": True, "trt_engine_cache_path": self.trt_engine_cache_path_},
)
]
self.run_model(providers)
# This test is only for TRT EP to test EPContext node with TRT engine
@unittest.skipIf(
"TensorrtExecutionProvider" not in ort.get_available_providers(),

Check failure

Code scanning / lintrunner

RUFF/F821 Error test

Undefined name ort.
See https://beta.ruff.rs/docs/rules/
reason="Test TRT EP only",
)

def create_ctx_node(self, ctx_embed_mode=0, cache_path=""):
if ctx_embed_mode:
Expand Down Expand Up @@ -60,24 +56,30 @@ def create_ctx_node(self, ctx_embed_mode=0, cache_path=""):
model = helper.make_model(graph)
onnx.save(model, self.ctx_node_model_name_)

def run_model(self, providers):
def test_ctx_node(self):
x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)

# First session and run to create engine cache
providers = [
(
"TensorrtExecutionProvider",
{"trt_engine_cache_enable": True, "trt_engine_cache_path": self.trt_engine_cache_path_},
)
]
session = onnxrt.InferenceSession(get_name("matmul_2.onnx"), providers=providers)

# One regular run to create engine cache
session.run(
["Y"],
{"X": x},
)

# Get engine cache name
cache_name = ""
for f in os.listdir(self.trt_engine_cache_path_):
if f.endswith(".engine"):
cache_name = f
print(cache_name)

# Second run to test ctx node with engine cache path
# Second session and run to test ctx node with engine cache path
self.create_ctx_node(cache_path=os.path.join(self.trt_engine_cache_path_, cache_name))
providers = [("TensorrtExecutionProvider", {})]
session = onnxrt.InferenceSession(get_name(self.ctx_node_model_name_), providers=providers)
Expand All @@ -86,7 +88,7 @@ def run_model(self, providers):
{"X": x},
)

# Third run to test ctx node with engine binary content
# Third session and run to test ctx node with engine binary content
self.create_ctx_node(ctx_embed_mode=1, cache_path=os.path.join(self.trt_engine_cache_path_, cache_name))
session = onnxrt.InferenceSession(get_name(self.ctx_node_model_name_), providers=providers)
session.run(
Expand Down

0 comments on commit f9206f3

Please sign in to comment.