diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 3d00c9cd6bf59..80249fb71003f 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -92,13 +92,17 @@ python3 demo_txt2img_xl.py "Self-portrait oil painting, a beautiful cyborg with python3 demo_txt2img_xl.py --lcm --disable-refiner "an astronaut riding a rainbow unicorn, cinematic, dramatic" ``` +#### Generate an image with SDXL Turbo model guided by a text prompt +``` +python3 demo_txt2img_xl.py --version xl-turbo --height 512 --width 512 --denoising-steps 1 --scheduler UniPC "little cute gremlin sitting on a bed, cinematic" +``` + #### Generate an image with a text prompt using a control net ``` python3 demo_txt2img.py "Stormtrooper's lecture in beautiful lecture hall" --controlnet-type depth --controlnet-scale 1.0 python3 demo_txt2img_xl.py "young Mona Lisa" --controlnet-type canny --controlnet-scale 0.5 --scheduler UniPC --disable-refiner ``` - ## Optimize Stable Diffusion ONNX models for Hugging Face Diffusers or Optimum If you are able to run the above demo with docker, you can use the docker and skip the following setup and fast forward to [Export ONNX pipeline](#export-onnx-pipeline). diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py index 646e3518fa053..f5f168a8c7d58 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -54,8 +54,12 @@ def load_pipelines(args, batch_size): # For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size. # Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance. # This range can cover most frequent shape of landscape (832x1216), portrait (1216x832) or square (1024x1024). - min_image_size = 832 if args.engine != "ORT_CUDA" else 512 - max_image_size = 1216 if args.engine != "ORT_CUDA" else 2048 + if args.version == "xl-turbo": + min_image_size = 512 + max_image_size = 768 if args.engine != "ORT_CUDA" else 1024 + else: + min_image_size = 832 if args.engine != "ORT_CUDA" else 512 + max_image_size = 1216 if args.engine != "ORT_CUDA" else 2048 # No VAE decoder in base when it outputs latent instead of image. base_info = PipelineInfo( diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py index f0c83fc507ae4..5f2774632937a 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -61,7 +61,7 @@ def parse_arguments(is_xl: bool, parser): parser.add_argument( "--version", type=str, - default=supported_versions[-1] if is_xl else "1.5", + default="xl-1.0" if is_xl else "1.5", choices=supported_versions, help="Version of Stable Diffusion" + (" XL." if is_xl else "."), ) @@ -244,6 +244,20 @@ def parse_arguments(is_xl: bool, parser): args.onnx_opset = 14 if args.engine == "ORT_CUDA" else 17 if is_xl: + if args.version == "xl-turbo": + if args.guidance > 1.0: + print("[I] Use --guidance=1.0 for sdxl-turbo.") + args.guidance = 1.0 + if args.lcm: + print("[I] sdxl-turbo cannot use with LCM.") + args.lcm = False + if args.denoising_steps > 8: + print("[I] Use --denoising_steps=4 (no more than 8) for sdxl-turbo.") + args.denoising_steps = 4 + if not args.disable_refiner: + print("[I] sdxl-turbo cannot use with SDXL refiner.") + args.disable_refiner = True + if args.lcm and args.scheduler != "LCM": print("[I] Use --scheduler=LCM for base since LCM is used.") args.scheduler = "LCM" @@ -628,12 +642,12 @@ def process_controlnet_arguments(args): assert isinstance(args.controlnet_type, list) assert isinstance(args.controlnet_scale, list) assert isinstance(args.controlnet_image, list) - if args.version not in ["1.5", "xl-1.0"]: - raise ValueError("This demo only supports ControlNet in Stable Diffusion 1.5 or XL.") + if args.version not in ["1.5", "xl-1.0", "xl-turbo"]: + raise ValueError("This demo only supports ControlNet in Stable Diffusion 1.5, XL or Turbo.") - is_xl = args.version == "xl-1.0" + is_xl = "xl" in args.version if is_xl and len(args.controlnet_type) > 1: - raise ValueError("This demo only support one ControlNet for Stable Diffusion XL.") + raise ValueError("This demo only support one ControlNet for Stable Diffusion XL or Turbo.") if len(args.controlnet_image) != 0 and len(args.controlnet_image) != len(args.controlnet_scale): raise ValueError( diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py index c09aff2f514c6..6715eb1c135d2 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -120,17 +120,23 @@ def is_inpaint(self) -> bool: def is_xl(self) -> bool: return "xl" in self.version + def is_xl_turbo(self) -> bool: + return self.version == "xl-turbo" + def is_xl_base(self) -> bool: - return self.is_xl() and not self._is_refiner + return self.version == "xl-1.0" and not self._is_refiner + + def is_xl_base_or_turbo(self) -> bool: + return self.is_xl_base() or self.is_xl_turbo() def is_xl_refiner(self) -> bool: - return self.is_xl() and self._is_refiner + return self.version == "xl-1.0" and self._is_refiner def use_safetensors(self) -> bool: return self.is_xl() def stages(self) -> List[str]: - if self.is_xl_base(): + if self.is_xl_base_or_turbo(): return ["clip", "clip2", "unetxl"] + (["vae"] if self._use_vae else []) if self.is_xl_refiner(): @@ -153,7 +159,7 @@ def custom_unet(self) -> Optional[str]: @staticmethod def supported_versions(is_xl: bool): - return ["xl-1.0"] if is_xl else ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base"] + return ["xl-1.0", "xl-turbo"] if is_xl else ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base"] def name(self) -> str: if self.version == "1.4": @@ -185,6 +191,8 @@ def name(self) -> str: return "stabilityai/stable-diffusion-xl-refiner-1.0" else: return "stabilityai/stable-diffusion-xl-base-1.0" + elif self.version == "xl-turbo": + return "stabilityai/sdxl-turbo" raise ValueError(f"Incorrect version {self.version}") @@ -197,13 +205,13 @@ def clip_embedding_dim(self): return 768 elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base"): return 1024 - elif self.version in ("xl-1.0") and self.is_xl_base(): + elif self.is_xl_base_or_turbo(): return 768 else: raise ValueError(f"Invalid version {self.version}") def clipwithproj_embedding_dim(self): - if self.version in ("xl-1.0"): + if self.is_xl_base_or_turbo(): return 1280 else: raise ValueError(f"Invalid version {self.version}") @@ -213,9 +221,9 @@ def unet_embedding_dim(self): return 768 elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base"): return 1024 - elif self.version in ("xl-1.0") and self.is_xl_base(): + elif self.is_xl_base_or_turbo(): return 2048 - elif self.version in ("xl-1.0") and self.is_xl_refiner(): + elif self.version == "xl-1.0" and self.is_xl_refiner(): return 1280 else: raise ValueError(f"Invalid version {self.version}") @@ -227,7 +235,7 @@ def max_image_size(self): return self._max_image_size def default_image_size(self): - if self.is_xl(): + if self.version == "xl-1.0": return 1024 if self.version in ("2.0", "2.1"): return 768 @@ -235,7 +243,7 @@ def default_image_size(self): @staticmethod def supported_controlnet(version="1.5"): - if version == "xl-1.0": + if version in ("xl-1.0", "xl-turbo"): return { "canny": "diffusers/controlnet-canny-sdxl-1.0", "depth": "diffusers/controlnet-depth-sdxl-1.0", diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py index d3387ab6db1bd..fa0035494217b 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py @@ -40,7 +40,7 @@ def __init__(self, pipeline_info: PipelineInfo, *args, **kwargs): pipeline_info (PipelineInfo): Version and Type of stable diffusion pipeline. """ - assert pipeline_info.is_xl_base() + assert pipeline_info.is_xl_base_or_turbo() super().__init__(pipeline_info, *args, **kwargs)