Skip to content

Commit

Permalink
default save_fp32_intermediate_model=False
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Nov 4, 2023
1 parent 42f3a67 commit eff8636
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
skip_group_norm.domain = "com.microsoft"

self.increase_counter(
f"SkipGroupNorm(add_out={len(outputs) > 1} bias={bias is not None} broadcast={broadcast})"
f"SkipGroupNorm(add_out={int(len(outputs) > 1)} bias={int(bias is not None)} broadcast={int(broadcast)})"
)

# Pass attributes from GroupNorm node to SkipGroupNorm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def build_engines(
opt_batch_size: int = 1,
force_engine_rebuild: bool = False,
device_id: int = 0,
save_fp32_intermediate_model=False,
):
self.torch_device = torch.device("cuda", device_id)
self.load_models(framework_model_dir)
Expand Down Expand Up @@ -230,35 +231,38 @@ def build_engines(
# Generate fp32 optimized model.
# If final target is fp16 model, we save fp32 optimized model so that it is easy to tune
# fp16 conversion. That could save a lot of time in developing.
if not os.path.exists(onnx_fp32_path):
print("------")
logger.info("Generating optimized model: %s", onnx_fp32_path)

# There is risk that some ORT fused ops fp32 only. So far, we have not encountered such issue.
model_obj.optimize_ort(
onnx_path,
onnx_fp32_path,
to_fp16=False,
fp32_op_list=self.model_config[model_name].force_fp32_ops,
optimize_by_ort=self.model_config[model_name].optimize_by_ort,
)
else:
logger.info("Found cached optimized model: %s", onnx_fp32_path)
use_fp32_intermediate = save_fp32_intermediate_model and self.model_config[model_name].fp16
if use_fp32_intermediate:
if not os.path.exists(onnx_fp32_path):
print("------")
logger.info("Generating optimized model: %s", onnx_fp32_path)

# There is risk that some ORT fused ops fp32 only. So far, we have not encountered such issue.
model_obj.optimize_ort(
onnx_path,
onnx_fp32_path,
to_fp16=False,
fp32_op_list=self.model_config[model_name].force_fp32_ops,
optimize_by_ort=self.model_config[model_name].optimize_by_ort,
)
else:
logger.info("Found cached optimized model: %s", onnx_fp32_path)

# Generate fp16 optimized model.
# Generate the final optimized model.
if not os.path.exists(onnx_opt_path):
print("------")
logger.info("Generating optimized model: %s", onnx_opt_path)

# The input is fp32 optimized model, so we need not run fusion again in this step.
# This step will convert model to fp16, then run ORT optimization to fuse fp16 ops if possible.
# When there is fp32 intermediate optimized model, this will just convert model from fp32 to fp16.
optimize_by_ort = False if use_fp32_intermediate else self.model_config[model_name].optimize_by_ort

model_obj.optimize_ort(
onnx_fp32_path,
onnx_fp32_path if use_fp32_intermediate else onnx_path,
onnx_opt_path,
to_fp16=self.model_config[model_name].fp16,
fp32_op_list=self.model_config[model_name].force_fp32_ops,
optimize_by_ort=self.model_config[model_name].optimize_by_ort,
optimize_by_fusion=False,
optimize_by_ort=optimize_by_ort,
optimize_by_fusion=not use_fp32_intermediate,
)
else:
logger.info("Found cached optimized model: %s", onnx_opt_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,6 @@ def optimize(
if keep_outputs:
m.prune_graph(outputs=keep_outputs)

if float16:
logger.info("Convert to float16 ...")
m.convert_float_to_float16(
keep_io_types=keep_io_types,
op_block_list=fp32_op_list,
)

use_external_data_format = m.model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF

# Note that ORT < 1.16 could not save model larger than 2GB.
Expand All @@ -110,6 +103,13 @@ def optimize(
if optimize_by_ort and (version.parse(ort_version) >= version.parse("1.16.0") or not use_external_data_format):
m = self.optimize_by_ort(m, use_external_data_format=use_external_data_format)

if float16:
logger.info("Convert to float16 ...")
m.convert_float_to_float16(
keep_io_types=keep_io_types,
op_block_list=fp32_op_list,
)

m.get_operator_statistics()
m.get_fused_operator_statistics()
m.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format)
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/python/tools/transformers/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,9 @@ def get_operator_statistics(self, include_domain=False):
op = (node.domain + ":" if include_domain and node.domain else "") + node.op_type
op_count[op] = 1 if op not in op_count else (op_count[op] + 1)

logger.info(f"Operators:{op_count}")
# Sorted by count in the descending order, then by key in alphabetical order.
logger.info(f"Operators:{sorted(op_count.items(), key=lambda kv:(-kv[1], kv[0]))}")

return op_count

@staticmethod
Expand Down

0 comments on commit eff8636

Please sign in to comment.