From aacf0df4e2645b78b9fc3f46af74da1c8a373763 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 4 Oct 2023 21:56:51 +0000 Subject: [PATCH 1/3] Add CUDA EP in StableDiffusion demo --- .../models/stable_diffusion/README.md | 44 +++-- .../models/stable_diffusion/demo_txt2img.py | 4 + .../models/stable_diffusion/demo_utils.py | 45 ++++- .../stable_diffusion/diffusion_models.py | 44 ++++- .../models/stable_diffusion/engine_builder.py | 5 +- .../engine_builder_ort_cuda.py | 172 ++++++++++++++++++ .../models/stable_diffusion/ort_optimizer.py | 15 +- .../pipeline_stable_diffusion.py | 3 + .../stable_diffusion/requirements-cuda.txt | 14 -- .../stable_diffusion/requirements-cuda11.txt | 21 +++ .../stable_diffusion/requirements-cuda12.txt | 21 +++ .../requirements-tensorrt.txt | 2 - .../models/stable_diffusion/requirements.txt | 1 + 13 files changed, 343 insertions(+), 48 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py delete mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt delete mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 1fbd5092a719a..d937e3f4213e0 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -72,38 +72,52 @@ cd onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion Below is an example to optimize Stable Diffusion 1.5 in Linux. For Windows OS, please change the format of path to be like `.\sd` instead of `./sd`. +It is recommended to create a Conda environment with Python 3.10 for the following setup: +``` +conda create -n py310 python=3.10 +conda activate py310 +``` + ### Setup Environment (CUDA) -It is recommended to create a Conda environment with Python 3.8, 3.9 or 3.10, and run the model with CUDA 11.8. -If you use CUDA 12.*, you will need build onnxruntime-gpu from source. +First, we need install CUDA 11.8 or 12.1, [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html) 8.5 or above, and [TensorRT 8.6.1](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html) in the machine. + +#### CUDA 11.8: + +In the Conda environment, install PyTorch 2.1 or above, and other required packages like the following: ``` -conda create -n py38 python=3.8 -conda activate py38 -pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 +pip install torch --index-url https://download.pytorch.org/whl/nightly/cu118 pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com -pip install -r requirements-cuda.txt +pip install -r requirements-cuda11.txt ``` -ONNX Runtime requires CUDA and [cuDNN](https://developer.nvidia.com/rdp/cudnn-download) for GPU inference. CUDA 11.8 and cuDNN 8.5 or above are recommended. -#### Install Nightly (Optional) +We cannot directly `pip install tensorrt` for CUDA 11. Follow https://github.com/NVIDIA/TensorRT/issues/2773 to install TensorRT for CUDA 11 in Linux. For Windows, pip install the tensorrt wheel in the downloaded TensorRT zip file instead. -Skip this step if you use onnxruntime-gpu package from official releases. +#### CUDA 12.*: +The official package of onnxruntime-gpu 1.16.* is built for CUDA 11.8. To use CUDA 12.*, you will need [build onnxruntime from source](https://onnxruntime.ai/docs/build/inferencing.html). -To try latest optimizations, you can install [ort-nightly-gpu](https://aiinfra.visualstudio.com/PublicPackages/_artifacts/feed/ORT-Nightly/PyPI/ort-nightly-gpu/) package like the following: +``` +git clone --recursive https://github.com/Microsoft/onnxruntime.git +cd onnxruntime +pip install -r requirements-dev.txt +``` +Follow [example script for A100 in Ubuntu](https://github.com/microsoft/onnxruntime/blob/26a7b63716e3125bfe35fe3663ba10d2d7322628/build_release.sh) +or [example script for RTX 4090 in Windows](https://github.com/microsoft/onnxruntime/blob/8df5f4e0df1f3b9ceeb0f1f2561b09727ace9b37/build_trt.cmd) to build and install onnxruntime-gpu wheel. +Then install other python packages like the following: ``` -pip uninstall onnxruntime-gpu -pip install ort-nightly-gpu -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ +pip install torch --index-url https://download.pytorch.org/whl/nightly/cu121 +pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +pip install -r requirements-cuda12.txt ``` +Finally, `pip install tensorrt` for Linux. For Windows, pip install the tensorrt wheel in the downloaded TensorRT zip file instead. ### Setup Environment (ROCm) -It is recommended that the users run the model with ROCm 5.4 or newer and Python 3.8, 3.9 or 3.10. +It is recommended that the users run the model with ROCm 5.4 or newer and Python 3.10. Note that Windows is not supported for ROCm at the moment. ``` -conda create -n py38 python=3.8 -conda activate py38 wget https://repo.radeon.com/rocm/manylinux/rocm-rel-5.4/torch-1.12.1%2Brocm5.4-cp38-cp38-linux_x86_64.whl pip install torch-1.12.1+rocm5.4-cp38-cp38-linux_x86_64.whl pip install -r requirements-rocm.txt diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py index f6e00063a6391..0eeab19e8328a 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -20,6 +20,8 @@ # limitations under the License. # -------------------------------------------------------------------------- +import logging + import coloredlogs from cuda import cudart from demo_utils import init_pipeline, parse_arguments, repeat_prompt @@ -27,6 +29,8 @@ from engine_builder import EngineType, get_engine_type from pipeline_txt2img import Txt2ImgPipeline +logger = logging.getLogger(__name__) + if __name__ == "__main__": coloredlogs.install(fmt="%(funcName)20s: %(message)s") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py index 5fdafc463f4e2..796e83f70d6e4 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -34,12 +34,15 @@ class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatte def parse_arguments(is_xl: bool, description: str): parser = argparse.ArgumentParser(description=description, formatter_class=RawTextArgumentDefaultsHelpFormatter) + engines = ["ORT_TRT", "TRT"] if is_xl else ["ORT_CUDA", "ORT_TRT", "TRT"] + parser.add_argument( "--engine", type=str, - default="ORT_TRT", - choices=["ORT_TRT", "TRT"], - help="Backend engine. Default is OnnxRuntime CUDA execution provider.", + default=engines[0], + choices=engines, + help="Backend engine in {engines}. " + "ORT_CUDA is CUDA execution provider; ORT_TRT is Tensorrt execution provider; TRT is TensorRT", ) supported_versions = PipelineInfo.supported_versions(is_xl) @@ -106,7 +109,7 @@ def parse_arguments(is_xl: bool, description: str): parser.add_argument( "--onnx-opset", type=int, - default=17, + default=None, choices=range(14, 18), help="Select ONNX opset version to target for exported models.", ) @@ -163,6 +166,16 @@ def parse_arguments(is_xl: bool, description: str): args = parser.parse_args() + if ( + args.engine in ["ORT_CUDA", "ORT_TRT"] + and (args.force_onnx_export or args.force_onnx_optimize) + and not args.force_engine_build + ): + raise ValueError( + "For ORT_CUDA or ORT_TRT, --force_onnx_export and --force_onnx_optimize are not supported. " + "Please use --force_engine_build instead." + ) + # Validate image dimensions if args.height % 8 != 0 or args.width % 8 != 0: raise ValueError( @@ -173,6 +186,9 @@ def parse_arguments(is_xl: bool, description: str): print("[I] CUDA Graph is disabled since dynamic input shape is configured.") args.disable_cuda_graph = True + if args.onnx_opset is None: + args.onnx_opset = 14 if args.engine == "ORT_CUDA" else 17 + print(args) return args @@ -197,7 +213,7 @@ def repeat_prompt(args): def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_size, batch_size): onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( - args.work_dir, pipeline_info, engine_type + work_dir=args.work_dir, pipeline_info=pipeline_info, engine_type=engine_type ) # Initialize demo @@ -214,7 +230,24 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si engine_type=engine_type, ) - if engine_type == EngineType.ORT_TRT: + if engine_type == EngineType.ORT_CUDA: + # Build CUDA EP engines and load pytorch modules + pipeline.backend.build_engines( + engine_dir=engine_dir, + framework_model_dir=framework_model_dir, + onnx_dir=onnx_dir, + onnx_opset=args.onnx_opset, + opt_image_height=args.height, + opt_image_width=args.height, + opt_batch_size=batch_size, + force_engine_rebuild=args.force_engine_build, + device_id=torch.cuda.current_device(), + disable_cuda_graph_models=[ + "clip2", # TODO: Add ArgMax cuda kernel to enable cuda graph for clip2. + "unetxl", + ], + ) + elif engine_type == EngineType.ORT_TRT: # Build TensorRT EP engines and load pytorch modules pipeline.backend.build_engines( engine_dir, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py index 951cd66005f4c..7726abb9f9e4d 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -298,9 +298,16 @@ def get_input_profile(self, batch_size, image_height, image_width, static_batch, def get_shape_dict(self, batch_size, image_height, image_width): return None + def fp32_input_output_names(self) -> List[str]: + """For CUDA EP, we export ONNX model with FP32 first, then convert it to mixed precision model. + This is a list of input or output names that are kept as float32 during converting. + For the first version, we will use same data type as TensorRT. + """ + return [] + def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True): optimizer = self.get_ort_optimizer() - optimizer.optimize(input_onnx_path, optimized_onnx_path, to_fp16) + optimizer.optimize(input_onnx_path, optimized_onnx_path, to_fp16, keep_io_types=self.fp32_input_output_names()) def optimize_trt(self, input_onnx_path, optimized_onnx_path): onnx_graph = onnx.load(input_onnx_path) @@ -416,7 +423,7 @@ def get_sample_input(self, batch_size, image_height, image_width): self.check_dims(batch_size, image_height, image_width) return (torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device),) - def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path): + def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path, use_external_data_format=False): graph: GraphProto = model.graph hidden_layers = -1 for i in range(len(graph.node)): @@ -457,7 +464,29 @@ def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path) onnx_model = OnnxModel(model) onnx_model.add_node(cast_node) - onnx_model.save_model_to_file(optimized_onnx_path) + onnx_model.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format) + + def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True): + optimizer = self.get_ort_optimizer() + if not self.output_hidden_state: + optimizer.optimize( + input_onnx_path, optimized_onnx_path, to_fp16, keep_io_types=[], keep_outputs=["text_embeddings"] + ) + else: + with tempfile.TemporaryDirectory() as tmp_dir: + # Save to a temporary file so that we can load it with Onnx Runtime. + logger.info("Saving a temporary model to add hidden_states to graph output ...") + tmp_model_path = os.path.join(tmp_dir, "model.onnx") + + model = onnx.load(input_onnx_path) + self.add_hidden_states_graph_output(model, tmp_model_path, use_external_data_format=True) + optimizer.optimize( + tmp_model_path, + optimized_onnx_path, + to_fp16, + keep_io_types=[], + keep_outputs=["text_embeddings", "hidden_states"], + ) def optimize_trt(self, input_onnx_path, optimized_onnx_path): onnx_graph = onnx.load(input_onnx_path) @@ -598,6 +627,9 @@ def get_sample_input(self, batch_size, image_height, image_width): torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), ) + def fp32_input_output_names(self) -> List[str]: + return ["sample", "timestep"] + class UNetXL(BaseModel): def __init__( @@ -703,6 +735,9 @@ def get_sample_input(self, batch_size, image_height, image_width): }, ) + def fp32_input_output_names(self) -> List[str]: + return ["sample", "timestep"] + # VAE Decoder class VAE(BaseModel): @@ -773,6 +808,9 @@ def get_sample_input(self, batch_size, image_height, image_width): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) return (torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device),) + def fp32_input_output_names(self) -> List[str]: + return ["latent", "images"] + def get_tokenizer(pipeline_info: PipelineInfo, framework_model_dir, hf_token, subfolder="tokenizer"): tokenizer_dir = os.path.join(framework_model_dir, pipeline_info.name(), subfolder) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py index 64c3c5bc80ecb..fdf05ffc799d9 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -77,9 +77,8 @@ def get_cached_model_name(self, model_name): def get_onnx_path(self, model_name, onnx_dir, opt=True): engine_name = self.engine_type.name.lower() - onnx_model_dir = os.path.join( - onnx_dir, self.get_cached_model_name(model_name) + (f".{engine_name}" if opt else "") - ) + directory_name = self.get_cached_model_name(model_name) + (f".{engine_name}" if opt else "") + onnx_model_dir = os.path.join(onnx_dir, directory_name) os.makedirs(onnx_model_dir, exist_ok=True) return os.path.join(onnx_model_dir, "model.onnx") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py new file mode 100644 index 0000000000000..e6e8e9c040881 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py @@ -0,0 +1,172 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import gc +import logging +import os +import shutil + +import torch +from diffusion_models import PipelineInfo +from engine_builder import EngineBuilder, EngineType + +import onnxruntime as ort +from onnxruntime.transformers.io_binding_helper import CudaSession + +logger = logging.getLogger(__name__) + + +class OrtCudaEngine(CudaSession): + def __init__(self, onnx_path, device_id: int = 0, enable_cuda_graph=False, disable_optimization=False): + self.onnx_path = onnx_path + self.provider = "CUDAExecutionProvider" + self.provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph) + + session_options = ort.SessionOptions() + # When the model has been optimized by onnxruntime, we can disable optimization to save session creation time. + if disable_optimization: + session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL + + print("creating CUDA EP session for ", onnx_path) + ort_session = ort.InferenceSession( + onnx_path, + session_options, + providers=[ + (self.provider, self.provider_options), + "CPUExecutionProvider", + ], + ) + print("created CUDA EP session for ", onnx_path) + + device = torch.device("cuda", device_id) + super().__init__(ort_session, device, enable_cuda_graph) + + def allocate_buffers(self, shape_dict, device): + super().allocate_buffers(shape_dict) + + +class OrtCudaEngineBuilder(EngineBuilder): + def __init__( + self, + pipeline_info: PipelineInfo, + max_batch_size=16, + hf_token=None, + device="cuda", + use_cuda_graph=False, + ): + """ + Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + device (str): + device to run. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + """ + super().__init__( + EngineType.ORT_CUDA, + pipeline_info, + max_batch_size=max_batch_size, + hf_token=hf_token, + device=device, + use_cuda_graph=use_cuda_graph, + ) + + def build_engines( + self, + engine_dir, + framework_model_dir, + onnx_dir, + onnx_opset, + opt_image_height=512, + opt_image_width=512, + opt_batch_size=1, + force_engine_rebuild=False, + device_id=0, + disable_cuda_graph_models=None, + ): + self.torch_device = torch.device("cuda", device_id) + self.load_models(framework_model_dir) + + if force_engine_rebuild: + if os.path.isdir(onnx_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir) + shutil.rmtree(onnx_dir) + if os.path.isdir(engine_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir) + shutil.rmtree(engine_dir) + + if not os.path.isdir(engine_dir): + os.makedirs(engine_dir) + + if not os.path.isdir(onnx_dir): + os.makedirs(onnx_dir) + + # Export models to ONNX + for model_name, model_obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + + onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) + onnx_opt_path = self.get_onnx_path(model_name, engine_dir, opt=True) + if not os.path.exists(onnx_opt_path): + if not os.path.exists(onnx_path): + logger.info("Exporting model: %s", onnx_path) + model = model_obj.load_model(framework_model_dir, self.hf_token) + with torch.inference_mode(): + # For CUDA EP, export FP32 onnx since some graph fusion only supports fp32 graph pattern. + inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) + + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=True, + input_names=model_obj.get_input_names(), + output_names=model_obj.get_output_names(), + dynamic_axes=model_obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + else: + logger.info("Found cached model: %s", onnx_path) + + # Run graph optimization and convert to mixed precision (computation in FP16) + if not os.path.exists(onnx_opt_path): + logger.info("Generating optimized model: %s", onnx_opt_path) + model_obj.optimize_ort(onnx_path, onnx_opt_path, to_fp16=True) + else: + logger.info("Found cached optimized model: %s", onnx_opt_path) + + built_engines = {} + for model_name in self.models: + if model_name == "vae" and self.vae_torch_fallback: + continue + + onnx_opt_path = self.get_onnx_path(model_name, engine_dir, opt=True) + + use_cuda_graph = self.use_cuda_graph + if self.use_cuda_graph and disable_cuda_graph_models and model_name in disable_cuda_graph_models: + use_cuda_graph = False + + engine = OrtCudaEngine(onnx_opt_path, device_id=device_id, enable_cuda_graph=use_cuda_graph) + logger.info("%s options for %s: %s", engine.provider, model_name, engine.provider_options) + built_engines[model_name] = engine + + self.engines = built_engines + + return built_engines + + def run_engine(self, model_name, feed_dict): + return self.engines[model_name].infer(feed_dict) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py index 2c4b8e8a1639e..a298e515c1454 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py @@ -13,6 +13,7 @@ import onnx from optimize_pipeline import has_external_data +from packaging import version from onnxruntime.transformers.fusion_options import FusionOptions from onnxruntime.transformers.onnx_model_clip import ClipOnnxModel @@ -71,10 +72,6 @@ def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True, keep use_gpu=True, ) - if keep_outputs is None and self.model_type == "clip": - # remove the pooler_output, and only keep the first output. - keep_outputs = ["text_embeddings"] - if keep_outputs: m.prune_graph(outputs=keep_outputs) @@ -84,8 +81,16 @@ def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True, keep keep_io_types=keep_io_types, ) + use_external_data_format = m.model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF + # Note that ORT < 1.16 could not save model larger than 2GB. - if float16 or (self.model_type != "unet"): + # This step is is optional since it has no impact on inference latency. + # The optimized model is not portable. It could only run in the same execution provider (CUDA EP in this case). + # When the model has been optimized by onnxruntime, we can disable optimization in SessionOption + # to save session creation time. Another benefit is to inspect the final graph for developing purpose. + from onnxruntime import __version__ as ort_version + + if version.parse(ort_version) >= version.parse("1.16.0") or not use_external_data_format: m = self.optimize_by_ort(m, use_external_data_format=use_external_data_format) m.get_operator_statistics() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py index a053c9d5d0835..87443c990450b 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py @@ -30,6 +30,7 @@ from diffusion_models import PipelineInfo, get_tokenizer from diffusion_schedulers import DDIMScheduler, EulerAncestralDiscreteScheduler, UniPCMultistepScheduler from engine_builder import EngineType +from engine_builder_ort_cuda import OrtCudaEngineBuilder from engine_builder_ort_trt import OrtTensorrtEngineBuilder from engine_builder_tensorrt import TensorrtEngineBuilder @@ -135,6 +136,8 @@ def __init__( self.backend = TensorrtEngineBuilder(pipeline_info, max_batch_size, hf_token, device, use_cuda_graph) elif engine_type == EngineType.ORT_TRT: self.backend = OrtTensorrtEngineBuilder(pipeline_info, max_batch_size, hf_token, device, use_cuda_graph) + elif engine_type == EngineType.ORT_CUDA: + self.backend = OrtCudaEngineBuilder(pipeline_info, max_batch_size, hf_token, device, use_cuda_graph) else: raise RuntimeError(f"Backend engine type {engine_type.name} is not supported") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt deleted file mode 100644 index 2a3caf4c2392b..0000000000000 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt +++ /dev/null @@ -1,14 +0,0 @@ --r requirements.txt -onnxruntime-gpu>=1.16 -py3nvml>=0.2.7 - -# cuda-python is needed for cuda graph. It shall be compatible with CUDA version of torch and onnxruntime-gpu. -cuda-python==11.8.0 -# For windows, cuda-python need the following -pywin32; platform_system == "Windows" - -nvtx - -# To export onnx, please install PyTorch 2.10 like -# pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118 -# pip3 install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt new file mode 100644 index 0000000000000..5f908c4f5ff39 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt @@ -0,0 +1,21 @@ +-r requirements.txt + +# Official onnxruntime-gpu 1.16.1 is built with CUDA 11.8. +onnxruntime-gpu>=1.16.1 + +py3nvml + +# The version of cuda-python shall be compatible with installed CUDA version. +# For example, if your CUDA version is 12.1, you can install cuda-python 12.1. +cuda-python==11.8.0 + +# For windows, cuda-python need the following +pywin32; platform_system == "Windows" + +nvtx + +# Please install PyTorch 2.1 or above for CUDA 11.8 using one of the following commands: +# pip3 install torch --index-url https://download.pytorch.org/whl/cu118 + +# Run the following command to install some extra packages for onnx graph optimization for TensorRT manually. +# pip3 install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt new file mode 100644 index 0000000000000..e4e765831c1b3 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt @@ -0,0 +1,21 @@ +-r requirements.txt + +# For CUDA 12.*, you will need build onnxruntime-gpu from source and install the wheel. See README.md for detail. +# onnxruntime-gpu>=1.16.1 + +py3nvml + +# The version of cuda-python shall be compatible with installed CUDA version. +# For example, if your CUDA version is 12.1, you can install cuda-python 12.1. +cuda-python==12.1.0 + +# For windows, cuda-python need the following +pywin32; platform_system == "Windows" + +nvtx + +# Please install PyTorch 2.1 or above for 12.1 using one of the following commands: +# pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + +# Run the following command to install some extra packages for onnx graph optimization for TensorRT manually. +# pip3 install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt deleted file mode 100644 index 5b59c64ab7470..0000000000000 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt +++ /dev/null @@ -1,2 +0,0 @@ --r requirements-cuda.txt -tensorrt>=8.6.1 diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt index d4e6c9fa07695..9386a941fb323 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt @@ -5,6 +5,7 @@ accelerate onnx>=1.13.0 coloredlogs packaging +# Use newer version of protobuf might cause crash protobuf==3.20.3 psutil sympy From c686c018ab1f9a905a768bd25e065b566c9f2db4 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 4 Oct 2023 22:25:48 +0000 Subject: [PATCH 2/3] update logging --- .../transformers/models/stable_diffusion/demo_txt2img.py | 4 ---- .../models/stable_diffusion/engine_builder_ort_cuda.py | 4 ++-- .../models/stable_diffusion/engine_builder_ort_trt.py | 4 ++-- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py index 0eeab19e8328a..f6e00063a6391 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -20,8 +20,6 @@ # limitations under the License. # -------------------------------------------------------------------------- -import logging - import coloredlogs from cuda import cudart from demo_utils import init_pipeline, parse_arguments, repeat_prompt @@ -29,8 +27,6 @@ from engine_builder import EngineType, get_engine_type from pipeline_txt2img import Txt2ImgPipeline -logger = logging.getLogger(__name__) - if __name__ == "__main__": coloredlogs.install(fmt="%(funcName)20s: %(message)s") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py index e6e8e9c040881..936d04e8a1c43 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py @@ -29,7 +29,7 @@ def __init__(self, onnx_path, device_id: int = 0, enable_cuda_graph=False, disab if disable_optimization: session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL - print("creating CUDA EP session for ", onnx_path) + logger.info("creating CUDA EP session for %s", onnx_path) ort_session = ort.InferenceSession( onnx_path, session_options, @@ -38,7 +38,7 @@ def __init__(self, onnx_path, device_id: int = 0, enable_cuda_graph=False, disab "CPUExecutionProvider", ], ) - print("created CUDA EP session for ", onnx_path) + logger.info("created CUDA EP session for %s", onnx_path) device = torch.device("cuda", device_id) super().__init__(ort_session, device, enable_cuda_graph) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py index 253cdcc45bf2e..a6bbd4ee7eeb7 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py @@ -32,7 +32,7 @@ def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, works session_options = ort.SessionOptions() session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL - print("creating TRT EP session for ", onnx_path) + logger.info("creating TRT EP session for %s", onnx_path) ort_session = ort.InferenceSession( onnx_path, session_options, @@ -40,7 +40,7 @@ def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, works ("TensorrtExecutionProvider", self.ort_trt_provider_options), ], ) - print("created TRT EP session for ", onnx_path) + logger.info("created TRT EP session for %s", onnx_path) device = torch.device("cuda", device_id) super().__init__(ort_session, device, enable_cuda_graph) From 90ee7ffc4e7b17996f162aeca0de31b39b82bb8b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 4 Oct 2023 22:41:56 +0000 Subject: [PATCH 3/3] fix code scan warning --- .../transformers/models/stable_diffusion/ort_optimizer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py index a298e515c1454..24570a6ef62da 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py @@ -12,7 +12,6 @@ from pathlib import Path import onnx -from optimize_pipeline import has_external_data from packaging import version from onnxruntime.transformers.fusion_options import FusionOptions @@ -60,8 +59,6 @@ def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True, keep fusion_options.enable_packed_kv = False fusion_options.enable_packed_qkv = False - use_external_data_format = has_external_data(input_fp32_onnx_path) - m = optimize_model( input_fp32_onnx_path, model_type=self.model_type,