Skip to content

Commit

Permalink
Enable optimize_by_onnxruntime call for float32 unet model (#17483)
Browse files Browse the repository at this point in the history
This makes it possible to call `optimize_by_onnxruntime` for float32 unet if `--use_external_data_format` is also used.

### Motivation and Context
When using `optimize_pipeline.py` without `--float16`, `optimize_by_onnxruntime` was not called for unet.
  • Loading branch information
kazssym authored Sep 10, 2023
1 parent b827ab0 commit 24f0893
Showing 1 changed file with 7 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,18 +150,19 @@ def optimize_sd_pipeline(
op_block_list=op_block_list + force_fp32_operators[name],
)

if enable_runtime_optimization and (float16 or (name not in ["unet"])):
if enable_runtime_optimization:
# Use this step to see the final graph that executed by Onnx Runtime.
# Note that ORT cannot save model larger than 2GB so we exclude unet float32 model.
# This step is optional since it has no impact on performance except model loading time.
with tempfile.TemporaryDirectory() as tmp_dir:
# Save to a temporary file so that we can load it with Onnx Runtime.
logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...")
tmp_model_path = Path(tmp_dir) / "model.onnx"
m.save_model_to_file(str(tmp_model_path))
ort_optimized_model_path = tmp_model_path
m.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format)
ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx"
optimize_by_onnxruntime(
str(tmp_model_path), use_gpu=True, optimized_model_path=str(ort_optimized_model_path)
str(tmp_model_path),
use_gpu=True,
optimized_model_path=str(ort_optimized_model_path),
save_as_external_data=use_external_data_format,
)
model = onnx.load(str(ort_optimized_model_path), load_external_data=True)
m = model_type_class_mapping[model_type](model)
Expand Down

0 comments on commit 24f0893

Please sign in to comment.