From f8e775d8a25d7bfe33f72a5b64ed94f555131ae9 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 20 Nov 2023 18:10:37 +0000 Subject: [PATCH] refactor script --- .../gen_trt_engine_wrapper_onnx_model.py | 266 +++++++++--------- 1 file changed, 135 insertions(+), 131 deletions(-) diff --git a/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py b/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py index e95610bbfeec0..717a0816247e7 100644 --- a/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py +++ b/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py @@ -8,48 +8,140 @@ import tensorrt as trt from onnx import TensorProto, helper -dynamic_dim_count = 0 - - -# TRT uses "-1" to represent dynamic dimension -# ORT uses symbolic name to represent dynamic dimension -# Here we only do the conversion when there is any dynamic dimension in the shape -def trt_shape_to_ort_shape(trt_data_shape): - def has_dynamic_dim(trt_data_shape): - if any(dim == -1 for dim in trt_data_shape): - return True - return False - - if not has_dynamic_dim(trt_data_shape): - return trt_data_shape - - ort_data_shape = [] - if has_dynamic_dim(trt_data_shape): - for dim in trt_data_shape: - if dim == -1: - global dynamic_dim_count - ort_data_shape.append("free_dim_" + str(dynamic_dim_count)) - dynamic_dim_count += 1 + +class TensorRTEngineWrapperCreator: + def __init__(self, args): + ctx_embed_mode = args.embed_mode + engine_cache_path = args.trt_engine_cache_path + self.model_name = args.model_name + self.dynamic_dim_count = 0 + + # Get serialized engine from engine cache + with open(engine_cache_path, "rb") as file: + engine_buffer = file.read() + + if ctx_embed_mode: + ep_cache_context_content = engine_buffer + else: + ep_cache_context_content = engine_cache_path + + # Deserialize an TRT engine + logger = trt.Logger(trt.Logger.WARNING) + runtime = trt.Runtime(logger) + engine = runtime.deserialize_cuda_engine(engine_buffer) + num_bindings = engine.num_bindings + + input_tensors = [] + output_tensors = [] + input_tensor_shapes = [] + output_tensor_shapes = [] + input_tensor_types = [] + output_tensor_types = [] + + # Get type and shape of each input/output + for b_index in range(num_bindings): + tensor_name = engine.get_tensor_name(b_index) + tensor_shape = engine.get_tensor_shape(tensor_name) + tensor_type = engine.get_tensor_dtype(tensor_name) + if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT: + input_tensors.append(tensor_name) + input_tensor_shapes.append(tensor_shape) + input_tensor_types.append(tensor_type) else: - ort_data_shape.append(dim) - return ort_data_shape - - -def trt_data_type_to_onnx_data_type(trt_data_type): - if trt_data_type == trt.DataType.FLOAT: - return TensorProto.FLOAT - elif trt_data_type == trt.DataType.HALF: - return TensorProto.FLOAT16 - elif trt_data_type == trt.DataType.INT8: - return TensorProto.INT8 - elif trt_data_type == trt.DataType.INT32: - return TensorProto.INT32 - elif trt_data_type == trt.DataType.BOOL: - return TensorProto.BOOL - elif trt_data_type == trt.DataType.UINT8: - return TensorProto.UINT8 - else: - return TensorProto.UNDEFINED + output_tensors.append(tensor_name) + output_tensor_shapes.append(tensor_shape) + output_tensor_types.append(tensor_type) + + # Note: + # The TRT engine should be built with min, max and opt profiles so that dynamic shape input can have dimension of "-1" + print(input_tensors) + print(input_tensor_types) + print(input_tensor_shapes) + print(output_tensors) + print(output_tensor_types) + print(output_tensor_shapes) + + nodes = [ + helper.make_node( + "EPContext", + input_tensors, + output_tensors, + "EPContext", + domain="com.microsoft", + embed_mode=ctx_embed_mode, + ep_cache_context=ep_cache_context_content, + ), + ] + + model_inputs = [] + for i in range(len(input_tensors)): + model_inputs.append( + helper.make_tensor_value_info( + input_tensors[i], + self.trt_data_type_to_onnx_data_type(input_tensor_types[i]), + self.trt_shape_to_ort_shape(input_tensor_shapes[i]), + ) + ) + + model_outputs = [] + for i in range(len(output_tensors)): + model_outputs.append( + helper.make_tensor_value_info( + output_tensors[i], + self.trt_data_type_to_onnx_data_type(output_tensor_types[i]), + self.trt_shape_to_ort_shape(output_tensor_shapes[i]), + ) + ) + + self.graph = helper.make_graph( + nodes, + "trt_engine_wrapper", + model_inputs, + model_outputs, + ) + + def trt_data_type_to_onnx_data_type(self, trt_data_type): + if trt_data_type == trt.DataType.FLOAT: + return TensorProto.FLOAT + elif trt_data_type == trt.DataType.HALF: + return TensorProto.FLOAT16 + elif trt_data_type == trt.DataType.INT8: + return TensorProto.INT8 + elif trt_data_type == trt.DataType.INT32: + return TensorProto.INT32 + elif trt_data_type == trt.DataType.BOOL: + return TensorProto.BOOL + elif trt_data_type == trt.DataType.UINT8: + return TensorProto.UINT8 + else: + return TensorProto.UNDEFINED + + # TRT uses "-1" to represent dynamic dimension + # ORT uses symbolic name to represent dynamic dimension + # Here we only do the conversion when there is any dynamic dimension in the shape + def trt_shape_to_ort_shape(self, trt_data_shape): + def has_dynamic_dim(trt_data_shape): + if any(dim == -1 for dim in trt_data_shape): + return True + return False + + if not has_dynamic_dim(trt_data_shape): + return trt_data_shape + + ort_data_shape = [] + if has_dynamic_dim(trt_data_shape): + for dim in trt_data_shape: + if dim == -1: + ort_data_shape.append("free_dim_" + str(self.dynamic_dim_count)) + self.dynamic_dim_count += 1 + else: + ort_data_shape.append(dim) + return ort_data_shape + + def create_model(self): + model = helper.make_model(self.graph) + onnx.save(model, self.model_name) + print(self.model_name + " is created.") def main(): @@ -74,96 +166,8 @@ def main(): type=str, ) args = parser.parse_args() - - ctx_embed_mode = args.embed_mode - engine_cache_path = args.trt_engine_cache_path - - # Get serialized engine from engine cache - with open(engine_cache_path, "rb") as file: - engine_buffer = file.read() - - if ctx_embed_mode: - ep_cache_context_content = engine_buffer - else: - ep_cache_context_content = engine_cache_path - - # Deserialize an TRT engine - logger = trt.Logger(trt.Logger.WARNING) - runtime = trt.Runtime(logger) - engine = runtime.deserialize_cuda_engine(engine_buffer) - num_bindings = engine.num_bindings - - input_tensors = [] - output_tensors = [] - input_tensor_shapes = [] - output_tensor_shapes = [] - input_tensor_types = [] - output_tensor_types = [] - - # Get type and shape of each input/output - for b_index in range(num_bindings): - tensor_name = engine.get_tensor_name(b_index) - tensor_shape = engine.get_tensor_shape(tensor_name) - tensor_type = engine.get_tensor_dtype(tensor_name) - if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT: - input_tensors.append(tensor_name) - input_tensor_shapes.append(tensor_shape) - input_tensor_types.append(tensor_type) - else: - output_tensors.append(tensor_name) - output_tensor_shapes.append(tensor_shape) - output_tensor_types.append(tensor_type) - - # Note: - # The TRT engine should be built with min, max and opt profiles so that dynamic shape input can have dimension of "-1" - print(input_tensors) - print(input_tensor_types) - print(input_tensor_shapes) - print(output_tensors) - print(output_tensor_types) - print(output_tensor_shapes) - - nodes = [ - helper.make_node( - "EPContext", - input_tensors, - output_tensors, - "EPContext", - domain="com.microsoft", - embed_mode=ctx_embed_mode, - ep_cache_context=ep_cache_context_content, - ), - ] - - model_inputs = [] - for i in range(len(input_tensors)): - model_inputs.append( - helper.make_tensor_value_info( - input_tensors[i], - trt_data_type_to_onnx_data_type(input_tensor_types[i]), - trt_shape_to_ort_shape(input_tensor_shapes[i]), - ) - ) - - model_outputs = [] - for i in range(len(output_tensors)): - model_outputs.append( - helper.make_tensor_value_info( - output_tensors[i], - trt_data_type_to_onnx_data_type(output_tensor_types[i]), - trt_shape_to_ort_shape(output_tensor_shapes[i]), - ) - ) - - graph = helper.make_graph( - nodes, - "trt_engine_wrapper", - model_inputs, - model_outputs, - ) - - model = helper.make_model(graph) - onnx.save(model, args.model_name) + ctor = TensorRTEngineWrapperCreator(args) + ctor.create_model() if __name__ == "__main__":