From fe7b237a9780bb4aed7fd31bf3e4476f94584e71 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 20 Dec 2023 18:00:45 +0000 Subject: [PATCH] fix refiner --- .../stable_diffusion/benchmark_controlnet.py | 3 +-- .../models/stable_diffusion/demo_utils.py | 7 ++++++ .../engine_builder_ort_cuda.py | 6 ++--- .../engine_builder_ort_trt.py | 2 +- .../stable_diffusion/engine_builder_torch.py | 1 + .../models/stable_diffusion/ort_utils.py | 25 ------------------- .../pipeline_stable_diffusion.py | 19 +++++++------- 7 files changed, 23 insertions(+), 40 deletions(-) delete mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_controlnet.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_controlnet.py index 86c6166472f3d..52c64fb7e8e0b 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_controlnet.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_controlnet.py @@ -139,7 +139,7 @@ def load_ort_cuda_pipeline(name, engine, use_control_net=False, enable_cuda_grap ) engine_type = EngineType.ORT_CUDA if engine == "ort_cuda" else EngineType.ORT_TRT - onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( + onnx_dir, engine_dir, output_dir, framework_model_dir, _ = get_engine_paths( work_dir=work_dir, pipeline_info=pipeline_info, engine_type=engine_type ) @@ -157,7 +157,6 @@ def load_ort_cuda_pipeline(name, engine, use_control_net=False, enable_cuda_grap engine_dir=engine_dir, framework_model_dir=framework_model_dir, onnx_dir=onnx_dir, - force_engine_rebuild=False, device_id=torch.cuda.current_device(), ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py index 609853c80ae16..e1f18eafab788 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -380,6 +380,8 @@ def repeat_prompt(args): def initialize_pipeline( version="xl-turbo", + is_refiner:bool=False, + is_inpaint:bool=False, engine_type=EngineType.ORT_CUDA, work_dir: str = ".", engine_dir=None, @@ -406,6 +408,8 @@ def initialize_pipeline( ): pipeline_info = PipelineInfo( version, + is_refiner=is_refiner, + is_inpaint=is_inpaint, use_vae=use_vae, min_image_size=min_image_size, max_image_size=max_image_size, @@ -529,6 +533,8 @@ def load_pipelines(args, batch_size=None): params = { "version": args.version, + "is_refiner": False, + "is_inpaint": False, "engine_type": engine_type, "work_dir": args.work_dir, "engine_dir": args.engine_dir, @@ -561,6 +567,7 @@ def load_pipelines(args, batch_size=None): refiner = None if "xl" in args.version and args.enable_refiner: params["version"] = "xl-1.0" # Allow SDXL Turbo to use refiner. + params["is_refiner"] = True params["scheduler"] = args.refiner_scheduler params["do_classifier_free_guidance"] = args.refiner_guidance > 1.0 params["lcm"] = False diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py index 30414776dab04..6ab4858f11f23 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py @@ -12,11 +12,11 @@ import torch from diffusion_models import PipelineInfo from engine_builder import EngineBuilder, EngineType -from onnx import TensorProto -from ort_utils import CudaSession, OnnxModel from packaging import version import onnxruntime as ort +from onnxruntime.transformers.io_binding_helper import CudaSession +from onnxruntime.transformers.onnx_model import OnnxModel logger = logging.getLogger(__name__) @@ -200,7 +200,7 @@ def import_diffusers_engine(self, diffusers_onnx_dir: str, engine_dir: str): ) if model_name == "clip2": - model.change_graph_input_type(model.find_graph_input("input_ids"), TensorProto.INT32) + model.change_graph_input_type(model.find_graph_input("input_ids"), onnx.TensorProto.INT32) model.save_model_to_file(onnx_opt_path, use_external_data_format=(model_name == "clip2")) elif model_name in ["unet", "unetxl"]: diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py index a0b9ae886f04e..a606b88c82245 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py @@ -11,10 +11,10 @@ from cuda import cudart from diffusion_models import PipelineInfo from engine_builder import EngineBuilder, EngineType -from ort_utils import CudaSession from packaging import version import onnxruntime as ort +from onnxruntime.transformers.io_binding_helper import CudaSession logger = logging.getLogger(__name__) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_torch.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_torch.py index 0c59d5485f1cb..84b9ec27b801e 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_torch.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_torch.py @@ -39,6 +39,7 @@ def __init__( use_cuda_graph=use_cuda_graph, ) + self.compile_config = {} if use_cuda_graph: self.compile_config = { "clip": {"mode": "reduce-overhead", "dynamic": False}, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py deleted file mode 100644 index f238b70389371..0000000000000 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py +++ /dev/null @@ -1,25 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import logging -import os -import sys - -logger = logging.getLogger(__name__) - - -def add_transformers_dir_to_path(): - sys.path.append(os.path.dirname(__file__)) - - transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..")) - if transformers_dir not in sys.path: - sys.path.append(transformers_dir) - - -add_transformers_dir_to_path() - -# Walkaround so that we can test local change without building new package -from io_binding_helper import CudaSession # noqa -from onnx_model import OnnxModel # noqa diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py index 85106f29167d4..104ce984bd401 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py @@ -395,19 +395,20 @@ def tokenize(prompt, output_hidden_states): # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - if pooled_outputs: - pooled_output = text_embeddings - - if output_hidden_states: - if do_classifier_free_guidance: - text_embeddings = torch.cat([uncond_hidden_states, hidden_states]) - else: - text_embeddings = hidden_states + if output_hidden_states: + hidden_states = torch.cat([uncond_hidden_states, hidden_states]) self.stop_profile("clip") if pooled_outputs: - return text_embeddings.to(dtype=dtype), pooled_output.to(dtype=dtype) + # For text encoder in sdxl base + return hidden_states.to(dtype=dtype), text_embeddings.to(dtype=dtype) + + if output_hidden_states: + # For text encoder 2 in sdxl base or refiner + return hidden_states.to(dtype=dtype) + + # For text encoder in sd 1.5 return text_embeddings.to(dtype=dtype) def denoise_latent(