From 7fe1ec7ece2356b6e3f1d03b5a00af84de252b51 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 3 Oct 2023 14:23:02 -0700 Subject: [PATCH] add demo_txt2img --- .../tools/transformers/io_binding_helper.py | 1 + .../models/stable_diffusion/demo_txt2img.py | 96 +++++++ .../stable_diffusion/demo_txt2img_xl.py | 213 +------------- .../models/stable_diffusion/demo_utils.py | 263 ++++++++++++++++++ .../stable_diffusion/diffusion_models.py | 4 +- .../models/stable_diffusion/engine_builder.py | 33 +-- .../engine_builder_tensorrt.py | 3 +- .../models/stable_diffusion/ort_optimizer.py | 36 ++- 8 files changed, 408 insertions(+), 241 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index 71c1a21d8f768..de17f195c99cc 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -283,6 +283,7 @@ def infer(self, feed_dict: Dict[str, torch.Tensor]): if name in self.input_names: if self.enable_cuda_graph: assert self.input_tensors[name].nelement() == tensor.nelement() + assert self.input_tensors[name].dtype == tensor.dtype assert tensor.device.type == "cuda" # Please install cuda-python package with a version corresponding to CUDA in your machine. from cuda import cudart diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py new file mode 100644 index 0000000000000..21af105bdde8c --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -0,0 +1,96 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import coloredlogs +from cuda import cudart +from demo_utils import init_pipeline, parse_arguments, repeat_prompt +from diffusion_models import PipelineInfo +from engine_builder import EngineType, get_engine_paths, get_engine_type +from pipeline_txt2img import Txt2ImgPipeline + +if __name__ == "__main__": + coloredlogs.install(fmt="%(funcName)20s: %(message)s") + + args = parse_arguments(is_xl=False, description="Options for Stable Diffusion Demo") + prompt, negative_prompt = repeat_prompt(args) + + image_height = args.height + image_width = args.width + + # Register TensorRT plugins + engine_type = get_engine_type(args.engine) + if engine_type == EngineType.TRT: + from trt_utilities import init_trt_plugins + + init_trt_plugins() + + max_batch_size = 16 + if engine_type != EngineType.ORT_CUDA and (args.build_dynamic_shape or image_height > 512 or image_width > 512): + max_batch_size = 4 + + batch_size = len(prompt) + if batch_size > max_batch_size: + raise ValueError( + f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4" + ) + + pipeline_info = PipelineInfo(args.version) + pipeline = init_pipeline(Txt2ImgPipeline, pipeline_info, engine_type, args, max_batch_size, batch_size) + + if engine_type == EngineType.TRT: + max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + pipeline.backend.activate_engines(shared_device_memory) + + pipeline.load_resources(image_height, image_width, batch_size) + + def run_inference(warmup=False): + images, time_base = pipeline.run( + prompt, + negative_prompt, + image_height, + image_width, + warmup=warmup, + denoising_steps=args.denoising_steps, + guidance=args.guidance, + seed=args.seed, + return_type="images", + ) + + return images, time_base + + if not args.disable_cuda_graph: + # inference once to get cuda graph + images, _ = run_inference(warmup=True) + + print("[I] Warming up ..") + for _ in range(args.num_warmup_runs): + images, _ = run_inference(warmup=True) + + print("[I] Running StableDiffusion pipeline") + if args.nvtx_profile: + cudart.cudaProfilerStart() + images, pipeline_time = run_inference(warmup=False) + if args.nvtx_profile: + cudart.cudaProfilerStop() + + pipeline.teardown() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py index 346bed4b70986..f13d5eab2ed6a 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -21,6 +21,7 @@ # -------------------------------------------------------------------------- import argparse +import coloredlogs import torch from cuda import cudart @@ -28,154 +29,15 @@ from engine_builder import EngineType, get_engine_paths, get_engine_type from pipeline_img2img_xl import Img2ImgXLPipeline from pipeline_txt2img_xl import Txt2ImgXLPipeline - - -def parse_arguments(): - parser = argparse.ArgumentParser(description="Options for Stable Diffusion XL Demo", conflict_handler="resolve") - parser.add_argument( - "--engine", - type=str, - default="ORT_TRT", - choices=["ORT_TRT", "TRT"], - help="Backend engine. Default is ORT_TRT, which means OnnxRuntime TensorRT execution provider.", - ) - - parser.add_argument( - "--version", type=str, default="xl-1.0", choices=["xl-1.0"], help="Version of Stable Diffusion XL" - ) - parser.add_argument( - "--height", type=int, default=1024, help="Height of image to generate (must be multiple of 8). Default is 1024." - ) - parser.add_argument( - "--width", type=int, default=1024, help="Height of image to generate (must be multiple of 8). Default is 1024." - ) - - parser.add_argument( - "--scheduler", - type=str, - default="DDIM", - choices=["DDIM", "EulerA", "UniPC"], - help="Scheduler for diffusion process", - ) - - parser.add_argument( - "--work-dir", - default="", - help="Root Directory to store torch or ONNX models, built engines and output images etc", - ) - - parser.add_argument("prompt", nargs="+", help="Text prompt(s) to guide image generation") - - parser.add_argument( - "--negative-prompt", nargs="*", default=[""], help="Optional negative prompt(s) to guide the image generation." - ) - parser.add_argument( - "--repeat-prompt", - type=int, - default=1, - choices=[1, 2, 4, 8, 16], - help="Number of times to repeat the prompt (batch size multiplier). Default is 1.", - ) - parser.add_argument( - "--denoising-steps", - type=int, - default=30, - help="Number of denoising steps in each of base and refiner. Default is 30.", - ) - parser.add_argument( - "--guidance", - type=float, - default=5.0, - help="Higher guidance scale encourages to generate images that are closely linked to the text prompt.", - ) - - # ONNX export - parser.add_argument( - "--onnx-opset", - type=int, - default=17, - choices=range(14, 18), - help="Select ONNX opset version to target for exported models. Default is 17.", - ) - parser.add_argument( - "--force-onnx-export", action="store_true", help="Force ONNX export of CLIP, UNET, and VAE models" - ) - parser.add_argument( - "--force-onnx-optimize", action="store_true", help="Force ONNX optimizations for CLIP, UNET, and VAE models" - ) - - # Framework model ckpt - parser.add_argument("--framework-model-dir", default="pytorch_model", help="Directory for HF saved models") - parser.add_argument("--hf-token", type=str, help="HuggingFace API access token for downloading model checkpoints") - - # Engine build options. - parser.add_argument("--force-engine-build", action="store_true", help="Force rebuilding the TensorRT engine") - parser.add_argument( - "--build-dynamic-batch", action="store_true", help="Build TensorRT engines to support dynamic batch size." - ) - parser.add_argument( - "--build-dynamic-shape", action="store_true", help="Build TensorRT engines to support dynamic image sizes." - ) - - # Inference related options - parser.add_argument( - "--num-warmup-runs", type=int, default=5, help="Number of warmup runs before benchmarking performance" - ) - parser.add_argument("--nvtx-profile", action="store_true", help="Enable NVTX markers for performance profiling") - parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results") - parser.add_argument("--disable-cuda-graph", action="store_true", help="Disable cuda graph.") - - # TensorRT only options - group = parser.add_argument_group("Options for TensorRT (--engine=TRT) only") - group.add_argument("--onnx-refit-dir", help="ONNX models to load the weights from") - group.add_argument( - "--build-enable-refit", action="store_true", help="Enable Refit option in TensorRT engines during build." - ) - group.add_argument( - "--build-preview-features", action="store_true", help="Build TensorRT engines with preview features." - ) - group.add_argument( - "--build-all-tactics", action="store_true", help="Build TensorRT engines using all tactic sources." - ) - - # Pipeline options - parser.add_argument( - "--enable-refiner", action="store_true", help="Enable refiner and run both base and refiner pipeline." - ) - - return parser.parse_args() - +from demo_utils import parse_arguments, repeat_prompt, init_pipeline if __name__ == "__main__": - args = parse_arguments() - - if (args.build_dynamic_batch or args.build_dynamic_shape) and not args.disable_cuda_graph: - print("[I] CUDA Graph is disabled since dynamic input shape is configured.") - args.disable_cuda_graph = True + coloredlogs.install(fmt="%(funcName)20s: %(message)s") + args = parse_arguments(is_xl=True, description="Options for Stable Diffusion XL Demo") + prompt, negative_prompt = repeat_prompt(args) - print(args) - - # Process prompt - if not isinstance(args.prompt, list): - raise ValueError(f"`prompt` must be of type `str` or `str` list, but is {type(args.prompt)}") - prompt = args.prompt * args.repeat_prompt - - if not isinstance(args.negative_prompt, list): - raise ValueError( - f"`--negative-prompt` must be of type `str` or `str` list, but is {type(args.negative_prompt)}" - ) - if len(args.negative_prompt) == 1: - negative_prompt = args.negative_prompt * len(prompt) - else: - negative_prompt = args.negative_prompt - - # Validate image dimensions image_height = args.height image_width = args.width - if image_height % 8 != 0 or image_width % 8 != 0: - raise ValueError( - f"Image height and width have to be divisible by 8 but specified as: {image_height} and {image_width}." - ) # Register TensorRT plugins engine_type = get_engine_type(args.engine) @@ -194,73 +56,12 @@ def parse_arguments(): f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4" ) - def init_pipeline(pipeline_class, pipeline_info, engine_type): - onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( - args.work_dir, pipeline_info, engine_type - ) - - # Initialize demo - pipeline = pipeline_class( - pipeline_info, - scheduler=args.scheduler, - output_dir=output_dir, - hf_token=args.hf_token, - verbose=False, - nvtx_profile=args.nvtx_profile, - max_batch_size=max_batch_size, - use_cuda_graph=not args.disable_cuda_graph, - framework_model_dir=framework_model_dir, - engine_type=engine_type, - ) - - # Load TensorRT engines and pytorch modules - if engine_type == EngineType.ORT_TRT: - pipeline.backend.build_engines( - engine_dir, - framework_model_dir, - onnx_dir, - args.onnx_opset, - opt_image_height=image_height, - opt_image_width=image_width, - opt_batch_size=len(prompt), - # force_export=args.force_onnx_export, - # force_optimize=args.force_onnx_optimize, - force_engine_rebuild=args.force_engine_build, - static_batch=not args.build_dynamic_batch, - static_image_shape=not args.build_dynamic_shape, - max_workspace_size=0, - device_id=torch.cuda.current_device(), - ) - elif engine_type == EngineType.TRT: - # Load TensorRT engines and pytorch modules - pipeline.backend.load_engines( - engine_dir, - framework_model_dir, - onnx_dir, - args.onnx_opset, - opt_batch_size=len(prompt), - opt_image_height=image_height, - opt_image_width=image_width, - force_export=args.force_onnx_export, - force_optimize=args.force_onnx_optimize, - force_build=args.force_engine_build, - static_batch=not args.build_dynamic_batch, - static_shape=not args.build_dynamic_shape, - enable_refit=args.build_enable_refit, - enable_preview=args.build_preview_features, - enable_all_tactics=args.build_all_tactics, - timing_cache=timing_cache, - onnx_refit_dir=args.onnx_refit_dir, - ) - - return pipeline - base_info = PipelineInfo(args.version, use_vae_in_xl_base=not args.enable_refiner) - base = init_pipeline(Txt2ImgXLPipeline, base_info, engine_type) + base = init_pipeline(Txt2ImgXLPipeline, base_info, engine_type, args, max_batch_size, batch_size) if args.enable_refiner: refiner_info = PipelineInfo(args.version, is_sd_xl_refiner=True) - refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type) + refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type, args, max_batch_size, batch_size) if engine_type == EngineType.TRT: max_device_memory = max(base.backend.max_device_memory(), refiner.backend.max_device_memory()) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py new file mode 100644 index 0000000000000..d872fbafe59d7 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -0,0 +1,263 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import argparse + +import torch +from cuda import cudart +from diffusion_models import PipelineInfo +from engine_builder import EngineType, get_engine_paths, get_engine_type +from pipeline_img2img_xl import Img2ImgXLPipeline +from pipeline_txt2img_xl import Txt2ImgXLPipeline + + +class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter): + pass + + +def parse_arguments(is_xl: bool, description: str): + parser = argparse.ArgumentParser(description=description, formatter_class=RawTextArgumentDefaultsHelpFormatter) + + parser.add_argument( + "--engine", + type=str, + default="ORT_TRT", + choices=["ORT_TRT", "TRT"], + help="Backend engine. Default is OnnxRuntime CUDA execution provider.", + ) + + supported_versions = PipelineInfo.supported_versions(is_xl) + parser.add_argument( + "--version", + type=str, + default=supported_versions[-1] if is_xl else "1.5", + choices=supported_versions, + help="Version of Stable Diffusion" + (" XL." if is_xl else "."), + ) + + parser.add_argument( + "--height", + type=int, + default=1024 if is_xl else 512, + help="Height of image to generate (must be multiple of 8).", + ) + parser.add_argument( + "--width", type=int, default=1024 if is_xl else 512, help="Height of image to generate (must be multiple of 8)." + ) + + parser.add_argument( + "--scheduler", + type=str, + default="DDIM", + choices=["DDIM", "EulerA", "UniPC"], + help="Scheduler for diffusion process", + ) + + parser.add_argument( + "--work-dir", + default=".", + help="Root Directory to store torch or ONNX models, built engines and output images etc.", + ) + + parser.add_argument("prompt", nargs="+", help="Text prompt(s) to guide image generation.") + + parser.add_argument( + "--negative-prompt", nargs="*", default=[""], help="Optional negative prompt(s) to guide the image generation." + ) + parser.add_argument( + "--repeat-prompt", + type=int, + default=1, + choices=[1, 2, 4, 8, 16], + help="Number of times to repeat the prompt (batch size multiplier).", + ) + + parser.add_argument( + "--denoising-steps", + type=int, + default=30 if is_xl else 50, + help="Number of denoising steps" + (" in each of base and refiner." if is_xl else "."), + ) + + parser.add_argument( + "--guidance", + type=float, + default=5.0 if is_xl else 7.5, + help="Higher guidance scale encourages to generate images that are closely linked to the text prompt.", + ) + + # ONNX export + parser.add_argument( + "--onnx-opset", + type=int, + default=None, + choices=range(14, 18), + help="Select ONNX opset version to target for exported models.", + ) + parser.add_argument( + "--force-onnx-export", action="store_true", help="Force ONNX export of CLIP, UNET, and VAE models." + ) + parser.add_argument( + "--force-onnx-optimize", action="store_true", help="Force ONNX optimizations for CLIP, UNET, and VAE models." + ) + + # Framework model ckpt + parser.add_argument( + "--framework-model-dir", + default="pytorch_model", + help="Directory for HF saved models. Default is pytorch_model.", + ) + parser.add_argument("--hf-token", type=str, help="HuggingFace API access token for downloading model checkpoints.") + + # Engine build options. + parser.add_argument("--force-engine-build", action="store_true", help="Force rebuilding the TensorRT engine.") + parser.add_argument( + "--build-dynamic-batch", action="store_true", help="Build TensorRT engines to support dynamic batch size." + ) + parser.add_argument( + "--build-dynamic-shape", action="store_true", help="Build TensorRT engines to support dynamic image sizes." + ) + + # Inference related options + parser.add_argument( + "--num-warmup-runs", type=int, default=5, help="Number of warmup runs before benchmarking performance." + ) + parser.add_argument("--nvtx-profile", action="store_true", help="Enable NVTX markers for performance profiling.") + parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.") + parser.add_argument("--disable-cuda-graph", action="store_true", help="Disable cuda graph.") + + # TensorRT only options + group = parser.add_argument_group("Options for TensorRT (--engine=TRT) only") + group.add_argument("--onnx-refit-dir", help="ONNX models to load the weights from.") + group.add_argument( + "--build-enable-refit", action="store_true", help="Enable Refit option in TensorRT engines during build." + ) + group.add_argument( + "--build-preview-features", action="store_true", help="Build TensorRT engines with preview features." + ) + group.add_argument( + "--build-all-tactics", action="store_true", help="Build TensorRT engines using all tactic sources." + ) + + # Pipeline options + if is_xl: + parser.add_argument( + "--enable-refiner", action="store_true", help="Enable refiner and run both base and refiner pipelines." + ) + + args = parser.parse_args() + + # Validate image dimensions + if args.height % 8 != 0 or args.width % 8 != 0: + raise ValueError( + f"Image height and width have to be divisible by 8 but specified as: {args.height} and {args.width}." + ) + + if (args.build_dynamic_batch or args.build_dynamic_shape) and not args.disable_cuda_graph: + print("[I] CUDA Graph is disabled since dynamic input shape is configured.") + args.disable_cuda_graph = True + + if args.onnx_opset is None: + args.onnx_opset = 14 if args.engine == "ORT_CUDA" else 17 + + print(args) + + return args + + +def repeat_prompt(args): + if not isinstance(args.prompt, list): + raise ValueError(f"`prompt` must be of type `str` or `str` list, but is {type(args.prompt)}") + prompt = args.prompt * args.repeat_prompt + + if not isinstance(args.negative_prompt, list): + raise ValueError( + f"`--negative-prompt` must be of type `str` or `str` list, but is {type(args.negative_prompt)}" + ) + if len(args.negative_prompt) == 1: + negative_prompt = args.negative_prompt * len(prompt) + else: + negative_prompt = args.negative_prompt + + return prompt, negative_prompt + + +def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_size, batch_size): + onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( + args.work_dir, pipeline_info, engine_type + ) + + # Initialize demo + pipeline = pipeline_class( + pipeline_info, + scheduler=args.scheduler, + output_dir=output_dir, + hf_token=args.hf_token, + verbose=False, + nvtx_profile=args.nvtx_profile, + max_batch_size=max_batch_size, + use_cuda_graph=not args.disable_cuda_graph, + framework_model_dir=framework_model_dir, + engine_type=engine_type, + ) + + if engine_type == EngineType.ORT_TRT: + # Build TensorRT EP engines and load pytorch modules + pipeline.backend.build_engines( + engine_dir, + framework_model_dir, + onnx_dir, + args.onnx_opset, + opt_image_height=args.height, + opt_image_width=args.height, + opt_batch_size=batch_size, + # force_export=args.force_onnx_export, + # force_optimize=args.force_onnx_optimize, + force_engine_rebuild=args.force_engine_build, + static_batch=not args.build_dynamic_batch, + static_image_shape=not args.build_dynamic_shape, + max_workspace_size=0, + device_id=torch.cuda.current_device(), + ) + elif engine_type == EngineType.TRT: + # Load TensorRT engines and pytorch modules + pipeline.backend.load_engines( + engine_dir, + framework_model_dir, + onnx_dir, + args.onnx_opset, + opt_batch_size=batch_size, + opt_image_height=args.height, + opt_image_width=args.height, + force_export=args.force_onnx_export, + force_optimize=args.force_onnx_optimize, + force_build=args.force_engine_build, + static_batch=not args.build_dynamic_batch, + static_shape=not args.build_dynamic_shape, + enable_refit=args.build_enable_refit, + enable_preview=args.build_preview_features, + enable_all_tactics=args.build_all_tactics, + timing_cache=timing_cache, + onnx_refit_dir=args.onnx_refit_dir, + ) + + return pipeline diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py index 0f54b3a8c0183..8f78bbfd39c05 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -121,8 +121,8 @@ def vae_scaling_factor(self) -> float: return 0.13025 if self.is_sd_xl() else 0.18215 @staticmethod - def supported_versions(): - return ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base", "xl-1.0"] + def supported_versions(is_xl:bool): + return ["xl-1.0"] if is_xl else ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base"] def name(self) -> str: if self.version == "1.4": diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py index c2d9683953915..64c3c5bc80ecb 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -91,6 +91,10 @@ def load_models(self, framework_model_dir: str): if hasattr(torch.nn.functional, "scaled_dot_product_attention"): delattr(torch.nn.functional, "scaled_dot_product_attention") + # For TRT or ORT_TRT, we will export fp16 torch model for UNet. + # For ORT_CUDA, we export fp32 model first, then optimize to fp16. + export_fp16_unet = self.engine_type in [EngineType.ORT_TRT, EngineType.TRT] + if "clip" in self.stages: self.models["clip"] = CLIP( self.pipeline_info, @@ -114,7 +118,7 @@ def load_models(self, framework_model_dir: str): self.pipeline_info, None, # not loaded yet device=self.torch_device, - fp16=True, + fp16=export_fp16_unet, max_batch_size=self.max_batch_size, unet_dim=(9 if self.pipeline_info.is_inpaint() else 4), ) @@ -124,7 +128,7 @@ def load_models(self, framework_model_dir: str): self.pipeline_info, None, # not loaded yet device=self.torch_device, - fp16=True, + fp16=export_fp16_unet, max_batch_size=self.max_batch_size, unet_dim=4, time_dim=(5 if self.pipeline_info.is_sd_xl_refiner() else 6), @@ -162,23 +166,16 @@ def vae_decode(self, latents): return images -def get_engine_paths(work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType, engine_sub_dir=True): +def get_engine_paths(work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType): root_dir = work_dir or "." - short_name = pipeline_info.short_name() - engine_name = engine_type.name.lower() - - if engine_sub_dir: - onnx_dir = os.path.join(root_dir, engine_type.name, short_name, "onnx") - engine_dir = os.path.join(root_dir, engine_type.name, short_name, "engine") - output_dir = os.path.join(root_dir, engine_type.name, short_name, "output") - timing_cache = os.path.join(root_dir, engine_type.name, "timing_cache") - else: - onnx_dir = os.path.join(root_dir, short_name, "onnx") - engine_dir = os.path.join(root_dir, short_name, engine_name) - output_dir = os.path.join(root_dir, short_name, "output") - timing_cache = os.path.join(root_dir, "timing_cache") - - framework_model_dir = os.path.join(root_dir, "torch_model") + + # When both ORT_CUDA and ORT_TRT/TRT is used, we shall make sub directory for each engine since + # ORT_CUDA need fp32 torch model, while ORT_TRT/TRT use fp16 torch model. + onnx_dir = os.path.join(root_dir, engine_type.name, short_name, "onnx") + engine_dir = os.path.join(root_dir, engine_type.name, short_name, "engine") + output_dir = os.path.join(root_dir, engine_type.name, short_name, "output") + timing_cache = os.path.join(root_dir, engine_type.name, "timing_cache") + framework_model_dir = os.path.join(root_dir, engine_type.name, "torch_model") return onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py index 8ee15df47d9fd..4a924abfb8600 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py @@ -504,5 +504,4 @@ def activate_engines(self, shared_device_memory=None): engine.activate(reuse_device_memory=self.shared_device_memory) def run_engine(self, model_name, feed_dict): - engine = self.engines[model_name] - return engine.infer(feed_dict, self.stream, use_cuda_graph=self.use_cuda_graph) + return self.engines[model_name].infer(feed_dict, self.stream, use_cuda_graph=self.use_cuda_graph) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py index 0824c8f07d6e2..166940df331c7 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py @@ -10,6 +10,7 @@ import logging import tempfile from pathlib import Path +from packaging import version import onnx @@ -19,6 +20,8 @@ from onnxruntime.transformers.onnx_model_vae import VaeOnnxModel from onnxruntime.transformers.optimizer import optimize_by_onnxruntime, optimize_model +from optimize_pipeline import has_external_data + logger = logging.getLogger(__name__) @@ -32,21 +35,22 @@ def __init__(self, model_type: str): "clip": ClipOnnxModel, } - def optimize_by_ort(self, onnx_model): + def optimize_by_ort(self, onnx_model, use_external_data_format=False): # Use this step to see the final graph that executed by Onnx Runtime. with tempfile.TemporaryDirectory() as tmp_dir: # Save to a temporary file so that we can load it with Onnx Runtime. logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...") tmp_model_path = Path(tmp_dir) / "model.onnx" - onnx_model.save_model_to_file(str(tmp_model_path)) - ort_optimized_model_path = tmp_model_path + onnx_model.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format) + ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx" optimize_by_onnxruntime( - str(tmp_model_path), use_gpu=True, optimized_model_path=str(ort_optimized_model_path) + str(tmp_model_path), use_gpu=True, optimized_model_path=str(ort_optimized_model_path), + save_as_external_data=use_external_data_format, ) model = onnx.load(str(ort_optimized_model_path), load_external_data=True) return self.model_type_class_mapping[self.model_type](model) - def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True): + def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True, keep_io_types=False, keep_outputs=None): """Optimize onnx model using ONNX Runtime transformers optimizer""" logger.info(f"Optimize {input_fp32_onnx_path}...") fusion_options = FusionOptions(self.model_type) @@ -54,6 +58,8 @@ def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True): fusion_options.enable_packed_kv = False fusion_options.enable_packed_qkv = False + use_external_data_format = has_external_data(input_fp32_onnx_path) + m = optimize_model( input_fp32_onnx_path, model_type=self.model_type, @@ -64,21 +70,25 @@ def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True): use_gpu=True, ) - if self.model_type == "clip": - m.prune_graph(outputs=["text_embeddings"]) # remove the pooler_output, and only keep the first output. + if keep_outputs is None and self.model_type == "clip": + # remove the pooler_output, and only keep the first output. + keep_outputs = ["text_embeddings"] + + if keep_outputs: + m.prune_graph(outputs=keep_outputs) if float16: logger.info("Convert to float16 ...") m.convert_float_to_float16( - keep_io_types=False, - op_block_list=["RandomNormalLike"], + keep_io_types=keep_io_types, ) - # Note that ORT 1.15 could not save model larger than 2GB. This only works for float16 - if float16 or (self.model_type != "unet"): - m = self.optimize_by_ort(m) + # Note that ORT < 1.16 could not save model larger than 2GB. + from onnxruntime import __version__ as ort_version + if version.parse(ort_version) >= version.parse("1.16.0") or not use_external_data_format: + m = self.optimize_by_ort(m, use_external_data_format=use_external_data_format) m.get_operator_statistics() m.get_fused_operator_statistics() - m.save_model_to_file(optimized_onnx_path, use_external_data_format=(self.model_type == "unet") and not float16) + m.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format) logger.info("%s is optimized: %s", self.model_type, optimized_onnx_path)