diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index e6e8f3c7eb8e7..3ba167f518ed5 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -1336,18 +1336,18 @@ def main(): os.environ["ORT_ENABLE_FUSED_CAUSAL_ATTENTION"] = "1" result = run_optimum_ort( - sd_model, - args.pipeline, - provider, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, + model_name=sd_model, + directory=args.pipeline, + provider=provider, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, ) elif args.engine == "onnxruntime": assert args.pipeline and os.path.isdir( diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_torch.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_torch.py new file mode 100644 index 0000000000000..909f40b362478 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_torch.py @@ -0,0 +1,107 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging + +from diffusion_models import PipelineInfo +from engine_builder import EngineBuilder, EngineType + +logger = logging.getLogger(__name__) + + +class TorchEngineBuilder(EngineBuilder): + def __init__( + self, + pipeline_info: PipelineInfo, + max_batch_size=16, + device="cuda", + use_cuda_graph=False, + ): + """ + Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + device (str): + device to run. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + """ + super().__init__( + EngineType.TORCH, + pipeline_info, + max_batch_size=max_batch_size, + device=device, + use_cuda_graph=use_cuda_graph, + ) + + self.compile_config = { + "clip": {"mode": "reduce-overhead", "dynamic": False}, + "clip2": {"mode": "reduce-overhead", "dynamic": False}, + "unet": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False}, + "unetxl": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False}, + } + + self.compile_config["vae"] = {"mode": "reduce-overhead", "fullgraph": False, "dynamic": False} + + def build_engines( + self, + framework_model_dir: str, + ): + import torch + + self.torch_device = torch.device("cuda", torch.cuda.current_device()) + self.load_models(framework_model_dir) + + pipe = self.load_pipeline_with_lora() if self.pipeline_info.lora_weights else None + + built_engines = {} + for model_name, model_obj in self.models.items(): + model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir) + if self.pipeline_info.is_xl() and not self.custom_fp16_vae: + model = model.to(device=self.torch_device, dtype=torch.float32) + else: + model = model.to(device=self.torch_device, dtype=torch.float16) + + if model_name in self.compile_config: + compile_config = self.compile_config[model_name] + if model_name in ["unet", "unetxl"]: + model.to(memory_format=torch.channels_last) + engine = torch.compile(model, **compile_config) + built_engines[model_name] = engine + else: + built_engines[model_name] = model + + self.engines = built_engines + + def run_engine(self, model_name, feed_dict): + if model_name in ["unet", "unetxl"]: + if "controlnet_images" in feed_dict: + return {"latent": self.engines[model_name](**feed_dict)} + + if model_name == "unetxl": + added_cond_kwargs = {k: feed_dict[k] for k in feed_dict if k in ["text_embeds", "time_ids"]} + return { + "latent": self.engines[model_name]( + feed_dict["sample"], + feed_dict["timestep"], + feed_dict["encoder_hidden_states"], + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + } + + return { + "latent": self.engines[model_name]( + feed_dict["sample"], feed_dict["timestep"], feed_dict["encoder_hidden_states"], return_dict=False + )[0] + } + + if model_name in ["vae_encoder"]: + return {"latent": self.engines[model_name](feed_dict["images"])} + + raise RuntimeError(f"Shall not reach here: {model_name}")