diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 1ec1ca3ba0c83..54af8844d0c6c 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -83,6 +83,9 @@ For example: If you do not provide prompt, the script will generate different image sizes for a list of prompts for demonstration. +#### Generate an image with SDXL LCM guided by a text prompt +```python3 demo_txt2img_xl.py --lcm --disable-refiner "an astronaut riding a rainbow unicorn, cinematic, dramatic"``` + ## Optimize Stable Diffusion ONNX models for Hugging Face Diffusers or Optimum If you are able to run the above demo with docker, you can use the docker and skip the following setup and fast forward to [Export ONNX pipeline](#export-onnx-pipeline). diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py index 4636f139d4613..b3056cc47c647 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -22,7 +22,7 @@ import coloredlogs from cuda import cudart -from demo_utils import init_pipeline, parse_arguments, repeat_prompt +from demo_utils import get_metadata, init_pipeline, parse_arguments, repeat_prompt from diffusion_models import PipelineInfo from engine_builder import EngineType, get_engine_type from pipeline_txt2img import Txt2ImgPipeline @@ -104,17 +104,25 @@ def run_inference(warmup=False): if not args.disable_cuda_graph: # inference once to get cuda graph - _image, _latency = run_inference(warmup=True) + _, _ = run_inference(warmup=True) print("[I] Warming up ..") for _ in range(args.num_warmup_runs): - _image, _latency = run_inference(warmup=True) + _, _ = run_inference(warmup=True) print("[I] Running StableDiffusion pipeline") if args.nvtx_profile: cudart.cudaProfilerStart() - _image, _latency = run_inference(warmup=False) + images, perf_data = run_inference(warmup=False) if args.nvtx_profile: cudart.cudaProfilerStop() + metadata = get_metadata(args, False) + metadata.update(pipeline.metadata()) + if perf_data: + metadata.update(perf_data) + metadata["images"] = len(images) + print(metadata) + pipeline.save_images(images, prompt, negative_prompt, metadata) + pipeline.teardown() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py index 4f9ecf6cbb152..7ff1794a68f8c 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -22,7 +22,7 @@ import coloredlogs from cuda import cudart -from demo_utils import init_pipeline, parse_arguments, repeat_prompt +from demo_utils import get_metadata, init_pipeline, parse_arguments, repeat_prompt from diffusion_models import PipelineInfo from engine_builder import EngineType, get_engine_type from pipeline_img2img_xl import Img2ImgXLPipeline @@ -54,7 +54,11 @@ def load_pipelines(args, batch_size): # No VAE decoder in base when it outputs latent instead of image. base_info = PipelineInfo( - args.version, use_vae=args.disable_refiner, min_image_size=min_image_size, max_image_size=max_image_size + args.version, + use_vae=args.disable_refiner, + min_image_size=min_image_size, + max_image_size=max_image_size, + use_lcm=args.lcm, ) # Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to @@ -118,7 +122,7 @@ def run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False refiner.load_resources(image_height, image_width, batch_size) def run_base_and_refiner(warmup=False): - images, time_base = base.run( + images, base_perf = base.run( prompt, negative_prompt, image_height, @@ -130,24 +134,31 @@ def run_base_and_refiner(warmup=False): return_type="latent" if refiner else "image", ) if refiner is None: - return images, time_base + return images, base_perf # Use same seed in base and refiner. seed = base.get_current_seed() - images, time_refiner = refiner.run( + images, refiner_perf = refiner.run( prompt, negative_prompt, images, image_height, image_width, warmup=warmup, - denoising_steps=args.denoising_steps, - guidance=args.guidance, + denoising_steps=args.refiner_steps, + strength=args.strength, + guidance=args.refiner_guidance, seed=seed, ) - return images, time_base + time_refiner + perf_data = None + if base_perf and refiner_perf: + perf_data = {"latency": base_perf["latency"] + refiner_perf["latency"]} + perf_data.update({"base." + key: val for key, val in base_perf.items()}) + perf_data.update({"refiner." + key: val for key, val in refiner_perf.items()}) + + return images, perf_data if not args.disable_cuda_graph: # inference once to get cuda graph @@ -164,13 +175,24 @@ def run_base_and_refiner(warmup=False): print("[I] Running StableDiffusion XL pipeline") if args.nvtx_profile: cudart.cudaProfilerStart() - _, latency = run_base_and_refiner(warmup=False) + images, perf_data = run_base_and_refiner(warmup=False) if args.nvtx_profile: cudart.cudaProfilerStop() - print("|------------|--------------|") - print("| {:^10} | {:>9.2f} ms |".format("e2e", latency)) - print("|------------|--------------|") + if refiner: + print("|------------|--------------|") + print("| {:^10} | {:>9.2f} ms |".format("e2e", perf_data["latency"])) + print("|------------|--------------|") + + metadata = get_metadata(args, True) + metadata.update({"base." + key: val for key, val in base.metadata().items()}) + if refiner: + metadata.update({"refiner." + key: val for key, val in refiner.metadata().items()}) + if perf_data: + metadata.update(perf_data) + metadata["images"] = len(images) + print(metadata) + (refiner or base).save_images(images, prompt, negative_prompt, metadata) def run_demo(args): @@ -189,6 +211,8 @@ def run_dynamic_shape_demo(args): """Run demo of generating images with different settings with ORT CUDA provider.""" args.engine = "ORT_CUDA" args.disable_cuda_graph = True + if args.lcm: + args.disable_refiner = True base, refiner = load_pipelines(args, 1) prompts = [ @@ -198,22 +222,31 @@ def run_dynamic_shape_demo(args): "cute grey cat with blue eyes, wearing a bowtie, acrylic painting", "beautiful Renaissance Revival Estate, Hobbit-House, detailed painting, warm colors, 8k, trending on Artstation", "blue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realistic", + "An astronaut riding a rainbow unicorn, cinematic, dramatic", + "close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm", ] - # batch size, height, width, scheduler, steps, prompt, seed + # refiner, batch size, height, width, scheduler, steps, prompt, seed, guidance, refiner scheduler, refiner steps, refiner strength configs = [ - (1, 832, 1216, "UniPC", 8, prompts[0], None), - (1, 1024, 1024, "DDIM", 24, prompts[1], None), - (1, 1216, 832, "UniPC", 16, prompts[2], None), - (1, 1344, 768, "DDIM", 24, prompts[3], None), - (2, 640, 1536, "UniPC", 16, prompts[4], 4312973633252712), - (2, 1152, 896, "DDIM", 24, prompts[5], 1964684802882906), + (1, 832, 1216, "UniPC", 8, prompts[0], None, 5.0, "UniPC", 10, 0.3), + (1, 1024, 1024, "DDIM", 24, prompts[1], None, 5.0, "DDIM", 30, 0.3), + (1, 1216, 832, "UniPC", 16, prompts[2], None, 5.0, "UniPC", 10, 0.3), + (1, 1344, 768, "DDIM", 24, prompts[3], None, 5.0, "UniPC", 20, 0.3), + (2, 640, 1536, "UniPC", 16, prompts[4], 4312973633252712, 5.0, "UniPC", 10, 0.3), + (2, 1152, 896, "DDIM", 24, prompts[5], 1964684802882906, 5.0, "UniPC", 20, 0.3), ] + # In testing LCM, refiner is disabled so the settings of refiner is not used. + if args.lcm: + configs = [ + (1, 1024, 1024, "LCM", 8, prompts[6], None, 1.0, "UniPC", 20, 0.3), + (1, 1216, 832, "LCM", 6, prompts[7], 1337, 1.0, "UniPC", 20, 0.3), + ] + # Warm up each combination of (batch size, height, width) once before serving. args.prompt = ["warm up"] args.num_warmup_runs = 1 - for batch_size, height, width, _, _, _, _ in configs: + for batch_size, height, width, _, _, _, _, _, _, _, _ in configs: args.batch_size = batch_size args.height = height args.width = width @@ -223,7 +256,19 @@ def run_dynamic_shape_demo(args): # Run pipeline on a list of prompts. args.num_warmup_runs = 0 - for batch_size, height, width, scheduler, steps, example_prompt, seed in configs: + for ( + batch_size, + height, + width, + scheduler, + steps, + example_prompt, + seed, + guidance, + refiner_scheduler, + refiner_steps, + strength, + ) in configs: args.prompt = [example_prompt] args.batch_size = batch_size args.height = height @@ -231,12 +276,13 @@ def run_dynamic_shape_demo(args): args.scheduler = scheduler args.denoising_steps = steps args.seed = seed + args.guidance = guidance + args.refiner_scheduler = refiner_scheduler + args.refiner_steps = refiner_steps + args.strength = strength base.set_scheduler(scheduler) if refiner: - refiner.set_scheduler(scheduler) - print( - f"\nbatch_size={batch_size}, height={height}, width={width}, scheduler={scheduler}, steps={steps}, prompt={example_prompt}, seed={seed}" - ) + refiner.set_scheduler(refiner_scheduler) prompt, negative_prompt = repeat_prompt(args) run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py index 39ee273a3130d..70b4f34fdd988 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -21,6 +21,7 @@ # -------------------------------------------------------------------------- import argparse +from typing import Any, Dict import torch from diffusion_models import PipelineInfo @@ -68,8 +69,8 @@ def parse_arguments(is_xl: bool, description: str): "--scheduler", type=str, default="DDIM", - choices=["DDIM", "UniPC"] if is_xl else ["DDIM", "EulerA", "UniPC"], - help="Scheduler for diffusion process", + choices=["DDIM", "UniPC", "LCM"] if is_xl else ["DDIM", "EulerA", "UniPC"], + help="Scheduler for diffusion process" + " of base" if is_xl else "", ) parser.add_argument( @@ -105,6 +106,42 @@ def parse_arguments(is_xl: bool, description: str): help="Higher guidance scale encourages to generate images that are closely linked to the text prompt.", ) + if is_xl: + parser.add_argument( + "--lcm", + action="store_true", + help="Use fine-tuned latent consistency model to replace the UNet in base.", + ) + + parser.add_argument( + "--refiner-scheduler", + type=str, + default="DDIM", + choices=["DDIM", "UniPC"], + help="Scheduler for diffusion process of refiner.", + ) + + parser.add_argument( + "--refiner-guidance", + type=float, + default=5.0, + help="Guidance scale used in refiner.", + ) + + parser.add_argument( + "--refiner-steps", + type=int, + default=30, + help="Number of denoising steps in refiner. Note that actual refiner steps is refiner_steps * strength.", + ) + + parser.add_argument( + "--strength", + type=float, + default=0.3, + help="A value between 0 and 1. The higher the value less the final image similar to the seed image.", + ) + # ONNX export parser.add_argument( "--onnx-opset", @@ -190,11 +227,52 @@ def parse_arguments(is_xl: bool, description: str): if args.onnx_opset is None: args.onnx_opset = 14 if args.engine == "ORT_CUDA" else 17 + if is_xl: + if args.lcm: + if args.guidance > 1.0: + print("[I] Use --guidance=1.0 for base since LCM is used.") + args.guidance = 1.0 + if args.scheduler != "LCM": + print("[I] Use --scheduler=LCM for base since LCM is used.") + args.scheduler = "LCM" + if args.denoising_steps > 16: + print("[I] Use --denoising_steps=8 (no more than 16) for base since LCM is used.") + args.denoising_steps = 8 + assert args.strength > 0.0 and args.strength < 1.0 + print(args) return args +def get_metadata(args, is_xl: bool = False) -> Dict[str, Any]: + metadata = { + "args.prompt": args.prompt, + "args.negative_prompt": args.negative_prompt, + "args.batch_size": args.batch_size, + "height": args.height, + "width": args.width, + "cuda_graph": not args.disable_cuda_graph, + "vae_slicing": args.enable_vae_slicing, + "engine": args.engine, + } + + if is_xl and not args.disable_refiner: + metadata["base.scheduler"] = args.scheduler + metadata["base.denoising_steps"] = args.denoising_steps + metadata["base.guidance"] = args.guidance + metadata["refiner.strength"] = args.strength + metadata["refiner.scheduler"] = args.refiner_scheduler + metadata["refiner.denoising_steps"] = args.refiner_steps + metadata["refiner.guidance"] = args.refiner_guidance + else: + metadata["scheduler"] = args.scheduler + metadata["denoising_steps"] = args.denoising_steps + metadata["guidance"] = args.guidance + + return metadata + + def repeat_prompt(args): if not isinstance(args.prompt, list): raise ValueError(f"`prompt` must be of type `str` or `str` list, but is {type(args.prompt)}") @@ -223,7 +301,7 @@ def init_pipeline( # Initialize demo pipeline = pipeline_class( pipeline_info, - scheduler=args.scheduler, + scheduler=args.refiner_scheduler if pipeline_info.is_xl_refiner() else args.scheduler, output_dir=output_dir, hf_token=args.hf_token, verbose=False, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py index 514205d3b8945..8206bee753859 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -91,6 +91,7 @@ def __init__( min_image_size=256, max_image_size=1024, use_fp16_vae=True, + use_lcm=False, ): self.version = version self._is_inpaint = is_inpaint @@ -99,7 +100,9 @@ def __init__( self._min_image_size = min_image_size self._max_image_size = max_image_size self._use_fp16_vae = use_fp16_vae + self._use_lcm = use_lcm if is_refiner: + assert not use_lcm assert self.is_xl() def is_inpaint(self) -> bool: @@ -136,6 +139,9 @@ def custom_fp16_vae(self) -> Optional[str]: # For SD XL, use a VAE that fine-tuned to run in fp16 precision without generating NaNs return "madebyollin/sdxl-vae-fp16-fix" if self._use_fp16_vae and self.is_xl() else None + def custom_unet(self) -> Optional[str]: + return "latent-consistency/lcm-sdxl" if self._use_lcm and self.is_xl_base() else None + @staticmethod def supported_versions(is_xl: bool): return ["xl-1.0"] if is_xl else ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base"] @@ -730,8 +736,22 @@ def __init__( self.unet_dim = unet_dim self.time_dim = time_dim + self.custom_unet = pipeline_info.custom_unet() + self.do_classifier_free_guidance = not (self.custom_unet and "lcm" in self.custom_unet) + self.batch_multiplier = 2 if self.do_classifier_free_guidance else 1 + def load_model(self, framework_model_dir, hf_token, subfolder="unet"): options = {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 else {} + + if self.custom_unet: + model_dir = os.path.join(framework_model_dir, self.custom_unet, subfolder) + if not os.path.exists(model_dir): + unet = UNet2DConditionModel.from_pretrained(self.custom_unet, **options) + unet.save_pretrained(model_dir) + else: + unet = UNet2DConditionModel.from_pretrained(model_dir, **options) + return unet.to(self.device) + return self.from_pretrained(UNet2DConditionModel, framework_model_dir, hf_token, subfolder, **options) def get_input_names(self): @@ -741,12 +761,20 @@ def get_output_names(self): return ["latent"] def get_dynamic_axes(self): + if self.do_classifier_free_guidance: + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + "text_embeds": {0: "2B"}, + "time_ids": {0: "2B"}, + } return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "encoder_hidden_states": {0: "2B"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - "text_embeds": {0: "2B"}, - "time_ids": {0: "2B"}, + "sample": {0: "B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "B"}, + "latent": {0: "B", 2: "H", 3: "W"}, + "text_embeds": {0: "B"}, + "time_ids": {0: "B"}, } def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): @@ -763,49 +791,52 @@ def get_input_profile(self, batch_size, image_height, image_width, static_batch, min_latent_width, max_latent_width, ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + m = self.batch_multiplier return { "sample": [ - (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), - (2 * batch_size, self.unet_dim, latent_height, latent_width), - (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), + (m * min_batch, self.unet_dim, min_latent_height, min_latent_width), + (m * batch_size, self.unet_dim, latent_height, latent_width), + (m * max_batch, self.unet_dim, max_latent_height, max_latent_width), ], "encoder_hidden_states": [ - (2 * min_batch, self.text_maxlen, self.embedding_dim), - (2 * batch_size, self.text_maxlen, self.embedding_dim), - (2 * max_batch, self.text_maxlen, self.embedding_dim), + (m * min_batch, self.text_maxlen, self.embedding_dim), + (m * batch_size, self.text_maxlen, self.embedding_dim), + (m * max_batch, self.text_maxlen, self.embedding_dim), ], - "text_embeds": [(2 * min_batch, 1280), (2 * batch_size, 1280), (2 * max_batch, 1280)], + "text_embeds": [(m * min_batch, 1280), (m * batch_size, 1280), (m * max_batch, 1280)], "time_ids": [ - (2 * min_batch, self.time_dim), - (2 * batch_size, self.time_dim), - (2 * max_batch, self.time_dim), + (m * min_batch, self.time_dim), + (m * batch_size, self.time_dim), + (m * max_batch, self.time_dim), ], } def get_shape_dict(self, batch_size, image_height, image_width): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + m = self.batch_multiplier return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "sample": (m * batch_size, self.unet_dim, latent_height, latent_width), "timestep": (1,), - "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), - "latent": (2 * batch_size, 4, latent_height, latent_width), - "text_embeds": (2 * batch_size, 1280), - "time_ids": (2 * batch_size, self.time_dim), + "encoder_hidden_states": (m * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (m * batch_size, 4, latent_height, latent_width), + "text_embeds": (m * batch_size, 1280), + "time_ids": (m * batch_size, self.time_dim), } def get_sample_input(self, batch_size, image_height, image_width): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) dtype = torch.float16 if self.fp16 else torch.float32 + m = self.batch_multiplier return ( torch.randn( - 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + m * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device ), torch.tensor([1.0], dtype=torch.float32, device=self.device), - torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), { "added_cond_kwargs": { - "text_embeds": torch.randn(2 * batch_size, 1280, dtype=dtype, device=self.device), - "time_ids": torch.randn(2 * batch_size, self.time_dim, dtype=dtype, device=self.device), + "text_embeds": torch.randn(m * batch_size, 1280, dtype=dtype, device=self.device), + "time_ids": torch.randn(m * batch_size, self.time_dim, dtype=dtype, device=self.device), } }, ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py index 26c8450c57de9..6932c8056cf78 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py @@ -719,3 +719,228 @@ def configure(self): def __len__(self): return self.num_train_timesteps + + +# Modified from diffusers.schedulers.LCMScheduler +class LCMScheduler: + def __init__( + self, + device="cuda", + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.012, + original_inference_steps: int = 50, + clip_sample: bool = False, + clip_sample_range: float = 1.0, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + timestep_scaling: float = 10.0, + ): + self.device = device + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.final_alpha_cumprod = self.alphas_cumprod[0] + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + self.num_train_timesteps = num_train_timesteps + self.clip_sample = clip_sample + self.clip_sample_range = clip_sample_range + self.steps_offset = steps_offset + self.prediction_type = prediction_type + self.thresholding = thresholding + self.timestep_spacing = timestep_spacing + self.timestep_scaling = timestep_scaling + self.original_inference_steps = original_inference_steps + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.sample_max_value = sample_max_value + + self._step_index = None + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + if len(index_candidates) > 1: + step_index = index_candidates[1] + else: + step_index = index_candidates[0] + + self._step_index = step_index.item() + + @property + def step_index(self): + return self._step_index + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def set_timesteps( + self, + num_inference_steps: int, + strength: int = 1.0, + ): + assert num_inference_steps <= self.num_train_timesteps + + self.num_inference_steps = num_inference_steps + original_steps = self.original_inference_steps + + assert original_steps <= self.num_train_timesteps + assert num_inference_steps <= original_steps + + # LCM Timesteps Setting + # Currently, only linear spacing is supported. + c = self.num_train_timesteps // original_steps + # LCM Training Steps Schedule + lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * c - 1 + skipping_step = len(lcm_origin_timesteps) // num_inference_steps + # LCM Inference Steps Schedule + timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] + + self.timesteps = torch.from_numpy(timesteps.copy()).to(device=self.device, dtype=torch.long) + + self._step_index = None + + def get_scalings_for_boundary_condition_discrete(self, timestep): + self.sigma_data = 0.5 # Default: 0.5 + scaled_timestep = timestep * self.timestep_scaling + + c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5 + return c_skip, c_out + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + ): + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # 1. get previous step value + prev_step_index = self.step_index + 1 + if prev_step_index < len(self.timesteps): + prev_timestep = self.timesteps[prev_step_index] + else: + prev_timestep = timestep + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 3. Get scalings for boundary conditions + c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) + + # 4. Compute the predicted original sample x_0 based on the model parameterization + if self.prediction_type == "epsilon": # noise-prediction + predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() + elif self.prediction_type == "sample": # x-prediction + predicted_original_sample = model_output + elif self.prediction_type == "v_prediction": # v-prediction + predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for `LCMScheduler`." + ) + + # 5. Clip or threshold "predicted x_0" + if self.thresholding: + predicted_original_sample = self._threshold_sample(predicted_original_sample) + elif self.clip_sample: + predicted_original_sample = predicted_original_sample.clamp(-self.clip_sample_range, self.clip_sample_range) + + # 6. Denoise model output using boundary conditions + denoised = c_out * predicted_original_sample + c_skip * sample + + # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference + # Noise is not used on the final timestep of the timestep schedule. + # This also means that noise is not used for one-step sampling. + if self.step_index != self.num_inference_steps - 1: + noise = torch.randn( + model_output.shape, device=model_output.device, dtype=denoised.dtype, generator=generator + ) + prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise + else: + prev_sample = denoised + + # upon completion increase step index by one + self._step_index += 1 + + return (prev_sample,) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def configure(self): + pass + + def __len__(self): + return self.num_train_timesteps diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py index ace75bfbae7cb..fac72be346b3d 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -77,6 +77,12 @@ def teardown(self): self.engines = {} def get_cached_model_name(self, model_name): + # TODO(tianleiwu): save custom model to a directory named by its original model. + if model_name == "unetxl" and self.pipeline_info.custom_unet(): + model_name = "lcm_" + model_name + + # TODO: When we support original VAE, we shall save custom VAE to another directory. + if self.pipeline_info.is_inpaint(): model_name += "_inpaint" return model_name @@ -93,6 +99,7 @@ def get_engine_path(self, engine_dir, model_name, profile_id): def load_models(self, framework_model_dir: str): # Disable torch SDPA since torch 2.0.* cannot export it to ONNX + # TODO(tianleiwu): Test and remove it if this is not needed in Torch 2.1. if hasattr(torch.nn.functional, "scaled_dot_product_attention"): delattr(torch.nn.functional, "scaled_dot_product_attention") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py index faa3f8bfaabf1..31ede1ba901f2 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py @@ -68,6 +68,7 @@ def _infer( image_height, image_width, denoising_steps=30, + strength=0.3, guidance=5.0, seed=None, warmup=False, @@ -79,7 +80,6 @@ def _infer( crops_coords_top_left = (0, 0) target_size = (image_height, image_width) - strength = 0.3 aesthetic_score = 6.0 negative_aesthetic_score = 2.5 @@ -155,12 +155,12 @@ def _infer( torch.cuda.synchronize() e2e_toc = time.perf_counter() + perf_data = None if not warmup: print("SD-XL Refiner Pipeline") - self.print_summary(e2e_tic, e2e_toc, batch_size) - self.save_images(images, "img2img-xl", prompt) + perf_data = self.print_summary(e2e_tic, e2e_toc, batch_size) - return images, (e2e_toc - e2e_tic) * 1000.0 + return images, perf_data def run( self, @@ -171,6 +171,7 @@ def run( image_width, denoising_steps=30, guidance=5.0, + strength=0.3, seed=None, warmup=False, return_type="image", @@ -213,6 +214,7 @@ def run( image_height, image_width, denoising_steps=denoising_steps, + strength=strength, guidance=guidance, seed=seed, warmup=warmup, @@ -226,6 +228,7 @@ def run( image_height, image_width, denoising_steps=denoising_steps, + strength=strength, guidance=guidance, seed=seed, warmup=warmup, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py index e675c9a7b3bf5..a0b3c3a1c85b1 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py @@ -23,12 +23,13 @@ import os import pathlib import random +from typing import Any, Dict, List import nvtx import torch from cuda import cudart from diffusion_models import PipelineInfo, get_tokenizer -from diffusion_schedulers import DDIMScheduler, EulerAncestralDiscreteScheduler, UniPCMultistepScheduler +from diffusion_schedulers import DDIMScheduler, EulerAncestralDiscreteScheduler, LCMScheduler, UniPCMultistepScheduler from engine_builder import EngineType from engine_builder_ort_cuda import OrtCudaEngineBuilder from engine_builder_ort_trt import OrtTensorrtEngineBuilder @@ -63,7 +64,7 @@ def __init__( max_batch_size (int): Maximum batch size for dynamic batch engine. scheduler (str): - The scheduler to guide the denoising process. Must be one of [DDIM, EulerA, UniPC]. + The scheduler to guide the denoising process. Must be one of [DDIM, EulerA, UniPC, LCM]. device (str): PyTorch device to run inference. Default: 'cuda' output_dir (str): @@ -162,9 +163,11 @@ def set_scheduler(self, scheduler: str): elif scheduler == "EulerA": self.scheduler = EulerAncestralDiscreteScheduler(device=self.device, **sched_opts) elif scheduler == "UniPC": - self.scheduler = UniPCMultistepScheduler(device=self.device) + self.scheduler = UniPCMultistepScheduler(device=self.device, **sched_opts) + elif scheduler == "LCM": + self.scheduler = LCMScheduler(device=self.device, **sched_opts) else: - raise ValueError("Scheduler should be either DDIM, EulerA or UniPC") + raise ValueError("Scheduler should be either DDIM, EulerA, UniPC or LCM") self.current_scheduler = scheduler self.denoising_steps = None @@ -238,6 +241,7 @@ def encode_prompt( pooled_outputs=False, output_hidden_states=False, force_zeros_for_empty_prompt=False, + do_classifier_free_guidance=True, ): if tokenizer is None: tokenizer = self.tokenizer @@ -265,41 +269,44 @@ def encode_prompt( if output_hidden_states: hidden_states = outputs["hidden_states"].clone() - # Note: negative prompt embedding is not needed for SD XL when guidance < 1 - - # For SD XL base, handle force_zeros_for_empty_prompt - is_empty_negative_prompt = all([not i for i in negative_prompt]) - if force_zeros_for_empty_prompt and is_empty_negative_prompt: - uncond_embeddings = torch.zeros_like(text_embeddings) - if output_hidden_states: - uncond_hidden_states = torch.zeros_like(hidden_states) - else: - # Tokenize negative prompt - uncond_input_ids = ( - tokenizer( - negative_prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", + # Note: negative prompt embedding is not needed for SD XL when guidance <= 1 + if do_classifier_free_guidance: + # For SD XL base, handle force_zeros_for_empty_prompt + is_empty_negative_prompt = all([not i for i in negative_prompt]) + if force_zeros_for_empty_prompt and is_empty_negative_prompt: + uncond_embeddings = torch.zeros_like(text_embeddings) + if output_hidden_states: + uncond_hidden_states = torch.zeros_like(hidden_states) + else: + # Tokenize negative prompt + uncond_input_ids = ( + tokenizer( + negative_prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.device) ) - .input_ids.type(torch.int32) - .to(self.device) - ) - outputs = self.run_engine(encoder, {"input_ids": uncond_input_ids}) - uncond_embeddings = outputs["text_embeddings"] - if output_hidden_states: - uncond_hidden_states = outputs["hidden_states"] + outputs = self.run_engine(encoder, {"input_ids": uncond_input_ids}) + uncond_embeddings = outputs["text_embeddings"] + if output_hidden_states: + uncond_hidden_states = outputs["hidden_states"] - # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) if pooled_outputs: pooled_output = text_embeddings if output_hidden_states: - text_embeddings = torch.cat([uncond_hidden_states, hidden_states]).to(dtype=torch.float16) + if do_classifier_free_guidance: + text_embeddings = torch.cat([uncond_hidden_states, hidden_states]).to(dtype=torch.float16) + else: + text_embeddings = hidden_states.to(dtype=torch.float16) cudart.cudaEventRecord(self.events["clip-stop"], 0) if self.nvtx_profile: @@ -321,7 +328,7 @@ def denoise_latent( guidance=7.5, add_kwargs=None, ): - assert guidance > 1.0, "Guidance has to be > 1.0" # TODO: remove this constraint + do_classifier_free_guidance = guidance > 1.0 cudart.cudaEventRecord(self.events["denoise-start"], 0) if not isinstance(timesteps, torch.Tensor): @@ -332,7 +339,7 @@ def denoise_latent( nvtx_latent_scale = nvtx.start_range(message="latent_scale", color="pink") # Expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input( latent_model_input, step_offset + step_index, timestep @@ -366,11 +373,14 @@ def denoise_latent( nvtx_latent_step = nvtx.start_range(message="latent_step", color="pink") # perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond) if type(self.scheduler) == UniPCMultistepScheduler: latents = self.scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] + elif type(self.scheduler) == LCMScheduler: + latents = self.scheduler.step(noise_pred, timestep, latents, generator=self.generator)[0] else: latents = self.scheduler.step(noise_pred, latents, step_offset + step_index, timestep) @@ -406,38 +416,42 @@ def decode_latent(self, latents): nvtx.end_range(nvtx_vae) return images - def print_summary(self, tic, toc, batch_size, vae_enc=False): + def print_summary(self, tic, toc, batch_size, vae_enc=False) -> Dict[str, Any]: + throughput = batch_size / (toc - tic) + latency_clip = cudart.cudaEventElapsedTime(self.events["clip-start"], self.events["clip-stop"])[1] + latency_unet = cudart.cudaEventElapsedTime(self.events["denoise-start"], self.events["denoise-stop"])[1] + latency_vae = cudart.cudaEventElapsedTime(self.events["vae-start"], self.events["vae-stop"])[1] + latency_vae_encoder = ( + cudart.cudaEventElapsedTime(self.events["vae_encoder-start"], self.events["vae_encoder-stop"])[1] + if vae_enc + else None + ) + latency = (toc - tic) * 1000.0 + print("|------------|--------------|") print("| {:^10} | {:^12} |".format("Module", "Latency")) print("|------------|--------------|") if vae_enc: - print( - "| {:^10} | {:>9.2f} ms |".format( - "VAE-Enc", - cudart.cudaEventElapsedTime(self.events["vae_encoder-start"], self.events["vae_encoder-stop"])[1], - ) - ) - print( - "| {:^10} | {:>9.2f} ms |".format( - "CLIP", cudart.cudaEventElapsedTime(self.events["clip-start"], self.events["clip-stop"])[1] - ) - ) - print( - "| {:^10} | {:>9.2f} ms |".format( - "UNet x " + str(self.actual_steps), - cudart.cudaEventElapsedTime(self.events["denoise-start"], self.events["denoise-stop"])[1], - ) - ) - print( - "| {:^10} | {:>9.2f} ms |".format( - "VAE-Dec", cudart.cudaEventElapsedTime(self.events["vae-start"], self.events["vae-stop"])[1] - ) - ) + print("| {:^10} | {:>9.2f} ms |".format("VAE-Enc", latency_vae_encoder)) + print("| {:^10} | {:>9.2f} ms |".format("CLIP", latency_clip)) + print("| {:^10} | {:>9.2f} ms |".format("UNet x " + str(self.actual_steps), latency_unet)) + print("| {:^10} | {:>9.2f} ms |".format("VAE-Dec", latency_vae)) print("|------------|--------------|") - print("| {:^10} | {:>9.2f} ms |".format("Pipeline", (toc - tic) * 1000.0)) + print("| {:^10} | {:>9.2f} ms |".format("Pipeline", latency)) print("|------------|--------------|") - print(f"Throughput: {batch_size / (toc - tic):.2f} image/s") + print(f"Throughput: {throughput:.2f} image/s") + + perf_data = { + "latency_clip": latency_clip, + "latency_unet": latency_unet, + "latency_vae": latency_vae, + "latency": latency, + "throughput": throughput, + } + if vae_enc: + perf_data["latency_vae_encoder"] = latency_vae_encoder + return perf_data @staticmethod def to_pil_image(images): @@ -449,26 +463,31 @@ def to_pil_image(images): return [Image.fromarray(images[i]) for i in range(images.shape[0])] - def save_images(self, images, pipeline, prompt): - image_name_prefix = ( - pipeline + "".join(set(["-" + prompt[i].replace(" ", "_")[:10] for i in range(len(prompt))])) + "-" - ) + def metadata(self) -> Dict[str, Any]: + return { + "actual_steps": self.actual_steps, + "seed": self.get_current_seed(), + "name": self.pipeline_info.name(), + "custom_vae": self.pipeline_info.custom_fp16_vae(), + "custom_unet": self.pipeline_info.custom_unet(), + } + def save_images(self, images: List, prompt: List[str], negative_prompt: List[str], metadata: Dict[str, Any]): images = self.to_pil_image(images) - random_session_id = str(random.randint(1000, 9999)) + session_id = str(random.randint(1000, 9999)) for i, image in enumerate(images): seed = str(self.get_current_seed()) - image_path = os.path.join( - self.output_dir, image_name_prefix + str(i + 1) + "-" + random_session_id + "-" + seed + ".png" - ) + prefix = "".join(x for x in prompt[i] if x.isalnum() or x in ", -").replace(" ", "_")[:20] + parts = [prefix, session_id, str(i + 1), str(seed), self.current_scheduler, str(self.actual_steps)] + image_path = os.path.join(self.output_dir, "-".join(parts) + ".png") print(f"Saving image {i+1} / {len(images)} to: {image_path}") from PIL import PngImagePlugin - metadata = PngImagePlugin.PngInfo() - metadata.add_text("prompt", prompt[i]) - metadata.add_text("batch_size", str(len(images))) - metadata.add_text("denoising_steps", str(self.denoising_steps)) - metadata.add_text("actual_steps", str(self.actual_steps)) - metadata.add_text("seed", seed) - image.save(image_path, "PNG", pnginfo=metadata) + info = PngImagePlugin.PngInfo() + for k, v in metadata.items(): + info.add_text(k, str(v)) + info.add_text("prompt", prompt[i]) + info.add_text("negative_prompt", negative_prompt[i]) + + image.save(image_path, "PNG", pnginfo=info) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py index b9759b44e7635..87ce85af247a5 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py @@ -84,11 +84,11 @@ def _infer( torch.cuda.synchronize() e2e_toc = time.perf_counter() + perf_data = None if not warmup: - self.print_summary(e2e_tic, e2e_toc, batch_size) - self.save_images(images, "txt2img", prompt) + perf_data = self.print_summary(e2e_tic, e2e_toc, batch_size) - return images, (e2e_toc - e2e_tic) * 1000.0 + return images, perf_data def run( self, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py index 1b3be143e6ce7..8ed7e20e94c07 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py @@ -62,7 +62,7 @@ def _infer( return_type="image", ): assert len(prompt) == len(negative_prompt) - + do_classifier_free_guidance = guidance > 1.0 original_size = (image_height, image_width) crops_coords_top_left = (0, 0) target_size = (image_height, image_width) @@ -91,6 +91,7 @@ def _infer( tokenizer=self.tokenizer, output_hidden_states=True, force_zeros_for_empty_prompt=True, + do_classifier_free_guidance=do_classifier_free_guidance, ) # CLIP text encoder 2 text_embeddings2, pooled_embeddings2 = self.encode_prompt( @@ -101,6 +102,7 @@ def _infer( pooled_outputs=True, output_hidden_states=True, force_zeros_for_empty_prompt=True, + do_classifier_free_guidance=do_classifier_free_guidance, ) # Merged text embeddings @@ -111,9 +113,10 @@ def _infer( original_size, crops_coords_top_left, target_size, dtype=text_embeddings.dtype ) add_time_ids = add_time_ids.repeat(batch_size, 1) - add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0).to(self.device) + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) - add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids} + add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids.to(self.device)} # UNet denoiser latents = self.denoise_latent( @@ -133,13 +136,12 @@ def _infer( torch.cuda.synchronize() e2e_toc = time.perf_counter() + perf_data = None if not warmup: print("SD-XL Base Pipeline") - self.print_summary(e2e_tic, e2e_toc, batch_size) - if return_type != "latent": - self.save_images(images, "txt2img-xl", prompt) + perf_data = self.print_summary(e2e_tic, e2e_toc, batch_size) - return images, (e2e_toc - e2e_tic) * 1000.0 + return images, perf_data def run( self, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt index a00e25ddd983f..63fa8acfbcc95 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt @@ -1,8 +1,8 @@ -diffusers==0.19.3 -transformers==4.31.0 +diffusers==0.23.1 +transformers==4.35.1 numpy>=1.24.1 accelerate -onnx==1.14.0 +onnx==1.14.1 coloredlogs packaging # Use newer version of protobuf might cause crash