From d6dad96923423d5a8ae184f3e88fdf48aa6aab57 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 5 Oct 2023 08:19:20 -0700 Subject: [PATCH] Add CUDA EP in StableDiffusion demo (#17788) Add CUDA EP to the demo of stable diffusion. ### A100 Performance Test | Engine Property | Batch Size | TRT Latency (ms) | ORT_TRT Latency (ms) | ORT_CUDA Latency (ms) | TORCH Latency (ms) -- | -- | -- | -- | -- | -- | -- SD 1.5, 50 steps, 512x512 | Static Input Shape | 1 | 861 | 851 | 861 | N/A SD 1.5, 50 steps, 512x512 | Dynamic Input Shape, Optimized for batch size 1 and image size 512x512 | 1 | 974 | 1079 | 928 | 1222 SD 1.5, 50 steps, 768x768 | Dynamic Input Shape, Optimized for batch size 1 and image size 512x512 | 1 | 2492 | OOM | 1901 | 1971 SD 1.5, 50 steps, 768x768 | Dynamic Input Shape, Optimized for batch size 1 and image size 512x512 | 4 |9091 | OOM | 6785 | 6700 We can see that ORT_CUDA is the most robust one for handling dynamic input shape. PyTorch could be a good choice if you run large batch size. The above result is from one A100-SXM4-80GB GPU (in Standard_ND96amsr_A100_v4 Azure VM) with 50 steps to generate 512x512 or 768x768 images using StableDiffusion 1.5. Onnxruntime-gpu is built from source, and the following packages or libraries are used in this test: * tensorrt==8.6.1.post1 * torch==2.2.0.dev20230920+cu121 * transformers==4.31.0 * diffusers==0.19.3 * onnx==1.14.1 * onnx-graphsurgeon==0.3.27 * polygraphy==0.47.1 * protobuf==3.20.2 * onnxruntime-gpu==1.17.0 (built from source of main branch) * CUDA 12.2.2 * cuDNN 8.9.5.29 * python 3.10.13 For static input shape, the engine is built with static batch size and static image shape, and cuda graph is enabled. For dynamic input shape, the engine is built to support dynamic batch size and dynamic image shape, and cuda graph is disabled. The TensorRT engine is built for batch size 1~4, image size 256x256 ~ 1024x1024, and the optimized image size is 512x512. The script to test static and dynamic input shape are like the following: ``` prompt="a cute magical flying dog, fantasy art drawn by disney concept artists, highly detailed, digital paintining" for e in TRT ORT_TRT ORT_CUDA do python demo_txt2img.py --engine $e "$prompt" python demo_txt2img.py --engine $e --disable-cuda-graph --build-dynamic-batch --build-dynamic-shape "$prompt" python demo_txt2img.py --engine $e --disable-cuda-graph --build-dynamic-batch --build-dynamic-shape --height 768 --width 768 "$prompt" done ``` Performance of PyTorch is from commands like the following: ``` python benchmark.py -e torch -v 1.5 --enable_torch_compile -b 1 --height 512 --width 512 python benchmark.py -e torch -v 1.5 --enable_torch_compile -b 1 --height 768 --width 768 python benchmark.py -e torch -v 1.5 --enable_torch_compile -b 4 --height 768 --width 768 ``` --- .../models/stable_diffusion/README.md | 44 +++-- .../models/stable_diffusion/demo_utils.py | 45 ++++- .../stable_diffusion/diffusion_models.py | 44 ++++- .../models/stable_diffusion/engine_builder.py | 5 +- .../engine_builder_ort_cuda.py | 172 ++++++++++++++++++ .../engine_builder_ort_trt.py | 4 +- .../models/stable_diffusion/ort_optimizer.py | 18 +- .../pipeline_stable_diffusion.py | 3 + .../stable_diffusion/requirements-cuda.txt | 14 -- .../stable_diffusion/requirements-cuda11.txt | 21 +++ .../stable_diffusion/requirements-cuda12.txt | 21 +++ .../requirements-tensorrt.txt | 2 - .../models/stable_diffusion/requirements.txt | 1 + 13 files changed, 341 insertions(+), 53 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py delete mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt delete mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 1fbd5092a719a..d937e3f4213e0 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -72,38 +72,52 @@ cd onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion Below is an example to optimize Stable Diffusion 1.5 in Linux. For Windows OS, please change the format of path to be like `.\sd` instead of `./sd`. +It is recommended to create a Conda environment with Python 3.10 for the following setup: +``` +conda create -n py310 python=3.10 +conda activate py310 +``` + ### Setup Environment (CUDA) -It is recommended to create a Conda environment with Python 3.8, 3.9 or 3.10, and run the model with CUDA 11.8. -If you use CUDA 12.*, you will need build onnxruntime-gpu from source. +First, we need install CUDA 11.8 or 12.1, [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html) 8.5 or above, and [TensorRT 8.6.1](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html) in the machine. + +#### CUDA 11.8: + +In the Conda environment, install PyTorch 2.1 or above, and other required packages like the following: ``` -conda create -n py38 python=3.8 -conda activate py38 -pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 +pip install torch --index-url https://download.pytorch.org/whl/nightly/cu118 pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com -pip install -r requirements-cuda.txt +pip install -r requirements-cuda11.txt ``` -ONNX Runtime requires CUDA and [cuDNN](https://developer.nvidia.com/rdp/cudnn-download) for GPU inference. CUDA 11.8 and cuDNN 8.5 or above are recommended. -#### Install Nightly (Optional) +We cannot directly `pip install tensorrt` for CUDA 11. Follow https://github.com/NVIDIA/TensorRT/issues/2773 to install TensorRT for CUDA 11 in Linux. For Windows, pip install the tensorrt wheel in the downloaded TensorRT zip file instead. -Skip this step if you use onnxruntime-gpu package from official releases. +#### CUDA 12.*: +The official package of onnxruntime-gpu 1.16.* is built for CUDA 11.8. To use CUDA 12.*, you will need [build onnxruntime from source](https://onnxruntime.ai/docs/build/inferencing.html). -To try latest optimizations, you can install [ort-nightly-gpu](https://aiinfra.visualstudio.com/PublicPackages/_artifacts/feed/ORT-Nightly/PyPI/ort-nightly-gpu/) package like the following: +``` +git clone --recursive https://github.com/Microsoft/onnxruntime.git +cd onnxruntime +pip install -r requirements-dev.txt +``` +Follow [example script for A100 in Ubuntu](https://github.com/microsoft/onnxruntime/blob/26a7b63716e3125bfe35fe3663ba10d2d7322628/build_release.sh) +or [example script for RTX 4090 in Windows](https://github.com/microsoft/onnxruntime/blob/8df5f4e0df1f3b9ceeb0f1f2561b09727ace9b37/build_trt.cmd) to build and install onnxruntime-gpu wheel. +Then install other python packages like the following: ``` -pip uninstall onnxruntime-gpu -pip install ort-nightly-gpu -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ +pip install torch --index-url https://download.pytorch.org/whl/nightly/cu121 +pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +pip install -r requirements-cuda12.txt ``` +Finally, `pip install tensorrt` for Linux. For Windows, pip install the tensorrt wheel in the downloaded TensorRT zip file instead. ### Setup Environment (ROCm) -It is recommended that the users run the model with ROCm 5.4 or newer and Python 3.8, 3.9 or 3.10. +It is recommended that the users run the model with ROCm 5.4 or newer and Python 3.10. Note that Windows is not supported for ROCm at the moment. ``` -conda create -n py38 python=3.8 -conda activate py38 wget https://repo.radeon.com/rocm/manylinux/rocm-rel-5.4/torch-1.12.1%2Brocm5.4-cp38-cp38-linux_x86_64.whl pip install torch-1.12.1+rocm5.4-cp38-cp38-linux_x86_64.whl pip install -r requirements-rocm.txt diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py index 5fdafc463f4e2..796e83f70d6e4 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -34,12 +34,15 @@ class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatte def parse_arguments(is_xl: bool, description: str): parser = argparse.ArgumentParser(description=description, formatter_class=RawTextArgumentDefaultsHelpFormatter) + engines = ["ORT_TRT", "TRT"] if is_xl else ["ORT_CUDA", "ORT_TRT", "TRT"] + parser.add_argument( "--engine", type=str, - default="ORT_TRT", - choices=["ORT_TRT", "TRT"], - help="Backend engine. Default is OnnxRuntime CUDA execution provider.", + default=engines[0], + choices=engines, + help="Backend engine in {engines}. " + "ORT_CUDA is CUDA execution provider; ORT_TRT is Tensorrt execution provider; TRT is TensorRT", ) supported_versions = PipelineInfo.supported_versions(is_xl) @@ -106,7 +109,7 @@ def parse_arguments(is_xl: bool, description: str): parser.add_argument( "--onnx-opset", type=int, - default=17, + default=None, choices=range(14, 18), help="Select ONNX opset version to target for exported models.", ) @@ -163,6 +166,16 @@ def parse_arguments(is_xl: bool, description: str): args = parser.parse_args() + if ( + args.engine in ["ORT_CUDA", "ORT_TRT"] + and (args.force_onnx_export or args.force_onnx_optimize) + and not args.force_engine_build + ): + raise ValueError( + "For ORT_CUDA or ORT_TRT, --force_onnx_export and --force_onnx_optimize are not supported. " + "Please use --force_engine_build instead." + ) + # Validate image dimensions if args.height % 8 != 0 or args.width % 8 != 0: raise ValueError( @@ -173,6 +186,9 @@ def parse_arguments(is_xl: bool, description: str): print("[I] CUDA Graph is disabled since dynamic input shape is configured.") args.disable_cuda_graph = True + if args.onnx_opset is None: + args.onnx_opset = 14 if args.engine == "ORT_CUDA" else 17 + print(args) return args @@ -197,7 +213,7 @@ def repeat_prompt(args): def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_size, batch_size): onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( - args.work_dir, pipeline_info, engine_type + work_dir=args.work_dir, pipeline_info=pipeline_info, engine_type=engine_type ) # Initialize demo @@ -214,7 +230,24 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si engine_type=engine_type, ) - if engine_type == EngineType.ORT_TRT: + if engine_type == EngineType.ORT_CUDA: + # Build CUDA EP engines and load pytorch modules + pipeline.backend.build_engines( + engine_dir=engine_dir, + framework_model_dir=framework_model_dir, + onnx_dir=onnx_dir, + onnx_opset=args.onnx_opset, + opt_image_height=args.height, + opt_image_width=args.height, + opt_batch_size=batch_size, + force_engine_rebuild=args.force_engine_build, + device_id=torch.cuda.current_device(), + disable_cuda_graph_models=[ + "clip2", # TODO: Add ArgMax cuda kernel to enable cuda graph for clip2. + "unetxl", + ], + ) + elif engine_type == EngineType.ORT_TRT: # Build TensorRT EP engines and load pytorch modules pipeline.backend.build_engines( engine_dir, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py index 951cd66005f4c..7726abb9f9e4d 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -298,9 +298,16 @@ def get_input_profile(self, batch_size, image_height, image_width, static_batch, def get_shape_dict(self, batch_size, image_height, image_width): return None + def fp32_input_output_names(self) -> List[str]: + """For CUDA EP, we export ONNX model with FP32 first, then convert it to mixed precision model. + This is a list of input or output names that are kept as float32 during converting. + For the first version, we will use same data type as TensorRT. + """ + return [] + def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True): optimizer = self.get_ort_optimizer() - optimizer.optimize(input_onnx_path, optimized_onnx_path, to_fp16) + optimizer.optimize(input_onnx_path, optimized_onnx_path, to_fp16, keep_io_types=self.fp32_input_output_names()) def optimize_trt(self, input_onnx_path, optimized_onnx_path): onnx_graph = onnx.load(input_onnx_path) @@ -416,7 +423,7 @@ def get_sample_input(self, batch_size, image_height, image_width): self.check_dims(batch_size, image_height, image_width) return (torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device),) - def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path): + def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path, use_external_data_format=False): graph: GraphProto = model.graph hidden_layers = -1 for i in range(len(graph.node)): @@ -457,7 +464,29 @@ def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path) onnx_model = OnnxModel(model) onnx_model.add_node(cast_node) - onnx_model.save_model_to_file(optimized_onnx_path) + onnx_model.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format) + + def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True): + optimizer = self.get_ort_optimizer() + if not self.output_hidden_state: + optimizer.optimize( + input_onnx_path, optimized_onnx_path, to_fp16, keep_io_types=[], keep_outputs=["text_embeddings"] + ) + else: + with tempfile.TemporaryDirectory() as tmp_dir: + # Save to a temporary file so that we can load it with Onnx Runtime. + logger.info("Saving a temporary model to add hidden_states to graph output ...") + tmp_model_path = os.path.join(tmp_dir, "model.onnx") + + model = onnx.load(input_onnx_path) + self.add_hidden_states_graph_output(model, tmp_model_path, use_external_data_format=True) + optimizer.optimize( + tmp_model_path, + optimized_onnx_path, + to_fp16, + keep_io_types=[], + keep_outputs=["text_embeddings", "hidden_states"], + ) def optimize_trt(self, input_onnx_path, optimized_onnx_path): onnx_graph = onnx.load(input_onnx_path) @@ -598,6 +627,9 @@ def get_sample_input(self, batch_size, image_height, image_width): torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), ) + def fp32_input_output_names(self) -> List[str]: + return ["sample", "timestep"] + class UNetXL(BaseModel): def __init__( @@ -703,6 +735,9 @@ def get_sample_input(self, batch_size, image_height, image_width): }, ) + def fp32_input_output_names(self) -> List[str]: + return ["sample", "timestep"] + # VAE Decoder class VAE(BaseModel): @@ -773,6 +808,9 @@ def get_sample_input(self, batch_size, image_height, image_width): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) return (torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device),) + def fp32_input_output_names(self) -> List[str]: + return ["latent", "images"] + def get_tokenizer(pipeline_info: PipelineInfo, framework_model_dir, hf_token, subfolder="tokenizer"): tokenizer_dir = os.path.join(framework_model_dir, pipeline_info.name(), subfolder) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py index 64c3c5bc80ecb..fdf05ffc799d9 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -77,9 +77,8 @@ def get_cached_model_name(self, model_name): def get_onnx_path(self, model_name, onnx_dir, opt=True): engine_name = self.engine_type.name.lower() - onnx_model_dir = os.path.join( - onnx_dir, self.get_cached_model_name(model_name) + (f".{engine_name}" if opt else "") - ) + directory_name = self.get_cached_model_name(model_name) + (f".{engine_name}" if opt else "") + onnx_model_dir = os.path.join(onnx_dir, directory_name) os.makedirs(onnx_model_dir, exist_ok=True) return os.path.join(onnx_model_dir, "model.onnx") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py new file mode 100644 index 0000000000000..936d04e8a1c43 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py @@ -0,0 +1,172 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import gc +import logging +import os +import shutil + +import torch +from diffusion_models import PipelineInfo +from engine_builder import EngineBuilder, EngineType + +import onnxruntime as ort +from onnxruntime.transformers.io_binding_helper import CudaSession + +logger = logging.getLogger(__name__) + + +class OrtCudaEngine(CudaSession): + def __init__(self, onnx_path, device_id: int = 0, enable_cuda_graph=False, disable_optimization=False): + self.onnx_path = onnx_path + self.provider = "CUDAExecutionProvider" + self.provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph) + + session_options = ort.SessionOptions() + # When the model has been optimized by onnxruntime, we can disable optimization to save session creation time. + if disable_optimization: + session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL + + logger.info("creating CUDA EP session for %s", onnx_path) + ort_session = ort.InferenceSession( + onnx_path, + session_options, + providers=[ + (self.provider, self.provider_options), + "CPUExecutionProvider", + ], + ) + logger.info("created CUDA EP session for %s", onnx_path) + + device = torch.device("cuda", device_id) + super().__init__(ort_session, device, enable_cuda_graph) + + def allocate_buffers(self, shape_dict, device): + super().allocate_buffers(shape_dict) + + +class OrtCudaEngineBuilder(EngineBuilder): + def __init__( + self, + pipeline_info: PipelineInfo, + max_batch_size=16, + hf_token=None, + device="cuda", + use_cuda_graph=False, + ): + """ + Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + device (str): + device to run. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + """ + super().__init__( + EngineType.ORT_CUDA, + pipeline_info, + max_batch_size=max_batch_size, + hf_token=hf_token, + device=device, + use_cuda_graph=use_cuda_graph, + ) + + def build_engines( + self, + engine_dir, + framework_model_dir, + onnx_dir, + onnx_opset, + opt_image_height=512, + opt_image_width=512, + opt_batch_size=1, + force_engine_rebuild=False, + device_id=0, + disable_cuda_graph_models=None, + ): + self.torch_device = torch.device("cuda", device_id) + self.load_models(framework_model_dir) + + if force_engine_rebuild: + if os.path.isdir(onnx_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir) + shutil.rmtree(onnx_dir) + if os.path.isdir(engine_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir) + shutil.rmtree(engine_dir) + + if not os.path.isdir(engine_dir): + os.makedirs(engine_dir) + + if not os.path.isdir(onnx_dir): + os.makedirs(onnx_dir) + + # Export models to ONNX + for model_name, model_obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + + onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) + onnx_opt_path = self.get_onnx_path(model_name, engine_dir, opt=True) + if not os.path.exists(onnx_opt_path): + if not os.path.exists(onnx_path): + logger.info("Exporting model: %s", onnx_path) + model = model_obj.load_model(framework_model_dir, self.hf_token) + with torch.inference_mode(): + # For CUDA EP, export FP32 onnx since some graph fusion only supports fp32 graph pattern. + inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) + + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=True, + input_names=model_obj.get_input_names(), + output_names=model_obj.get_output_names(), + dynamic_axes=model_obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + else: + logger.info("Found cached model: %s", onnx_path) + + # Run graph optimization and convert to mixed precision (computation in FP16) + if not os.path.exists(onnx_opt_path): + logger.info("Generating optimized model: %s", onnx_opt_path) + model_obj.optimize_ort(onnx_path, onnx_opt_path, to_fp16=True) + else: + logger.info("Found cached optimized model: %s", onnx_opt_path) + + built_engines = {} + for model_name in self.models: + if model_name == "vae" and self.vae_torch_fallback: + continue + + onnx_opt_path = self.get_onnx_path(model_name, engine_dir, opt=True) + + use_cuda_graph = self.use_cuda_graph + if self.use_cuda_graph and disable_cuda_graph_models and model_name in disable_cuda_graph_models: + use_cuda_graph = False + + engine = OrtCudaEngine(onnx_opt_path, device_id=device_id, enable_cuda_graph=use_cuda_graph) + logger.info("%s options for %s: %s", engine.provider, model_name, engine.provider_options) + built_engines[model_name] = engine + + self.engines = built_engines + + return built_engines + + def run_engine(self, model_name, feed_dict): + return self.engines[model_name].infer(feed_dict) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py index 253cdcc45bf2e..a6bbd4ee7eeb7 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py @@ -32,7 +32,7 @@ def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, works session_options = ort.SessionOptions() session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL - print("creating TRT EP session for ", onnx_path) + logger.info("creating TRT EP session for %s", onnx_path) ort_session = ort.InferenceSession( onnx_path, session_options, @@ -40,7 +40,7 @@ def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, works ("TensorrtExecutionProvider", self.ort_trt_provider_options), ], ) - print("created TRT EP session for ", onnx_path) + logger.info("created TRT EP session for %s", onnx_path) device = torch.device("cuda", device_id) super().__init__(ort_session, device, enable_cuda_graph) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py index 2c4b8e8a1639e..24570a6ef62da 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py @@ -12,7 +12,7 @@ from pathlib import Path import onnx -from optimize_pipeline import has_external_data +from packaging import version from onnxruntime.transformers.fusion_options import FusionOptions from onnxruntime.transformers.onnx_model_clip import ClipOnnxModel @@ -59,8 +59,6 @@ def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True, keep fusion_options.enable_packed_kv = False fusion_options.enable_packed_qkv = False - use_external_data_format = has_external_data(input_fp32_onnx_path) - m = optimize_model( input_fp32_onnx_path, model_type=self.model_type, @@ -71,10 +69,6 @@ def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True, keep use_gpu=True, ) - if keep_outputs is None and self.model_type == "clip": - # remove the pooler_output, and only keep the first output. - keep_outputs = ["text_embeddings"] - if keep_outputs: m.prune_graph(outputs=keep_outputs) @@ -84,8 +78,16 @@ def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True, keep keep_io_types=keep_io_types, ) + use_external_data_format = m.model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF + # Note that ORT < 1.16 could not save model larger than 2GB. - if float16 or (self.model_type != "unet"): + # This step is is optional since it has no impact on inference latency. + # The optimized model is not portable. It could only run in the same execution provider (CUDA EP in this case). + # When the model has been optimized by onnxruntime, we can disable optimization in SessionOption + # to save session creation time. Another benefit is to inspect the final graph for developing purpose. + from onnxruntime import __version__ as ort_version + + if version.parse(ort_version) >= version.parse("1.16.0") or not use_external_data_format: m = self.optimize_by_ort(m, use_external_data_format=use_external_data_format) m.get_operator_statistics() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py index a053c9d5d0835..87443c990450b 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py @@ -30,6 +30,7 @@ from diffusion_models import PipelineInfo, get_tokenizer from diffusion_schedulers import DDIMScheduler, EulerAncestralDiscreteScheduler, UniPCMultistepScheduler from engine_builder import EngineType +from engine_builder_ort_cuda import OrtCudaEngineBuilder from engine_builder_ort_trt import OrtTensorrtEngineBuilder from engine_builder_tensorrt import TensorrtEngineBuilder @@ -135,6 +136,8 @@ def __init__( self.backend = TensorrtEngineBuilder(pipeline_info, max_batch_size, hf_token, device, use_cuda_graph) elif engine_type == EngineType.ORT_TRT: self.backend = OrtTensorrtEngineBuilder(pipeline_info, max_batch_size, hf_token, device, use_cuda_graph) + elif engine_type == EngineType.ORT_CUDA: + self.backend = OrtCudaEngineBuilder(pipeline_info, max_batch_size, hf_token, device, use_cuda_graph) else: raise RuntimeError(f"Backend engine type {engine_type.name} is not supported") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt deleted file mode 100644 index 2a3caf4c2392b..0000000000000 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt +++ /dev/null @@ -1,14 +0,0 @@ --r requirements.txt -onnxruntime-gpu>=1.16 -py3nvml>=0.2.7 - -# cuda-python is needed for cuda graph. It shall be compatible with CUDA version of torch and onnxruntime-gpu. -cuda-python==11.8.0 -# For windows, cuda-python need the following -pywin32; platform_system == "Windows" - -nvtx - -# To export onnx, please install PyTorch 2.10 like -# pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118 -# pip3 install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt new file mode 100644 index 0000000000000..5f908c4f5ff39 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt @@ -0,0 +1,21 @@ +-r requirements.txt + +# Official onnxruntime-gpu 1.16.1 is built with CUDA 11.8. +onnxruntime-gpu>=1.16.1 + +py3nvml + +# The version of cuda-python shall be compatible with installed CUDA version. +# For example, if your CUDA version is 12.1, you can install cuda-python 12.1. +cuda-python==11.8.0 + +# For windows, cuda-python need the following +pywin32; platform_system == "Windows" + +nvtx + +# Please install PyTorch 2.1 or above for CUDA 11.8 using one of the following commands: +# pip3 install torch --index-url https://download.pytorch.org/whl/cu118 + +# Run the following command to install some extra packages for onnx graph optimization for TensorRT manually. +# pip3 install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt new file mode 100644 index 0000000000000..e4e765831c1b3 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt @@ -0,0 +1,21 @@ +-r requirements.txt + +# For CUDA 12.*, you will need build onnxruntime-gpu from source and install the wheel. See README.md for detail. +# onnxruntime-gpu>=1.16.1 + +py3nvml + +# The version of cuda-python shall be compatible with installed CUDA version. +# For example, if your CUDA version is 12.1, you can install cuda-python 12.1. +cuda-python==12.1.0 + +# For windows, cuda-python need the following +pywin32; platform_system == "Windows" + +nvtx + +# Please install PyTorch 2.1 or above for 12.1 using one of the following commands: +# pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + +# Run the following command to install some extra packages for onnx graph optimization for TensorRT manually. +# pip3 install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt deleted file mode 100644 index 5b59c64ab7470..0000000000000 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt +++ /dev/null @@ -1,2 +0,0 @@ --r requirements-cuda.txt -tensorrt>=8.6.1 diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt index d4e6c9fa07695..9386a941fb323 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt @@ -5,6 +5,7 @@ accelerate onnx>=1.13.0 coloredlogs packaging +# Use newer version of protobuf might cause crash protobuf==3.20.3 psutil sympy