diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 5927a469ca3e4..b10c10c87ee57 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -21,7 +21,7 @@ These optimizations are firstly carried out on CUDA EP. They may not work on oth | [demo_txt2img.py](./demo_txt2img.py) | Demo of text to image generation using Stable Diffusion models except XL. | | [optimize_pipeline.py](./optimize_pipeline.py) | Optimize Stable Diffusion ONNX models exported from Huggingface diffusers or optimum | | [benchmark.py](./benchmark.py) | Benchmark latency and memory of OnnxRuntime, xFormers or PyTorch 2.0 on stable diffusion. | -| [benchmark_turbo.py](./benchmark_controlnet.py)| Benchmark latency of PyTorch or Stable-Fast with canny control net. | +| [benchmark_controlnet.py](./benchmark_controlnet.py)| Benchmark latency of canny control net. | ## Run demo with docker @@ -379,97 +379,6 @@ Common settings for below test results: | ------------------------------ | ---------------------- | ------ | ----- | ----- | ----------- | ----------- | | runwayml/stable-diffusion-v1-5 | TRUE | 512 | 512 | 50 | 5 | 1 | -#### Results of RTX 3060 (Windows 11) - -| engine | version | provider | batch size | average latency | first run memory MB | second run memory MB | -| ----------- | ----------------------- | --------------------- | ---------- | --------------- | ------------------- | -------------------- | -| onnxruntime | 1.14.1 | CUDA | 1 | 4.8 | 4,117 | 4,625 | -| torch | 2.0.0+cu117 | default | 1 | 5.6 | 4,325 | 4,047 | -| torch | 1.13.1+cu117 | xformers | 1 | 6.0 | 9,124 | 9,130 | -| onnxruntime | 1.14.1 | CUDA | 4 | 17.7 | 6,659 | 6,659 | -| torch | 2.0.0+cu117 | default | 4 | 20.1 | 6,421 | 6,907 | -| torch | 1.13.1+cu117 | xformers | 4 | 21.6 | 10,407 | 10,409 | -| onnxruntime | 1.14.1 | CUDA | 8 | 33.5 | 6,663 | 6,663 | -| torch | 2.0.0+cu117 | default | 8 | 39.5 | 10,767 | 10,813 | -| torch | 1.13.1+cu117 | xformers | 8 | 41.1 | 10,825 | 9,255 | - - -#### Results of A100-SXM4-40GB (Ubuntu 20.04) -| engine | version | provider | batch size | average latency | first run memory MB | second run memory MB | -| ----------- | ----------------------- | --------------------- | ---------- | --------------- | ------------------- | -------------------- | -| onnxruntime | 1.14.1 | CUDA | 1 | 1.1 | 6,883 | 7,395 | -| torch | 2.0.0+cu117 | default | 1 | 1.5 | 13,828 | 4,400 | -| torch | 2.0.0+cu117 | compile | 1 | 1.8 | 13,892 | 4,386 | -| onnxruntime | 1.14.1 | CUDA | 4 | 3.7 | 7,381 | 7,381 | -| torch | 2.0.0+cu117 | default | 4 | 3.9 | 31,278 | 6,870 | -| torch | 2.0.0+cu117 | compile | 4 | 3.4 | 31,364 | 6,880 | -| onnxruntime | 1.14.1 | CUDA | 8 | 6.9 | 7,411 | 7,411 | -| torch | 2.0.0+cu117 | default | 8 | 7.6 | 31,660 | 10,122 | -| torch | 2.0.0+cu117 | compile | 8 | 6.5 | 31,800 | 10,308 | -| onnxruntime | 1.14.1 | CUDA | 16 | 13.6 | 11,479 | 11,479 | -| torch | 2.0.0+cu117 | default | 16 | 14.8 | 32,306 | 16,520 | -| torch | 2.0.0+cu117 | compile | 16 | 12.6 | 32,636 | 16,898 | - -#### Results of A100-PCIE-80GB (Ubuntu 20.04) -| engine | version | provider | batch size | average latency | first run memory MB | second run memory MB | -| ----------- | ----------------------- | --------------------- | ---------- | --------------- | ------------------- | -------------------- | -| tensorrt | 8.6.1 | default | 1 | 1.00 | 9,056 | 9,056 | -| onnxruntime | 1.16.0 nightly | tensorrt | 1 | 1.09 | 11,250 | 11,250 | -| onnxruntime | 1.16.0 nightly | tensorrt (cuda graph) | 1 | 0.96 | 11,382 | 11,382 | -| onnxruntime | 1.16.0 nightly | cuda | 1 | 1.11 | 4,760 | 5,144 | -| onnxruntime | 1.16.0 nightly | cuda (cuda graph) | 1 | 1.04 | 5,230 | 5,390 | -| tensorrt | 8.6.1 | default | 4 | 3.39 | 9,072 | 9,072 | -| onnxruntime | 1.16.0 nightly | tensorrt | 4 | 3.60 | 11,266 | 11,266 | -| onnxruntime | 1.16.0 nightly | tensorrt (cuda graph) | 4 | 3.43 | 11,428 | 11,428 | - -#### Results of V100-PCIE-16GB (Ubuntu 20.04) - -Results from Standard_NC6s_v3 Azure virtual machine: - -| engine | version | provider | batch size | average latency | first run memory MB | second run memory MB | -| ----------- | ----------------------- | --------------------- | ---------- | --------------- | ------------------- | -------------------- | -| onnxruntime | 1.14.1 | CUDA | 1 | 2.7 | 12,646 | 7,152 | -| torch | 2.0.0+cu117 | compile | 1 | 3.2 | 13,317 | 3,909 | -| torch | 2.0.0+cu117 | default | 1 | 2.7 | 13,343 | 3,921 | -| torch | 1.13.1+cu117 | xformers | 1 | 3.5 | 14,979 | 10,449 | -| onnxruntime | 1.14.1 | CUDA | 4 | 8.4 | 7,114 | 7,114 | -| torch | 2.0.0+cu117 | compile | 4 | 8.0 | 13,897 | 6,821 | -| torch | 2.0.0+cu117 | default | 4 | 8.7 | 13,873 | 6,607 | -| torch | 1.13.1+cu117 | xformers | 4 | 9.1 | 12,969 | 8,421 | -| onnxruntime | 1.14.1 | CUDA | 8 | 15.9 | 7,120 | 7,120 | -| torch | 2.0.0+cu117 | compile | 8 | 15.5 | 14,669 | 10,355 | -| torch | 2.0.0+cu117 | default | 8 | 17.0 | 14,469 | 9,657 | -| torch | 1.13.1+cu117 | xformers | 8 | 17.4 | 15,593 | 9,133 | - -#### Results of T4 (Ubuntu 20.04) - -To make the result stable, we lock the frequency of T4 GPU like -`sudo nvidia-smi --lock-gpu-clocks=990` for fair comparison. See [nvidia blog](https://developer.nvidia.com/blog/advanced-api-performance-setstablepowerstate/) for more information. Note that performance might be slightly better without locking frequency. - -Results are from Standard_NC4as_T4_v3 Azure virtual machine: - -| engine | version | provider | batch size | average latency | first run memory MB | second run memory MB | -| ----------- | ----------------------- | --------------------- | ---------- | --------------- | ------------------- | -------------------- | -| onnxruntime | 1.14.1 | CUDA | 1 | 5.6 | 4,925 | 4,925 | -| onnxruntime | 1.15.1 | CUDA | 1 | 5.5 | 3,738 | 4,250 | -| onnxruntime | 1.15.1 (tensorrt 8.6.1) | Tensorrt | 1 | 4.8 | 10,710 | 10,710 | -| onnxruntime | 1.16.0 nightly | Tensorrt (cuda graph) | 1 | 4.7 | 11,746 | 10,746 | -| tensorrt | 8.6.1 | default | 1 | 5.0 | 8,530 | 8,530 | -| torch | 1.13.1+cu117 | xformers | 1 | 6.9 | 14,845 | 10,317 | -| torch | 2.0.0+cu117 | compile | 1 | 6.0 | 12,989 | 3,841 | -| torch | 2.0.0+cu117 | default | 1 | 6.4 | 12,987 | 3,841 | -| onnxruntime | 1.14.1 | CUDA | 4 | 23.0 | 6,977 | 6,977 | -| onnxruntime | 1.15.1 | CUDA | 4 | 22.6 | 6,298 | 6,298 | -| onnxruntime | 1.15.1 (tensorrt 8.6.1) | Tensorrt | 4 | 21.8 | 10,746 | 10,746 | -| tensorrt | 8.6.1 | default | 4 | 22.2 | 8,542 | 8,542 | -| torch | 1.13.1+cu117 | xformers | 4 | 25.8 | 12,819 | 8,269 | -| torch | 2.0.0+cu117 | compile | 4 | 22.2 | 14,637 | 6,583 | -| torch | 2.0.0+cu117 | default | 4 | 25.2 | 14,409 | 6,355 | -| onnxruntime | 1.14.1 | CUDA | 8 | 46.4 | 6,779 | 6,779 | -| torch | 1.13.1+cu117 | xformers | 8 | 51.4 | 14,827 | 9,001 | -| torch | 2.0.0+cu117 | compile | 8 | 46.5 | 12,595 | 10,171 | -| torch | 2.0.0+cu117 | default | 8 | 50.7 | 11,955 | 9,531 | - #### Results of MI250X, 1 GCD (Ubuntu 20.04) | engine | version | provider | batch size | average latency | first run memory MB | second run memory MB | diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index 1f1db914e274b..6c337af78e0a9 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -301,67 +301,95 @@ def run_ort( return result -def export_and_run_ort( - version: str, - provider: str, - batch_size: int, - disable_safety_checker: bool, - height: int, - width: int, - steps: int, - num_prompts: int, - batch_count: int, - start_memory, - memory_monitor_type, - enable_cuda_graph: bool, +def get_optimum_ort_pipeline( + model_name: str, + directory: str, + provider="CUDAExecutionProvider", + disable_safety_checker: bool = True, ): - assert provider == "CUDAExecutionProvider" + from optimum.onnxruntime import ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline + + if directory is not None and os.path.exists(directory): + if "xl" in model_name: + pipeline = ORTStableDiffusionXLPipeline.from_pretrained( + directory, + provider=provider, + session_options=None, + use_io_binding=False, + ) + else: + pipeline = ORTStableDiffusionPipeline.from_pretrained( + directory, + provider=provider, + use_io_binding=False, + ) + elif "xl" in model_name: + pipeline = ORTStableDiffusionXLPipeline.from_pretrained( + model_name, + export=True, + provider=provider, + session_options=None, + use_io_binding=False, + ) + pipeline.save_pretrained(directory) + else: + pipeline = ORTStableDiffusionPipeline.from_pretrained( + model_name, + export=True, + provider=provider, + use_io_binding=False, + ) + pipeline.save_pretrained(directory) - from diffusers import DDIMScheduler - from diffusion_models import PipelineInfo - from onnxruntime_cuda_txt2img import OnnxruntimeCudaStableDiffusionPipeline + if disable_safety_checker: + pipeline.safety_checker = None + pipeline.feature_extractor = None - pipeline_info = PipelineInfo(version) - model_name = pipeline_info.name() + return pipeline - scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") - pipe = OnnxruntimeCudaStableDiffusionPipeline.from_pretrained( - model_name, - scheduler=scheduler, - requires_safety_checker=not disable_safety_checker, - enable_cuda_graph=enable_cuda_graph, - pipeline_info=pipeline_info, - ) - # re-use cached folder to save ONNX models - pipe.set_cached_folder(model_name) +def run_optimum_ort_pipeline( + pipe, + batch_size: int, + image_filename_prefix: str, + height, + width, + steps, + num_prompts, + batch_count, + start_memory, + memory_monitor_type, +): + from optimum.onnxruntime import ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline + + assert isinstance(pipe, (ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline)) - pipe = pipe.to("cuda", torch_dtype=torch.float16) + prompts = example_prompts() def warmup(): - pipe(["warm up"] * batch_size, image_height=height, image_width=width, num_inference_steps=steps) + pipe("warm up", height, width, num_inference_steps=steps, num_images_per_prompt=batch_size) - # Run warm up, and measure GPU memory of two runs - # The first run has algo search so it might need more memory + # Run warm up, and measure GPU memory of two runs. + # The first run has algo search for cuDNN/MIOpen, so it might need more memory. first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) - # An extra warm up run is needed for cuda graph warmup() - image_filename_prefix = get_image_filename_prefix("ort_cuda", model_name, batch_size, disable_safety_checker) - latency_list = [] - prompts, negative_prompt = example_prompts() for i, prompt in enumerate(prompts): if i >= num_prompts: break for j in range(batch_count): inference_start = time.time() images = pipe( - [prompt] * batch_size, - negative_prompt=[negative_prompt] * batch_size, + prompt, + height, + width, num_inference_steps=steps, + negative_prompt=None, + guidance_scale=0.0, # 7.5 + num_images_per_prompt=batch_size, ).images inference_end = time.time() latency = inference_end - inference_start @@ -373,11 +401,8 @@ def warmup(): from onnxruntime import __version__ as ort_version return { - "model_name": model_name, - "engine": "onnxruntime", + "engine": "optimum_ort", "version": ort_version, - "provider": provider.replace("ExecutionProvider", ""), - "directory": pipe.engine_dir, "height": height, "width": width, "steps": steps, @@ -388,13 +413,13 @@ def warmup(): "median_latency": statistics.median(latency_list), "first_run_memory_MB": first_run_memory, "second_run_memory_MB": second_run_memory, - "disable_safety_checker": disable_safety_checker, - "enable_cuda_graph": enable_cuda_graph, } -def run_ort_trt( - version: str, +def run_optimum_ort( + model_name: str, + directory: str, + provider: str, batch_size: int, disable_safety_checker: bool, height: int, @@ -404,92 +429,36 @@ def run_ort_trt( batch_count: int, start_memory, memory_monitor_type, - max_batch_size: int, - enable_cuda_graph: bool, ): - from diffusers import DDIMScheduler - from diffusion_models import PipelineInfo - from onnxruntime_tensorrt_txt2img import OnnxruntimeTensorRTStableDiffusionPipeline - - pipeline_info = PipelineInfo(version) - model_name = pipeline_info.name() - - assert batch_size <= max_batch_size + load_start = time.time() + pipe = get_optimum_ort_pipeline(model_name, directory, provider, disable_safety_checker) + load_end = time.time() + print(f"Model loading took {load_end - load_start} seconds") - scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") - pipe = OnnxruntimeTensorRTStableDiffusionPipeline.from_pretrained( - model_name, - revision="fp16", - torch_dtype=torch.float16, - scheduler=scheduler, - requires_safety_checker=not disable_safety_checker, - image_height=height, - image_width=width, - max_batch_size=max_batch_size, - onnx_opset=17, - enable_cuda_graph=enable_cuda_graph, - pipeline_info=pipeline_info, + image_filename_prefix = get_image_filename_prefix("optimum", model_name, batch_size, disable_safety_checker) + result = run_optimum_ort_pipeline( + pipe, + batch_size, + image_filename_prefix, + height, + width, + steps, + num_prompts, + batch_count, + start_memory, + memory_monitor_type, ) - # re-use cached folder to save ONNX models and TensorRT Engines - pipe.set_cached_folder(model_name, revision="fp16") - - pipe = pipe.to("cuda") - - def warmup(): - pipe(["warm up"] * batch_size, negative_prompt=["negative"] * batch_size, num_inference_steps=steps) - - # Run warm up, and measure GPU memory of two runs - # The first run has algo search so it might need more memory - first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) - second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) - - warmup() - - image_filename_prefix = get_image_filename_prefix("ort_trt", model_name, batch_size, disable_safety_checker) - - latency_list = [] - prompts, negative_prompt = example_prompts() - for i, prompt in enumerate(prompts): - if i >= num_prompts: - break - for j in range(batch_count): - inference_start = time.time() - images = pipe( - [prompt] * batch_size, - negative_prompt=[negative_prompt] * batch_size, - num_inference_steps=steps, - ).images - inference_end = time.time() - latency = inference_end - inference_start - latency_list.append(latency) - print(f"Inference took {latency:.3f} seconds") - for k, image in enumerate(images): - image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") - - from tensorrt import __version__ as trt_version - - from onnxruntime import __version__ as ort_version - - return { - "model_name": model_name, - "engine": "onnxruntime", - "version": ort_version, - "provider": f"tensorrt({trt_version})", - "directory": pipe.engine_dir, - "height": height, - "width": width, - "steps": steps, - "batch_size": batch_size, - "batch_count": batch_count, - "num_prompts": num_prompts, - "average_latency": sum(latency_list) / len(latency_list), - "median_latency": statistics.median(latency_list), - "first_run_memory_MB": first_run_memory, - "second_run_memory_MB": second_run_memory, - "disable_safety_checker": disable_safety_checker, - "enable_cuda_graph": enable_cuda_graph, - } + result.update( + { + "model_name": model_name, + "directory": directory, + "provider": provider.replace("ExecutionProvider", ""), + "disable_safety_checker": disable_safety_checker, + "enable_cuda_graph": False, + } + ) + return result def run_ort_trt_static( @@ -523,17 +492,16 @@ def run_ort_trt_static( short_name = pipeline_info.short_name() from engine_builder import EngineType, get_engine_paths - from pipeline_txt2img import Txt2ImgPipeline + from pipeline_stable_diffusion import StableDiffusionPipeline engine_type = EngineType.ORT_TRT onnx_dir, engine_dir, output_dir, framework_model_dir, _ = get_engine_paths(work_dir, pipeline_info, engine_type) # Initialize pipeline - pipeline = Txt2ImgPipeline( + pipeline = StableDiffusionPipeline( pipeline_info, scheduler="DDIM", output_dir=output_dir, - hf_token=None, verbose=False, nvtx_profile=nvtx_profile, max_batch_size=max_batch_size, @@ -551,7 +519,6 @@ def run_ort_trt_static( opt_image_height=height, opt_image_width=width, opt_batch_size=batch_size, - force_engine_rebuild=False, static_batch=True, static_image_shape=True, max_workspace_size=0, @@ -592,15 +559,11 @@ def warmup(): denoising_steps=steps, guidance=7.5, seed=123, - warmup=True, ) - images = pipeline.to_pil_image( - images - ) # include image conversion time to pil image for apple-to-apple compare inference_end = time.time() latency = inference_end - inference_start latency_list.append(latency) - print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") for k, image in enumerate(images): image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") @@ -664,7 +627,7 @@ def run_tensorrt_static( pipeline_info = PipelineInfo(version) from engine_builder import EngineType, get_engine_paths - from pipeline_txt2img import Txt2ImgPipeline + from pipeline_stable_diffusion import StableDiffusionPipeline engine_type = EngineType.TRT onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( @@ -672,11 +635,10 @@ def run_tensorrt_static( ) # Initialize pipeline - pipeline = Txt2ImgPipeline( + pipeline = StableDiffusionPipeline( pipeline_info, scheduler="DDIM", output_dir=output_dir, - hf_token=None, verbose=False, nvtx_profile=nvtx_profile, max_batch_size=max_batch_size, @@ -693,16 +655,10 @@ def run_tensorrt_static( opt_batch_size=batch_size, opt_image_height=height, opt_image_width=width, - force_export=False, - force_optimize=False, - force_build=False, static_batch=True, static_shape=True, - enable_refit=False, - enable_preview=False, enable_all_tactics=False, timing_cache=timing_cache, - onnx_refit_dir=None, ) # activate engines @@ -744,15 +700,11 @@ def warmup(): denoising_steps=steps, guidance=7.5, seed=123, - warmup=True, ) - images = pipeline.to_pil_image( - images - ) # include image conversion time to pil image for apple-to-apple compare inference_end = time.time() latency = inference_end - inference_start latency_list.append(latency) - print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") for k, image in enumerate(images): image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") @@ -828,7 +780,6 @@ def init_pipeline(pipeline_class, pipeline_info): pipeline_info, scheduler="DDIM", output_dir=output_dir, - hf_token=None, verbose=False, nvtx_profile=nvtx_profile, max_batch_size=max_batch_size, @@ -845,66 +796,39 @@ def init_pipeline(pipeline_class, pipeline_info): opt_batch_size=batch_size, opt_image_height=height, opt_image_width=width, - force_export=False, - force_optimize=False, - force_build=False, static_batch=True, static_shape=True, - enable_refit=False, - enable_preview=False, enable_all_tactics=False, timing_cache=timing_cache, - onnx_refit_dir=None, ) return pipeline - from pipeline_img2img_xl import Img2ImgXLPipeline - from pipeline_txt2img_xl import Txt2ImgXLPipeline - - base_pipeline_info = PipelineInfo(version) - demo_base = init_pipeline(Txt2ImgXLPipeline, base_pipeline_info) + from pipeline_stable_diffusion import StableDiffusionPipeline - refiner_pipeline_info = PipelineInfo(version, is_refiner=True) - demo_refiner = init_pipeline(Img2ImgXLPipeline, refiner_pipeline_info) + pipeline_info = PipelineInfo(version) + pipeline = init_pipeline(StableDiffusionPipeline, pipeline_info) - max_device_memory = max(demo_base.backend.max_device_memory(), demo_refiner.backend.max_device_memory()) + max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory()) _, shared_device_memory = cudart.cudaMalloc(max_device_memory) - demo_base.backend.activate_engines(shared_device_memory) - demo_refiner.backend.activate_engines(shared_device_memory) + pipeline.backend.activate_engines(shared_device_memory) # Here we use static batch and image size, so the resource allocation only need done once. # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency. - demo_base.load_resources(image_height, image_width, batch_size) - demo_refiner.load_resources(image_height, image_width, batch_size) + pipeline.load_resources(image_height, image_width, batch_size) - def run_sd_xl_inference(prompt, negative_prompt, seed=None, warmup=False): - images, time_base = demo_base.run( + def run_sd_xl_inference(prompt, negative_prompt, seed=None): + return pipeline.run( prompt, negative_prompt, image_height, image_width, denoising_steps=steps, guidance=5.0, - warmup=warmup, seed=seed, - return_type="latent", ) - images, time_refiner = demo_refiner.run( - prompt, - negative_prompt, - images, - image_height, - image_width, - denoising_steps=steps, - guidance=5.0, - warmup=warmup, - seed=seed, - ) - return images, time_base + time_refiner - def warmup(): - run_sd_xl_inference(["warm up"] * batch_size, ["negative"] * batch_size, warmup=True) + run_sd_xl_inference(["warm up"] * batch_size, ["negative"] * batch_size) # Run warm up, and measure GPU memory of two runs # The first run has algo search so it might need more memory @@ -913,7 +837,7 @@ def warmup(): warmup() - model_name = refiner_pipeline_info.name() + model_name = pipeline_info.name() image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, disable_safety_checker) latency_list = [] @@ -926,23 +850,17 @@ def warmup(): # Use warmup mode here since non-warmup mode will save image to disk. if nvtx_profile: cudart.cudaProfilerStart() - images, pipeline_time = run_sd_xl_inference( - [prompt] * batch_size, [negative_prompt] * batch_size, seed=123, warmup=True - ) + images, pipeline_time = run_sd_xl_inference([prompt] * batch_size, [negative_prompt] * batch_size, seed=123) if nvtx_profile: cudart.cudaProfilerStop() - images = demo_refiner.to_pil_image( - images - ) # include image conversion time to pil image for apple-to-apple compare inference_end = time.time() latency = inference_end - inference_start latency_list.append(latency) - print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") for k, image in enumerate(images): image.save(f"{image_filename_prefix}_{i}_{j}_{k}.png") - demo_base.teardown() - demo_refiner.teardown() + pipeline.teardown() return { "model_name": model_name, @@ -979,97 +897,39 @@ def run_ort_trt_xl( nvtx_profile: bool = False, use_cuda_graph=True, ): - from cuda import cudart + from demo_utils import initialize_pipeline + from engine_builder import EngineType + + pipeline = initialize_pipeline( + version=version, + engine_type=EngineType.ORT_TRT, + work_dir=work_dir, + height=height, + width=width, + use_cuda_graph=use_cuda_graph, + max_batch_size=max_batch_size, + opt_batch_size=batch_size, + ) - # Validate image dimensions - image_height = height - image_width = width - if image_height % 8 != 0 or image_width % 8 != 0: - raise ValueError( - f"Image height and width have to be divisible by 8 but specified as: {image_height} and {image_width}." - ) + from cuda import cudart assert batch_size <= max_batch_size - from engine_builder import EngineType, get_engine_paths - - def init_pipeline(pipeline_class, pipeline_info): - engine_type = EngineType.ORT_TRT - - onnx_dir, engine_dir, output_dir, framework_model_dir, _ = get_engine_paths( - work_dir, pipeline_info, engine_type - ) - - # Initialize pipeline - pipeline = pipeline_class( - pipeline_info, - scheduler="DDIM", - output_dir=output_dir, - hf_token=None, - verbose=False, - nvtx_profile=nvtx_profile, - max_batch_size=max_batch_size, - use_cuda_graph=use_cuda_graph, - framework_model_dir=framework_model_dir, - engine_type=engine_type, - ) - - pipeline.backend.build_engines( - engine_dir, - framework_model_dir, - onnx_dir, - 17, - opt_image_height=height, - opt_image_width=width, - opt_batch_size=batch_size, - force_engine_rebuild=False, - static_batch=True, - static_image_shape=True, - max_workspace_size=0, - device_id=torch.cuda.current_device(), # TODO: might not work with CUDA_VISIBLE_DEVICES - ) - return pipeline - - from diffusion_models import PipelineInfo - from pipeline_img2img_xl import Img2ImgXLPipeline - from pipeline_txt2img_xl import Txt2ImgXLPipeline - - base_pipeline_info = PipelineInfo(version) - demo_base = init_pipeline(Txt2ImgXLPipeline, base_pipeline_info) - - refiner_pipeline_info = PipelineInfo(version, is_refiner=True) - demo_refiner = init_pipeline(Img2ImgXLPipeline, refiner_pipeline_info) - - demo_base.load_resources(image_height, image_width, batch_size) - demo_refiner.load_resources(image_height, image_width, batch_size) + pipeline.load_resources(height, width, batch_size) - def run_sd_xl_inference(prompt, negative_prompt, seed=None, warmup=False): - images, time_base = demo_base.run( + def run_sd_xl_inference(prompt, negative_prompt, seed=None): + return pipeline.run( prompt, negative_prompt, - image_height, - image_width, - denoising_steps=steps, - guidance=5.0, - warmup=warmup, - seed=seed, - return_type="latent", - ) - images, time_refiner = demo_refiner.run( - prompt, - negative_prompt, - images, - image_height, - image_width, + height, + width, denoising_steps=steps, guidance=5.0, - warmup=warmup, seed=seed, ) - return images, time_base + time_refiner def warmup(): - run_sd_xl_inference(["warm up"] * batch_size, ["negative"] * batch_size, warmup=True) + run_sd_xl_inference(["warm up"] * batch_size, ["negative"] * batch_size) # Run warm up, and measure GPU memory of two runs # The first run has algo search so it might need more memory @@ -1078,7 +938,7 @@ def warmup(): warmup() - model_name = refiner_pipeline_info.name() + model_name = pipeline.pipeline_info.name() image_filename_prefix = get_image_filename_prefix("ort_trt", model_name, batch_size, disable_safety_checker) latency_list = [] @@ -1091,25 +951,19 @@ def warmup(): # Use warmup mode here since non-warmup mode will save image to disk. if nvtx_profile: cudart.cudaProfilerStart() - images, pipeline_time = run_sd_xl_inference( - [prompt] * batch_size, [negative_prompt] * batch_size, seed=123, warmup=True - ) + images, pipeline_time = run_sd_xl_inference([prompt] * batch_size, [negative_prompt] * batch_size, seed=123) if nvtx_profile: cudart.cudaProfilerStop() - images = demo_refiner.to_pil_image( - images - ) # include image conversion time to pil image for apple-to-apple compare inference_end = time.time() latency = inference_end - inference_start latency_list.append(latency) - print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") for k, image in enumerate(images): filename = f"{image_filename_prefix}_{i}_{j}_{k}.png" image.save(filename) print("Image saved to", filename) - demo_base.teardown() - demo_refiner.teardown() + pipeline.teardown() from tensorrt import __version__ as trt_version @@ -1209,7 +1063,7 @@ def parse_arguments(): required=False, type=str, default="onnxruntime", - choices=["onnxruntime", "torch", "tensorrt"], + choices=["onnxruntime", "optimum", "torch", "tensorrt"], help="Engines to benchmark. Default is onnxruntime.", ) @@ -1423,26 +1277,6 @@ def main(): nvtx_profile=False, use_cuda_graph=args.enable_cuda_graph, ) - elif args.tuning: - print( - "Testing OnnxruntimeTensorRTStableDiffusionPipeline with {}.".format( - "static input shape" if args.enable_cuda_graph else "dynamic batch size" - ) - ) - result = run_ort_trt( - version=args.version, - batch_size=args.batch_size, - disable_safety_checker=not args.enable_safety_checker, - height=args.height, - width=args.width, - steps=args.steps, - num_prompts=args.num_prompts, - batch_count=args.batch_count, - start_memory=start_memory, - memory_monitor_type=memory_monitor_type, - max_batch_size=args.max_trt_batch_size, - enable_cuda_graph=args.enable_cuda_graph, - ) else: print("Testing Txt2ImgPipeline with static input shape. Backend is ORT TensorRT EP.") result = run_ort_trt_static( @@ -1461,15 +1295,13 @@ def main(): nvtx_profile=False, use_cuda_graph=args.enable_cuda_graph, ) + elif args.engine == "optimum" and provider == "CUDAExecutionProvider": + if "xl" in args.version: + os.environ["ORT_ENABLE_FUSED_CAUSAL_ATTENTION"] = "1" - elif args.engine == "onnxruntime" and provider == "CUDAExecutionProvider" and args.pipeline is None: - print( - "Testing OnnxruntimeCudaStableDiffusionPipeline with {} input shape. Backend is ORT CUDA EP.".format( - "static" if args.enable_cuda_graph else "dynamic" - ) - ) - result = export_and_run_ort( - version=args.version, + result = run_optimum_ort( + model_name=sd_model, + directory=args.pipeline, provider=provider, batch_size=args.batch_size, disable_safety_checker=not args.enable_safety_checker, @@ -1480,7 +1312,6 @@ def main(): batch_count=args.batch_count, start_memory=start_memory, memory_monitor_type=memory_monitor_type, - enable_cuda_graph=args.enable_cuda_graph, ) elif args.engine == "onnxruntime": assert args.pipeline and os.path.isdir( diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_controlnet.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_controlnet.py index 39b963313ea64..86c6166472f3d 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_controlnet.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark_controlnet.py @@ -9,6 +9,7 @@ from statistics import mean import torch +from demo_utils import PipelineInfo from diffusers import ( AutoencoderKL, ControlNetModel, @@ -16,6 +17,8 @@ EulerAncestralDiscreteScheduler, StableDiffusionXLControlNetPipeline, ) +from engine_builder import EngineType, get_engine_paths +from pipeline_stable_diffusion import StableDiffusionPipeline """ Benchmark script for SDXL-Turbo with control net for engines like PyTorch or Stable Fast. @@ -120,6 +123,112 @@ def load_pipeline(name, engine, use_control_net=False, use_nhwc=False, enable_cu return pipeline +def get_prompt(): + return "little cute gremlin wearing a jacket, cinematic, vivid colors, intricate masterpiece, golden ratio, highly detailed" + + +def load_ort_cuda_pipeline(name, engine, use_control_net=False, enable_cuda_graph=True, work_dir="."): + version = PipelineInfo.supported_models()[name] + guidance_scale = 0.0 + pipeline_info = PipelineInfo( + version, + use_vae=True, + use_fp16_vae=True, + do_classifier_free_guidance=(guidance_scale > 1.0), + controlnet=["canny"] if use_control_net else [], + ) + + engine_type = EngineType.ORT_CUDA if engine == "ort_cuda" else EngineType.ORT_TRT + onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( + work_dir=work_dir, pipeline_info=pipeline_info, engine_type=engine_type + ) + + pipeline = StableDiffusionPipeline( + pipeline_info, + scheduler="EulerA", + max_batch_size=32, + use_cuda_graph=enable_cuda_graph, + framework_model_dir=framework_model_dir, + output_dir=output_dir, + engine_type=engine_type, + ) + + pipeline.backend.build_engines( + engine_dir=engine_dir, + framework_model_dir=framework_model_dir, + onnx_dir=onnx_dir, + force_engine_rebuild=False, + device_id=torch.cuda.current_device(), + ) + + return pipeline + + +def test_ort_cuda( + pipeline, + batch_size=1, + steps=4, + control_image=None, + warmup_runs=3, + test_runs=10, + seed=123, + verbose=False, + image_height=512, + image_width=512, +): + if batch_size > 4 and pipeline.pipeline_info.version == "xl-1.0": + pipeline.backend.enable_vae_slicing() + + pipeline.load_resources(image_height, image_width, batch_size) + + warmup_prompt = "warm up" + for _ in range(warmup_runs): + images, _ = pipeline.run( + [warmup_prompt] * batch_size, + [""] * batch_size, + image_height=image_height, + image_width=image_width, + denoising_steps=steps, + guidance=0.0, + seed=seed, + controlnet_images=[control_image], + controlnet_scales=torch.FloatTensor([0.5]), + output_type="image", + ) + assert len(images) == batch_size + + generator = torch.Generator(device="cuda") + generator.manual_seed(seed) + + prompt = get_prompt() + + latency_list = [] + images = None + for _ in range(test_runs): + torch.cuda.synchronize() + start_time = time.perf_counter() + images, _ = pipeline.run( + [prompt] * batch_size, + [""] * batch_size, + image_height=image_height, + image_width=image_width, + denoising_steps=steps, + guidance=0.0, + seed=seed, + controlnet_images=[control_image], + controlnet_scales=torch.FloatTensor([0.5]), + output_type="pil", + ) + torch.cuda.synchronize() + seconds = time.perf_counter() - start_time + latency_list.append(seconds) + + if verbose: + print(latency_list) + + return images, latency_list + + def test(pipeline, batch_size=1, steps=4, control_image=None, warmup_runs=3, test_runs=10, seed=123, verbose=False): control_net_args = {} if hasattr(pipeline, "controlnet"): @@ -130,33 +239,33 @@ def test(pipeline, batch_size=1, steps=4, control_image=None, warmup_runs=3, tes warmup_prompt = "warm up" for _ in range(warmup_runs): - image = pipeline( + images = pipeline( prompt=warmup_prompt, num_inference_steps=steps, num_images_per_prompt=batch_size, guidance_scale=0.0, **control_net_args, ).images - assert len(image) == batch_size + assert len(images) == batch_size generator = torch.Generator(device="cuda") generator.manual_seed(seed) - prompt = "little cute gremlin wearing a jacket, cinematic, vivid colors, intricate masterpiece, golden ratio, highly detailed" + prompt = get_prompt() latency_list = [] - image = None + images = None for _ in range(test_runs): torch.cuda.synchronize() start_time = time.perf_counter() - image = pipeline( + images = pipeline( prompt=prompt, num_inference_steps=steps, num_images_per_prompt=batch_size, guidance_scale=0.0, generator=generator, **control_net_args, - ).images[0] + ).images torch.cuda.synchronize() seconds = time.perf_counter() - start_time latency_list.append(seconds) @@ -164,7 +273,7 @@ def test(pipeline, batch_size=1, steps=4, control_image=None, warmup_runs=3, tes if verbose: print(latency_list) - return image, latency_list + return images, latency_list def arguments(): @@ -175,17 +284,25 @@ def arguments(): "--engine", type=str, default="torch", - choices=["torch", "stable_fast"], - help="Backend engine: torch or stable_fast", + choices=["torch", "stable_fast", "ort_cuda", "ort_trt"], + help="Backend engine: torch, stable_fast or ort_cuda", ) parser.add_argument( "--name", type=str, + choices=list(PipelineInfo.supported_models().keys()), default="stabilityai/sdxl-turbo", help="Stable diffusion model name. Default is stabilityai/sdxl-turbo", ) + parser.add_argument( + "--work-dir", + type=str, + default=".", + help="working directory for ort_cuda or ort_trt", + ) + parser.add_argument( "--use_control_net", action="store_true", @@ -239,21 +356,39 @@ def main(): args = arguments() with torch.no_grad(): - pipeline = load_pipeline( - args.name, - args.engine, - use_control_net=args.use_control_net, - use_nhwc=args.use_nhwc, - enable_cuda_graph=args.enable_cuda_graph, - ) + if args.engine == "ort_cuda": + pipeline = load_ort_cuda_pipeline( + args.name, + args.engine, + use_control_net=args.use_control_net, + enable_cuda_graph=args.enable_cuda_graph, + work_dir=args.work_dir, + ) + else: + pipeline = load_pipeline( + args.name, + args.engine, + use_control_net=args.use_control_net, + use_nhwc=args.use_nhwc, + enable_cuda_graph=args.enable_cuda_graph, + ) canny_image = get_canny_image() - if args.engine == "stable_fast": + if args.engine == "ort_cuda": + images, latency_list = test_ort_cuda( + pipeline, + args.batch_size, + args.steps, + control_image=canny_image, + warmup_runs=args.warmup_runs, + verbose=args.verbose, + ) + elif args.engine == "stable_fast": from sfast.utils.compute_precision import low_compute_precision with low_compute_precision(): - image, latency_list = test( + images, latency_list = test( pipeline, args.batch_size, args.steps, @@ -262,7 +397,7 @@ def main(): verbose=args.verbose, ) else: - image, latency_list = test( + images, latency_list = test( pipeline, args.batch_size, args.steps, @@ -272,8 +407,8 @@ def main(): ) # Save the first output image to inspect the result. - if image: - image.save( + if images: + images[0].save( f"{args.engine}_{args.name.replace('/', '_')}_{args.batch_size}_{args.steps}_c{int(args.use_control_net)}.png" ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py index c18747d5c6518..40692701c28d6 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -26,15 +26,11 @@ add_controlnet_arguments, arg_parser, get_metadata, - init_pipeline, - max_batch, + load_pipelines, parse_arguments, process_controlnet_arguments, repeat_prompt, ) -from diffusion_models import PipelineInfo -from engine_builder import EngineType, get_engine_type -from pipeline_txt2img import Txt2ImgPipeline if __name__ == "__main__": coloredlogs.install(fmt="%(funcName)20s: %(message)s") @@ -45,83 +41,26 @@ controlnet_images, controlnet_scale = process_controlnet_arguments(args) - prompt, negative_prompt = repeat_prompt(args) - - image_height = args.height - image_width = args.width - - # Register TensorRT plugins - engine_type = get_engine_type(args.engine) - if engine_type == EngineType.TRT: - from trt_utilities import init_trt_plugins - - init_trt_plugins() - - max_batch_size = max_batch(args) + pipeline, refiner = load_pipelines(args) + assert refiner is None + prompt, negative_prompt = repeat_prompt(args) batch_size = len(prompt) - if batch_size > max_batch_size: - raise ValueError( - f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4" - ) - - # For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size. - # Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance. - # This range can cover common used shape of landscape 512x768, portrait 768x512, or square 512x512 and 768x768. - min_image_size = 512 if args.engine != "ORT_CUDA" else 256 - max_image_size = 768 if args.engine != "ORT_CUDA" else 1024 - pipeline_info = PipelineInfo( - args.version, - min_image_size=min_image_size, - max_image_size=max_image_size, - do_classifier_free_guidance=(args.guidance > 1.0), - controlnet=args.controlnet_type, - lora_weights=args.lora_weights, - lora_scale=args.lora_scale, - ) - - # Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to - # optimize the shape used most frequently. We can let user config it when we develop a UI plugin. - # In this demo, we optimize batch size 1 and image size 512x512 (or 768x768 for SD 2.0/2.1) for dynamic engine. - # This is mainly for benchmark purpose to simulate the case that we have no knowledge of user's preference. - opt_batch_size = 1 if args.build_dynamic_batch else batch_size - opt_image_height = pipeline_info.default_image_size() if args.build_dynamic_shape else args.height - opt_image_width = pipeline_info.default_image_size() if args.build_dynamic_shape else args.width - - pipeline = init_pipeline( - Txt2ImgPipeline, - pipeline_info, - engine_type, - args, - max_batch_size, - opt_batch_size, - opt_image_height, - opt_image_width, - ) - - if engine_type == EngineType.TRT: - max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory()) - _, shared_device_memory = cudart.cudaMalloc(max_device_memory) - pipeline.backend.activate_engines(shared_device_memory) - - if engine_type == EngineType.ORT_CUDA and args.enable_vae_slicing: - pipeline.backend.enable_vae_slicing() - - pipeline.load_resources(image_height, image_width, batch_size) + pipeline.load_resources(args.height, args.width, batch_size) def run_inference(warmup=False): return pipeline.run( prompt, negative_prompt, - image_height, - image_width, - warmup=warmup, + args.height, + args.width, denoising_steps=args.denoising_steps, guidance=args.guidance, seed=args.seed, controlnet_images=controlnet_images, controlnet_scales=controlnet_scale, - return_type="image", + show_latency=not warmup, + output_type="pil", ) if not args.disable_cuda_graph: diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py index b691f5115e6d3..19bbb45d77c93 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -26,107 +26,11 @@ add_controlnet_arguments, arg_parser, get_metadata, - init_pipeline, - max_batch, + load_pipelines, parse_arguments, process_controlnet_arguments, repeat_prompt, ) -from diffusion_models import PipelineInfo -from engine_builder import EngineType, get_engine_type -from pipeline_img2img_xl import Img2ImgXLPipeline -from pipeline_txt2img_xl import Txt2ImgXLPipeline - - -def load_pipelines(args, batch_size): - # Register TensorRT plugins - engine_type = get_engine_type(args.engine) - if engine_type == EngineType.TRT: - from trt_utilities import init_trt_plugins - - init_trt_plugins() - - max_batch_size = max_batch(args) - - if batch_size > max_batch_size: - raise ValueError(f"Batch size {batch_size} is larger than allowed {max_batch_size}.") - - # For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size. - # Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance. - # This range can cover most frequent shape of landscape (832x1216), portrait (1216x832) or square (1024x1024). - if args.version == "xl-turbo": - min_image_size = 512 - max_image_size = 768 if args.engine != "ORT_CUDA" else 1024 - else: - min_image_size = 832 if args.engine != "ORT_CUDA" else 512 - max_image_size = 1216 if args.engine != "ORT_CUDA" else 2048 - - # No VAE decoder in base when it outputs latent instead of image. - base_info = PipelineInfo( - args.version, - use_vae=not args.enable_refiner, - min_image_size=min_image_size, - max_image_size=max_image_size, - use_lcm=args.lcm, - do_classifier_free_guidance=(args.guidance > 1.0), - controlnet=args.controlnet_type, - lora_weights=args.lora_weights, - lora_scale=args.lora_scale, - ) - - # Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to - # optimize the shape used most frequently. We can let user config it when we develop a UI plugin. - # In this demo, we optimize batch size 1 and image size 1024x1024 for SD XL dynamic engine. - # This is mainly for benchmark purpose to simulate the case that we have no knowledge of user's preference. - opt_batch_size = 1 if args.build_dynamic_batch else batch_size - opt_image_height = base_info.default_image_size() if args.build_dynamic_shape else args.height - opt_image_width = base_info.default_image_size() if args.build_dynamic_shape else args.width - - base = init_pipeline( - Txt2ImgXLPipeline, - base_info, - engine_type, - args, - max_batch_size, - opt_batch_size, - opt_image_height, - opt_image_width, - ) - - refiner = None - if args.enable_refiner: - refiner_version = "xl-1.0" # Allow SDXL Turbo to use refiner. - refiner_info = PipelineInfo( - refiner_version, is_refiner=True, min_image_size=min_image_size, max_image_size=max_image_size - ) - refiner = init_pipeline( - Img2ImgXLPipeline, - refiner_info, - engine_type, - args, - max_batch_size, - opt_batch_size, - opt_image_height, - opt_image_width, - ) - - if engine_type == EngineType.TRT: - max_device_memory = max(base.backend.max_device_memory(), (refiner or base).backend.max_device_memory()) - _, shared_device_memory = cudart.cudaMalloc(max_device_memory) - base.backend.activate_engines(shared_device_memory) - if refiner: - refiner.backend.activate_engines(shared_device_memory) - - if engine_type == EngineType.ORT_CUDA: - enable_vae_slicing = args.enable_vae_slicing - if batch_size > 4 and not enable_vae_slicing and (args.height >= 1024 and args.width >= 1024): - print( - "Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4 and resolution >= 1024." - ) - enable_vae_slicing = True - if enable_vae_slicing: - (refiner or base).backend.enable_vae_slicing() - return base, refiner def run_pipelines( @@ -145,13 +49,13 @@ def run_base_and_refiner(warmup=False): negative_prompt, image_height, image_width, - warmup=warmup, denoising_steps=args.denoising_steps, guidance=args.guidance, seed=args.seed, controlnet_images=controlnet_image, controlnet_scales=controlnet_scale, - return_type="latent" if refiner else "image", + show_latency=not warmup, + output_type="latent" if refiner else "pil", ) if refiner is None: return images, base_perf @@ -162,14 +66,14 @@ def run_base_and_refiner(warmup=False): images, refiner_perf = refiner.run( prompt, negative_prompt, - images, image_height, image_width, - warmup=warmup, denoising_steps=args.refiner_denoising_steps, + image=images, strength=args.strength, guidance=args.refiner_guidance, seed=seed, + show_latency=not warmup, ) perf_data = None @@ -309,6 +213,32 @@ def run_dynamic_shape_demo(args): refiner.teardown() +def run_turbo_demo(args): + """Run demo of generating images with test prompts with ORT CUDA provider.""" + args.engine = "ORT_CUDA" + args.disable_cuda_graph = True + base, refiner = load_pipelines(args, 1) + + from datasets import load_dataset + + dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts") + num_rows = dataset["test"].num_rows + batch_size = args.batch_size + num_batch = int(num_rows / batch_size) + args.batch_size = 1 + for i in range(num_batch): + args.prompt = [dataset["test"][i]["Prompt"] for i in range(i * batch_size, (i + 1) * batch_size)] + base.set_scheduler(args.scheduler) + if refiner: + refiner.set_scheduler(args.refiner_scheduler) + prompt, negative_prompt = repeat_prompt(args) + run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False) + + base.teardown() + if refiner: + refiner.teardown() + + if __name__ == "__main__": coloredlogs.install(fmt="%(funcName)20s: %(message)s") @@ -318,6 +248,9 @@ def run_dynamic_shape_demo(args): no_prompt = isinstance(args.prompt, list) and len(args.prompt) == 1 and not args.prompt[0] if no_prompt: - run_dynamic_shape_demo(args) + if args.version == "xl-turbo": + run_turbo_demo(args) + else: + run_dynamic_shape_demo(args) else: run_demo(args) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py index c0395b5e4642f..609853c80ae16 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -29,9 +29,11 @@ import cv2 import numpy as np import torch +from cuda import cudart from diffusion_models import PipelineInfo -from engine_builder import EngineType, get_engine_paths +from engine_builder import EngineType, get_engine_paths, get_engine_type from PIL import Image +from pipeline_stable_diffusion import StableDiffusionPipeline class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter): @@ -40,7 +42,8 @@ class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatte def arg_parser(description: str): return argparse.ArgumentParser( - description=description, formatter_class=RawTextArgumentDefaultsHelpFormatter, add_help=False + description=description, + formatter_class=RawTextArgumentDefaultsHelpFormatter, ) @@ -65,8 +68,7 @@ def set_default_arguments(args): def parse_arguments(is_xl: bool, parser): - engines = ["ORT_CUDA", "ORT_TRT", "TRT"] - parser.add_argument("--help", action="store_true", help="show this help message and exit") + engines = ["ORT_CUDA", "ORT_TRT", "TRT", "TORCH"] parser.add_argument( "-e", @@ -89,14 +91,14 @@ def parse_arguments(is_xl: bool, parser): ) parser.add_argument( - "-h", + "-y", "--height", type=int, default=None, help="Height of image to generate (must be multiple of 8).", ) parser.add_argument( - "-w", "--width", type=int, default=None, help="Height of image to generate (must be multiple of 8)." + "-x", "--width", type=int, default=None, help="Height of image to generate (must be multiple of 8)." ) parser.add_argument( @@ -115,6 +117,13 @@ def parse_arguments(is_xl: bool, parser): help="Root Directory to store torch or ONNX models, built engines and output images etc.", ) + parser.add_argument( + "-i", + "--engine-dir", + default=None, + help="Root Directory to store built engines or optimized ONNX models etc.", + ) + parser.add_argument("prompt", nargs="*", default=[""], help="Text prompt(s) to guide image generation.") parser.add_argument( @@ -208,23 +217,8 @@ def parse_arguments(is_xl: bool, parser): choices=range(14, 18), help="Select ONNX opset version to target for exported models.", ) - parser.add_argument( - "--force-onnx-export", action="store_true", help="Force ONNX export of CLIP, UNET, and VAE models." - ) - parser.add_argument( - "--force-onnx-optimize", action="store_true", help="Force ONNX optimizations for CLIP, UNET, and VAE models." - ) - - # Framework model ckpt - parser.add_argument( - "--framework-model-dir", - default="pytorch_model", - help="Directory for HF saved models. Default is pytorch_model.", - ) - parser.add_argument("--hf-token", type=str, help="HuggingFace API access token for downloading model checkpoints.") # Engine build options. - parser.add_argument("--force-engine-build", action="store_true", help="Force rebuilding the TensorRT engine.") parser.add_argument( "-db", "--build-dynamic-batch", @@ -252,34 +246,14 @@ def parse_arguments(is_xl: bool, parser): # TensorRT only options group = parser.add_argument_group("Options for TensorRT (--engine=TRT) only") - group.add_argument("--onnx-refit-dir", help="ONNX models to load the weights from.") - group.add_argument( - "--build-enable-refit", action="store_true", help="Enable Refit option in TensorRT engines during build." - ) - group.add_argument( - "--build-preview-features", action="store_true", help="Build TensorRT engines with preview features." - ) group.add_argument( "--build-all-tactics", action="store_true", help="Build TensorRT engines using all tactic sources." ) args = parser.parse_args() - if args.help: - parser.print_help() - sys.exit() set_default_arguments(args) - if ( - args.engine in ["ORT_CUDA", "ORT_TRT"] - and (args.force_onnx_export or args.force_onnx_optimize) - and not args.force_engine_build - ): - raise ValueError( - "For ORT_CUDA or ORT_TRT, --force_onnx_export and --force_onnx_optimize are not supported. " - "Please use --force_engine_build instead." - ) - # Validate image dimensions if args.height % 64 != 0 or args.width % 64 != 0: raise ValueError( @@ -404,79 +378,217 @@ def repeat_prompt(args): return prompt, negative_prompt -def init_pipeline( - pipeline_class, pipeline_info, engine_type, args, max_batch_size, opt_batch_size, opt_image_height, opt_image_width +def initialize_pipeline( + version="xl-turbo", + engine_type=EngineType.ORT_CUDA, + work_dir: str = ".", + engine_dir=None, + onnx_opset: int = 17, + scheduler="EulerA", + height=512, + width=512, + nvtx_profile=False, + use_cuda_graph=True, + build_dynamic_batch=False, + build_dynamic_shape=False, + min_image_size: int = 512, + max_image_size: int = 1024, + max_batch_size: int = 16, + opt_batch_size: int = 1, + build_all_tactics=False, + do_classifier_free_guidance=False, + lcm=False, + controlnet=None, + lora_weights=None, + lora_scale=1.0, + use_fp16_vae=True, + use_vae=True, ): + pipeline_info = PipelineInfo( + version, + use_vae=use_vae, + min_image_size=min_image_size, + max_image_size=max_image_size, + use_fp16_vae=use_fp16_vae, + use_lcm=lcm, + do_classifier_free_guidance=do_classifier_free_guidance, + controlnet=controlnet, + lora_weights=lora_weights, + lora_scale=lora_scale, + ) + + input_engine_dir = engine_dir + onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( - work_dir=args.work_dir, pipeline_info=pipeline_info, engine_type=engine_type + work_dir=work_dir, pipeline_info=pipeline_info, engine_type=engine_type ) - # Initialize demo - pipeline = pipeline_class( + pipeline = StableDiffusionPipeline( pipeline_info, - scheduler=args.refiner_scheduler if pipeline_info.is_xl_refiner() else args.scheduler, + scheduler=scheduler, output_dir=output_dir, - hf_token=args.hf_token, verbose=False, - nvtx_profile=args.nvtx_profile, + nvtx_profile=nvtx_profile, max_batch_size=max_batch_size, - use_cuda_graph=not args.disable_cuda_graph, + use_cuda_graph=use_cuda_graph, framework_model_dir=framework_model_dir, engine_type=engine_type, ) + import_engine_dir = None + if input_engine_dir: + if not os.path.exists(input_engine_dir): + raise RuntimeError(f"--engine_dir directory does not exist: {input_engine_dir}") + + # Support importing from optimized diffusers onnx pipeline + if engine_type == EngineType.ORT_CUDA and os.path.exists(os.path.join(input_engine_dir, "model_index.json")): + import_engine_dir = input_engine_dir + else: + engine_dir = input_engine_dir + + opt_image_height = pipeline_info.default_image_size() if build_dynamic_shape else height + opt_image_width = pipeline_info.default_image_size() if build_dynamic_shape else width + if engine_type == EngineType.ORT_CUDA: - # Build CUDA EP engines and load pytorch modules pipeline.backend.build_engines( engine_dir=engine_dir, framework_model_dir=framework_model_dir, onnx_dir=onnx_dir, - tmp_dir=os.path.join(args.work_dir or ".", engine_type.name, pipeline_info.short_name(), "tmp"), - force_engine_rebuild=args.force_engine_build, + tmp_dir=os.path.join(work_dir or ".", engine_type.name, pipeline_info.short_name(), "tmp"), device_id=torch.cuda.current_device(), + import_engine_dir=import_engine_dir, ) elif engine_type == EngineType.ORT_TRT: - # Build TensorRT EP engines and load pytorch modules pipeline.backend.build_engines( engine_dir, framework_model_dir, onnx_dir, - args.onnx_opset, + onnx_opset, opt_image_height=opt_image_height, opt_image_width=opt_image_width, opt_batch_size=opt_batch_size, - force_engine_rebuild=args.force_engine_build, - static_batch=not args.build_dynamic_batch, - static_image_shape=not args.build_dynamic_shape, + static_batch=not build_dynamic_batch, + static_image_shape=not build_dynamic_shape, max_workspace_size=0, device_id=torch.cuda.current_device(), timing_cache=timing_cache, ) elif engine_type == EngineType.TRT: - # Load TensorRT engines and pytorch modules pipeline.backend.load_engines( engine_dir, framework_model_dir, onnx_dir, - args.onnx_opset, + onnx_opset, opt_batch_size=opt_batch_size, opt_image_height=opt_image_height, opt_image_width=opt_image_width, - force_export=args.force_onnx_export, - force_optimize=args.force_onnx_optimize, - force_build=args.force_engine_build, - static_batch=not args.build_dynamic_batch, - static_shape=not args.build_dynamic_shape, - enable_refit=args.build_enable_refit, - enable_preview=args.build_preview_features, - enable_all_tactics=args.build_all_tactics, + static_batch=not build_dynamic_batch, + static_shape=not build_dynamic_shape, + enable_all_tactics=build_all_tactics, timing_cache=timing_cache, - onnx_refit_dir=args.onnx_refit_dir, ) + elif engine_type == EngineType.TORCH: + pipeline.backend.build_engines(framework_model_dir) + else: + raise RuntimeError("invalid engine type") return pipeline +def load_pipelines(args, batch_size=None): + engine_type = get_engine_type(args.engine) + + # Register TensorRT plugins + if engine_type == EngineType.TRT: + from trt_utilities import init_trt_plugins + + init_trt_plugins() + + max_batch_size = max_batch(args) + + if batch_size is None: + assert isinstance(args.prompt, list) + batch_size = len(args.prompt) * args.batch_size + + if batch_size > max_batch_size: + raise ValueError(f"Batch size {batch_size} is larger than allowed {max_batch_size}.") + + # For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size. + # Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance. + # This range can cover most frequent shape of landscape (832x1216), portrait (1216x832) or square (1024x1024). + if args.version == "xl-turbo": + min_image_size = 512 + max_image_size = 768 if args.engine != "ORT_CUDA" else 1024 + elif args.version == "xl-1.0": + min_image_size = 832 if args.engine != "ORT_CUDA" else 512 + max_image_size = 1216 if args.engine != "ORT_CUDA" else 2048 + else: + # This range can cover common used shape of landscape 512x768, portrait 768x512, or square 512x512 and 768x768. + min_image_size = 512 if args.engine != "ORT_CUDA" else 256 + max_image_size = 768 if args.engine != "ORT_CUDA" else 1024 + + params = { + "version": args.version, + "engine_type": engine_type, + "work_dir": args.work_dir, + "engine_dir": args.engine_dir, + "onnx_opset": args.onnx_opset, + "scheduler": args.scheduler, + "height": args.height, + "width": args.width, + "nvtx_profile": args.nvtx_profile, + "use_cuda_graph": not args.disable_cuda_graph, + "build_dynamic_batch": args.build_dynamic_batch, + "build_dynamic_shape": args.build_dynamic_shape, + "min_image_size": min_image_size, + "max_image_size": max_image_size, + "max_batch_size": max_batch_size, + "opt_batch_size": 1 if args.build_dynamic_batch else batch_size, + "build_all_tactics": args.build_all_tactics, + "do_classifier_free_guidance": args.guidance > 1.0, + "controlnet": args.controlnet_type, + "lora_weights": args.lora_weights, + "lora_scale": args.lora_scale, + "use_fp16_vae": "xl" in args.version, + "use_vae": True, + } + + if "xl" in args.version: + params["lcm"] = args.lcm + params["use_vae"] = not args.enable_refiner + base = initialize_pipeline(**params) + + refiner = None + if "xl" in args.version and args.enable_refiner: + params["version"] = "xl-1.0" # Allow SDXL Turbo to use refiner. + params["scheduler"] = args.refiner_scheduler + params["do_classifier_free_guidance"] = args.refiner_guidance > 1.0 + params["lcm"] = False + params["controlnet"] = None + params["lora_weights"] = None + params["use_vae"] = True + params["use_fp16_vae"] = True + refiner = initialize_pipeline(**params) + + if engine_type == EngineType.TRT: + max_device_memory = max(base.backend.max_device_memory(), (refiner or base).backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + base.backend.activate_engines(shared_device_memory) + if refiner: + refiner.backend.activate_engines(shared_device_memory) + + if engine_type == EngineType.ORT_CUDA: + enable_vae_slicing = args.enable_vae_slicing + if batch_size > 4 and not enable_vae_slicing and (args.height >= 1024 and args.width >= 1024): + print( + "Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4 and resolution >= 1024." + ) + enable_vae_slicing = True + if enable_vae_slicing: + (refiner or base).backend.enable_vae_slicing() + return base, refiner + + def get_depth_image(image): """ Create depth map for SDXL depth control net. @@ -542,7 +654,7 @@ def add_controlnet_arguments(parser, is_xl: bool = False): """ Add control net related arguments. """ - group = parser.add_argument_group("Options for ControlNet (only supports SD 1.5 or XL).") + group = parser.add_argument_group("Options for ControlNet (supports 1.5, sd-turbo, xl-turbo, xl-1.0).") group.add_argument( "-ci", @@ -622,7 +734,7 @@ def process_controlnet_arguments(args): if len(args.controlnet_type) == 0: return None, None - if args.version not in ["1.5", "xl-1.0", "xl-turbo"]: + if args.version not in ["1.5", "xl-1.0", "xl-turbo", "sd-turbo"]: raise ValueError("This demo only supports ControlNet in Stable Diffusion 1.5, XL or Turbo.") is_xl = "xl" in args.version diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py index 9f3c5a8c938c6..10af22e44d3a5 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -87,7 +87,7 @@ def __init__( version: str, is_inpaint: bool = False, is_refiner: bool = False, - use_vae=False, + use_vae=True, # TODO: this has couple with output type of pipeline min_image_size=256, max_image_size=1024, use_fp16_vae=True, @@ -161,6 +161,23 @@ def custom_unet(self) -> Optional[str]: def supported_versions(is_xl: bool): return ["xl-1.0", "xl-turbo"] if is_xl else ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base", "sd-turbo"] + @staticmethod + def supported_models(): + return { + "CompVis/stable-diffusion-v1-4": "1.4", + "runwayml/stable-diffusion-v1-5": "1.5", + "stabilityai/stable-diffusion-2-base": "2.0-base", + "stabilityai/stable-diffusion-2": "2.0", + "stabilityai/stable-diffusion-2-1": "2.1", + "stabilityai/stable-diffusion-2-1-base": "2.1", + "stabilityai/stable-diffusion-xl-base-1.0": "xl-1.0", + "stabilityai/stable-diffusion-xl-refiner-1.0": "xl-1.0", + "stabilityai/sdxl-turbo": "xl-turbo", + "stabilityai/sd-turbo": "sd-turbo", + # "runwayml/stable-diffusion-inpainting": "1.5", + # "stabilityai/stable-diffusion-2-inpainting": "2.0", + } + def name(self) -> str: if self.version == "1.4": if self.is_inpaint(): @@ -329,7 +346,7 @@ def get_ort_optimizer(self): def get_model(self): return self.model - def from_pretrained(self, model_class, framework_model_dir, hf_token, subfolder=None, model_name=None, **kwargs): + def from_pretrained(self, model_class, framework_model_dir, subfolder=None, model_name=None, **kwargs): if model_name is None: model_name = self.pipeline_info.name() @@ -343,7 +360,6 @@ def from_pretrained(self, model_class, framework_model_dir, hf_token, subfolder= model_name, subfolder=subfolder, use_safetensors=self.pipeline_info.use_safetensors(), - use_auth_token=hf_token, **kwargs, ).to(self.device) model.save_pretrained(model_dir) @@ -353,7 +369,7 @@ def from_pretrained(self, model_class, framework_model_dir, hf_token, subfolder= model = model_class.from_pretrained(model_dir).to(self.device) return model - def load_model(self, framework_model_dir: str, hf_token: str, subfolder: str): + def load_model(self, framework_model_dir: str, subfolder: str): pass def get_input_names(self) -> List[str]: @@ -405,8 +421,7 @@ def get_shape_dict(self, batch_size, image_height, image_width): def fp32_input_output_names(self) -> List[str]: """For CUDA EP, we export ONNX model with FP32 first, then convert it to mixed precision model. - This is a list of input or output names that are kept as float32 during converting. - For the first version, we will use same data type as TensorRT. + This is a list of input or output names that are kept as float32 in optimized model. """ return [] @@ -519,7 +534,7 @@ def get_output_names(self): return ["text_embeddings"] def get_dynamic_axes(self): - return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} + return {"input_ids": {0: "B", 1: "S"}, "text_embeddings": {0: "B", 1: "S"}} def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): self.check_dims(batch_size, image_height, image_width) @@ -581,7 +596,7 @@ def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path, onnx.helper.make_tensor_value_info( graph_output_name, graph.output[0].type.tensor_type.elem_type, - ["B", self.text_maxlen, self.embedding_dim], + ["B", "S", self.embedding_dim], ) ) @@ -660,8 +675,8 @@ def optimize_trt(self, input_onnx_path, optimized_onnx_path): else: onnx.save(onnx_opt_graph, optimized_onnx_path) - def load_model(self, framework_model_dir, hf_token, subfolder="text_encoder"): - return self.from_pretrained(CLIPTextModel, framework_model_dir, hf_token, subfolder) + def load_model(self, framework_model_dir, subfolder="text_encoder"): + return self.from_pretrained(CLIPTextModel, framework_model_dir, subfolder) class CLIPWithProj(CLIP): @@ -682,8 +697,8 @@ def __init__( clip_skip=clip_skip, ) - def load_model(self, framework_model_dir, hf_token, subfolder="text_encoder_2"): - return self.from_pretrained(CLIPTextModelWithProjection, framework_model_dir, hf_token, subfolder) + def load_model(self, framework_model_dir, subfolder="text_encoder_2"): + return self.from_pretrained(CLIPTextModelWithProjection, framework_model_dir, subfolder) def get_shape_dict(self, batch_size, image_height, image_width): self.check_dims(batch_size, image_height, image_width) @@ -816,10 +831,10 @@ def __init__( self.unet_dim = unet_dim self.controlnet = pipeline_info.controlnet_name() - def load_model(self, framework_model_dir, hf_token, subfolder="unet"): + def load_model(self, framework_model_dir, subfolder="unet"): options = {"variant": "fp16", "torch_dtype": torch.float16} - model = self.from_pretrained(UNet2DConditionModel, framework_model_dir, hf_token, subfolder, **options) + model = self.from_pretrained(UNet2DConditionModel, framework_model_dir, subfolder, **options) if self.controlnet: controlnet_list = [] @@ -827,7 +842,6 @@ def load_model(self, framework_model_dir, hf_token, subfolder="unet"): controlnet = self.from_pretrained( ControlNetModel, framework_model_dir, - hf_token, subfolder=None, model_name=name, torch_dtype=torch.float16, @@ -929,10 +943,8 @@ def get_sample_input(self, batch_size, image_height, image_width): dtype = torch.float16 if self.fp16 else torch.float32 m = self.get_batch_multiplier() output = ( - torch.randn( - m * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device - ), - torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn(m * batch_size, self.unet_dim, latent_height, latent_width, dtype=dtype, device=self.device), + torch.tensor([1.0], dtype=dtype, device=self.device), torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), ) @@ -946,9 +958,6 @@ def get_sample_input(self, batch_size, image_height, image_width): ) return output - def fp32_input_output_names(self) -> List[str]: - return ["sample", "timestep"] - class UNetXL(BaseModel): def __init__( @@ -977,7 +986,7 @@ def __init__( self.custom_unet = pipeline_info.custom_unet() self.controlnet = pipeline_info.controlnet_name() - def load_model(self, framework_model_dir, hf_token, subfolder="unet", always_download_fp16=True): + def load_model(self, framework_model_dir, subfolder="unet", always_download_fp16=True): options = {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 or always_download_fp16 else {} if self.custom_unet: @@ -989,7 +998,7 @@ def load_model(self, framework_model_dir, hf_token, subfolder="unet", always_dow unet = UNet2DConditionModel.from_pretrained(model_dir, **options) model = unet.to(self.device) else: - model = self.from_pretrained(UNet2DConditionModel, framework_model_dir, hf_token, subfolder, **options) + model = self.from_pretrained(UNet2DConditionModel, framework_model_dir, subfolder, **options) if always_download_fp16 and not self.fp16: model = model.to(torch.float32) @@ -1107,9 +1116,9 @@ def get_sample_input(self, batch_size, image_height, image_width): if not self.controlnet: return ( torch.randn( - m * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + m * batch_size, self.unet_dim, latent_height, latent_width, dtype=dtype, device=self.device ), - torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.tensor([1.0], dtype=dtype, device=self.device), torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), { "added_cond_kwargs": { @@ -1122,9 +1131,9 @@ def get_sample_input(self, batch_size, image_height, image_width): # sample, timestep, encoder_hidden_states, text_embeds, time_ids, controlnet_images, controlnet_scales, return ( torch.randn( - m * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + m * batch_size, self.unet_dim, latent_height, latent_width, dtype=dtype, device=self.device ), - torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.tensor([1.0], dtype=dtype, device=self.device), torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), torch.randn(m * batch_size, 1280, dtype=dtype, device=self.device), torch.randn(m * batch_size, self.time_dim, dtype=dtype, device=self.device), @@ -1134,9 +1143,6 @@ def get_sample_input(self, batch_size, image_height, image_width): torch.randn(len(self.controlnet), dtype=dtype, device=self.device), ) - def fp32_input_output_names(self) -> List[str]: - return ["sample", "timestep"] - # VAE Decoder class VAE(BaseModel): @@ -1160,7 +1166,7 @@ def __init__( # For SD XL, need custom trained fp16 model to speed up, and avoid overflow at the same time. self.custom_fp16_vae = custom_fp16_vae - def load_model(self, framework_model_dir, hf_token: Optional[str] = None, subfolder: str = "vae_decoder"): + def load_model(self, framework_model_dir, subfolder: str = "vae_decoder"): model_name = self.custom_fp16_vae or self.pipeline_info.name() model_dir = os.path.join(framework_model_dir, model_name, subfolder) @@ -1172,7 +1178,6 @@ def load_model(self, framework_model_dir, hf_token: Optional[str] = None, subfol self.pipeline_info.name(), subfolder="vae", use_safetensors=self.pipeline_info.use_safetensors(), - use_auth_token=hf_token, ).to(self.device) vae.save_pretrained(model_dir) else: @@ -1225,13 +1230,14 @@ def get_shape_dict(self, batch_size, image_height, image_width): def get_sample_input(self, batch_size, image_height, image_width): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return (torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device),) + dtype = torch.float16 if self.fp16 else torch.float32 + return (torch.randn(batch_size, 4, latent_height, latent_width, dtype=dtype, device=self.device),) def fp32_input_output_names(self) -> List[str]: - return [] if self.fp16 else ["latent", "images"] + return [] -def get_tokenizer(pipeline_info: PipelineInfo, framework_model_dir, hf_token, subfolder="tokenizer"): +def get_tokenizer(pipeline_info: PipelineInfo, framework_model_dir, subfolder="tokenizer"): tokenizer_dir = os.path.join(framework_model_dir, pipeline_info.name(), subfolder) if not os.path.exists(tokenizer_dir): @@ -1239,7 +1245,6 @@ def get_tokenizer(pipeline_info: PipelineInfo, framework_model_dir, hf_token, su pipeline_info.name(), subfolder=subfolder, use_safetensors=pipeline_info.is_xl(), - use_auth_token=hf_token, ) model.save_pretrained(tokenizer_dir) else: @@ -1266,8 +1271,8 @@ def __init__(self, pipeline_info: PipelineInfo, model, device, max_batch_size): max_batch_size=max_batch_size, ) - def load_model(self, framework_model_dir, hf_token, subfolder="vae_encoder"): - vae = self.from_pretrained(AutoencoderKL, framework_model_dir, hf_token, subfolder) + def load_model(self, framework_model_dir, subfolder="vae_encoder"): + vae = self.from_pretrained(AutoencoderKL, framework_model_dir, subfolder) return TorchVAEEncoder(vae) def get_input_names(self): diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py index ffa986f53304c..6bd00a854a97f 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -34,7 +34,6 @@ def __init__( pipeline_info: PipelineInfo, device="cuda", max_batch_size=16, - hf_token=None, use_cuda_graph=False, ): """ @@ -47,21 +46,18 @@ def __init__( device to run engine max_batch_size (int): Maximum batch size for dynamic batch engine. - hf_token (str): - HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. use_cuda_graph (bool): Use CUDA graph to capture engine execution and then launch inference """ self.engine_type = engine_type self.pipeline_info = pipeline_info self.max_batch_size = max_batch_size - self.hf_token = hf_token self.use_cuda_graph = use_cuda_graph self.device = torch.device(device) self.torch_device = torch.device(device, torch.cuda.current_device()) self.stages = pipeline_info.stages() - self.vae_torch_fallback = self.pipeline_info.vae_torch_fallback() + self.vae_torch_fallback = self.pipeline_info.vae_torch_fallback() and self.engine_type != EngineType.TORCH self.custom_fp16_vae = self.pipeline_info.custom_fp16_vae() self.models = {} @@ -87,24 +83,36 @@ def teardown(self): del engine self.engines = {} + def get_diffusers_module_name(self, model_name): + name_mapping = { + "clip": "text_encoder", + "clip2": "text_encoder_2", + "unet": "unet", + "unetxl": "unet", + "vae": "vae_decoder", + } + return name_mapping[model_name] if model_name in name_mapping else model_name + def get_cached_model_name(self, model_name): + model_name = self.get_diffusers_module_name(model_name) + is_unet = model_name == "unet" hash_source = [] - if model_name in ["clip", "clip2", "unet", "unetxl"] and self.pipeline_info.lora_weights: + if model_name in ["text_encoder", "text_encoder_2", "unet"] and self.pipeline_info.lora_weights: if self.pipeline_info.lora_weights in [ "latent-consistency/lcm-lora-sdxl", "latent-consistency/lcm-lora-sdv1-5", ]: - if model_name in ["unet", "unetxl"]: - model_name = model_name + "_lcm-lora" + if is_unet: + model_name = "unet_lcm-lora" else: model_name = model_name + "_lora" hash_source.append(self.pipeline_info.lora_weights) # TODO(tianleiwu): save custom model to a directory named by its original model. - if model_name == "unetxl" and self.pipeline_info.custom_unet(): + if is_unet and self.pipeline_info.custom_unet(): model_name = model_name + "_lcm" - if model_name in ["unet", "unetxl"] and self.pipeline_info.controlnet: + if model_name in ["unet"] and self.pipeline_info.controlnet: model_name = model_name + "_" + "_".join(self.pipeline_info.controlnet) if hash_source: @@ -118,8 +126,9 @@ def get_cached_model_name(self, model_name): def get_model_dir(self, model_name, root_dir, opt=True, suffix="", create=True): engine_name = self.engine_type.name.lower() - # TODO: Need not add engine name for ORT_CUDA - directory_name = self.get_cached_model_name(model_name) + (f".{engine_name}" if opt else "") + suffix + if engine_name != "ort_cuda" and not suffix: + suffix = f".{engine_name}" if opt else "" + directory_name = self.get_cached_model_name(model_name) + suffix onnx_model_dir = os.path.join(root_dir, directory_name) if create: os.makedirs(onnx_model_dir, exist_ok=True) @@ -160,14 +169,14 @@ def get_or_load_model(self, pipeline, model_name, model_obj, framework_model_dir model = pipeline.unet pipeline.unet = None else: - model = model_obj.load_model(framework_model_dir, self.hf_token) + model = model_obj.load_model(framework_model_dir) return model.to(self.torch_device) def load_models(self, framework_model_dir: str): # For TRT or ORT_TRT, we will export fp16 torch model for UNet. # For ORT_CUDA, we export fp32 model first, then optimize to fp16. - export_fp16_unet = self.engine_type in [EngineType.ORT_TRT, EngineType.TRT] + export_fp16 = self.engine_type in [EngineType.ORT_TRT, EngineType.TRT] if "clip" in self.stages: self.models["clip"] = CLIP( @@ -192,7 +201,7 @@ def load_models(self, framework_model_dir: str): self.pipeline_info, None, # not loaded yet device=self.torch_device, - fp16=export_fp16_unet, + fp16=export_fp16, max_batch_size=self.max_batch_size, unet_dim=(9 if self.pipeline_info.is_inpaint() else 4), ) @@ -202,7 +211,7 @@ def load_models(self, framework_model_dir: str): self.pipeline_info, None, # not loaded yet device=self.torch_device, - fp16=export_fp16_unet, + fp16=export_fp16, max_batch_size=self.max_batch_size, unet_dim=4, time_dim=(5 if self.pipeline_info.is_xl_refiner() else 6), @@ -215,13 +224,17 @@ def load_models(self, framework_model_dir: str): None, # not loaded yet device=self.torch_device, max_batch_size=self.max_batch_size, + fp16=(export_fp16 and self.custom_fp16_vae is not None), custom_fp16_vae=self.custom_fp16_vae, ) if self.vae_torch_fallback: - self.torch_models["vae"] = self.models["vae"].load_model(framework_model_dir, self.hf_token) + self.torch_models["vae"] = self.models["vae"].load_model(framework_model_dir) def load_resources(self, image_height, image_width, batch_size): + if self.engine_type == EngineType.TORCH: + return + # Allocate buffers for I/O bindings for model_name, obj in self.models.items(): if model_name == "vae" and self.vae_torch_fallback: @@ -232,13 +245,22 @@ def load_resources(self, image_height, image_width, batch_size): ) def _vae_decode(self, latents): - if self.vae_torch_fallback: + if self.engine_type == EngineType.TORCH: + if self.pipeline_info.is_xl() and not self.custom_fp16_vae: # need upcast + latents = latents.to(dtype=torch.float32) + images = self.engines["vae"](latents)["sample"] + else: + images = self.engines["vae"](latents)["sample"] + elif self.vae_torch_fallback: if not self.custom_fp16_vae: latents = latents.to(dtype=torch.float32) self.torch_models["vae"] = self.torch_models["vae"].to(dtype=torch.float32) images = self.torch_models["vae"](latents)["sample"] else: - images = self.run_engine("vae", {"latent": latents})["images"] + if self.pipeline_info.is_xl() and not self.custom_fp16_vae: # need upcast + images = self.run_engine("vae", {"latent": latents.to(dtype=torch.float32)})["images"] + else: + images = self.run_engine("vae", {"latent": latents})["images"] return images diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py index 2ac9a45577676..30414776dab04 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py @@ -6,13 +6,15 @@ import gc import logging import os -import shutil from typing import List, Optional +import onnx import torch from diffusion_models import PipelineInfo from engine_builder import EngineBuilder, EngineType -from ort_utils import CudaSession +from onnx import TensorProto +from ort_utils import CudaSession, OnnxModel +from packaging import version import onnxruntime as ort @@ -83,7 +85,6 @@ def __init__( self, pipeline_info: PipelineInfo, max_batch_size=16, - hf_token=None, device="cuda", use_cuda_graph=False, ): @@ -95,8 +96,6 @@ def __init__( Version and Type of pipeline. max_batch_size (int): Maximum batch size for dynamic batch engine. - hf_token (str): - HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. device (str): device to run. use_cuda_graph (bool): @@ -106,7 +105,6 @@ def __init__( EngineType.ORT_CUDA, pipeline_info, max_batch_size=max_batch_size, - hf_token=hf_token, device=device, use_cuda_graph=use_cuda_graph, ) @@ -153,6 +151,65 @@ def configure_xl(self, onnx_opset_version: int): use_cuda_graph=self.use_cuda_graph, ) + def optimized_onnx_path(self, engine_dir, model_name): + suffix = "" if self.model_config[model_name].fp16 else ".fp32" + return self.get_onnx_path(model_name, engine_dir, opt=True, suffix=suffix) + + def import_diffusers_engine(self, diffusers_onnx_dir: str, engine_dir: str): + """Import optimized onnx models for diffusers from Olive or optimize_pipeline tools. + + Args: + diffusers_onnx_dir (str): optimized onnx directory of Olive + engine_dir (str): the directory to store imported onnx + """ + if version.parse(ort.__version__) < version.parse("1.17.0"): + print("Skip importing since onnxruntime-gpu version < 1.17.0.") + return + + for model_name, model_obj in self.models.items(): + onnx_import_path = self.optimized_onnx_path(diffusers_onnx_dir, model_name) + if not os.path.exists(onnx_import_path): + print(f"{onnx_import_path} not existed. Skip importing.") + continue + + onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name) + if os.path.exists(onnx_opt_path): + print(f"{onnx_opt_path} existed. Skip importing.") + continue + + if model_name == "vae" and self.pipeline_info.is_xl(): + print(f"Skip importing VAE since it is not fully compatible with float16: {onnx_import_path}.") + continue + + model = OnnxModel(onnx.load(onnx_import_path, load_external_data=True)) + + if model_name in ["clip", "clip2"]: + hidden_states_per_layer = [] + for output in model.graph().output: + if output.name.startswith("hidden_states."): + hidden_states_per_layer.append(output.name) + if hidden_states_per_layer: + kept_hidden_states = hidden_states_per_layer[-2 - model_obj.clip_skip] + model.rename_graph_output(kept_hidden_states, "hidden_states") + + model.rename_graph_output( + "last_hidden_state" if model_name == "clip" else "text_embeds", "text_embeddings" + ) + model.prune_graph( + ["text_embeddings", "hidden_states"] if hidden_states_per_layer else ["text_embeddings"] + ) + + if model_name == "clip2": + model.change_graph_input_type(model.find_graph_input("input_ids"), TensorProto.INT32) + + model.save_model_to_file(onnx_opt_path, use_external_data_format=(model_name == "clip2")) + elif model_name in ["unet", "unetxl"]: + model.rename_graph_output("out_sample", "latent") + model.save_model_to_file(onnx_opt_path, use_external_data_format=True) + + del model + continue + def build_engines( self, engine_dir: str, @@ -160,21 +217,13 @@ def build_engines( onnx_dir: str, tmp_dir: Optional[str] = None, onnx_opset_version: int = 17, - force_engine_rebuild: bool = False, device_id: int = 0, - save_fp32_intermediate_model=False, + save_fp32_intermediate_model: bool = False, + import_engine_dir: Optional[str] = None, ): self.torch_device = torch.device("cuda", device_id) self.load_models(framework_model_dir) - if force_engine_rebuild: - if os.path.isdir(onnx_dir): - logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir) - shutil.rmtree(onnx_dir) - if os.path.isdir(engine_dir): - logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir) - shutil.rmtree(engine_dir) - if not os.path.isdir(engine_dir): os.makedirs(engine_dir) @@ -188,6 +237,13 @@ def build_engines( if model_name not in self.model_config: self.model_config[model_name] = _ModelConfig(onnx_opset_version, self.use_cuda_graph) + # Import Engine + if import_engine_dir: + if self.pipeline_info.is_xl(): + self.import_diffusers_engine(import_engine_dir, engine_dir) + else: + print(f"Only support importing SDXL onnx. Ignore --engine-dir {import_engine_dir}") + # Load lora only when we need export text encoder or UNet to ONNX. load_lora = False if self.pipeline_info.lora_weights: @@ -195,9 +251,7 @@ def build_engines( if model_name not in ["clip", "clip2", "unet", "unetxl"]: continue onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) - - suffix = ".fp16" if self.model_config[model_name].fp16 else ".fp32" - onnx_opt_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=suffix) + onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name) if not os.path.exists(onnx_opt_path): if not os.path.exists(onnx_path): load_lora = True @@ -212,8 +266,7 @@ def build_engines( continue onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) - suffix = ".fp16" if self.model_config[model_name].fp16 else ".fp32" - onnx_opt_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=suffix) + onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name) if not os.path.exists(onnx_opt_path): if not os.path.exists(onnx_path): print("----") @@ -280,7 +333,7 @@ def build_engines( fp32_op_list=self.model_config[model_name].force_fp32_ops, optimize_by_ort=optimize_by_ort, optimize_by_fusion=not use_fp32_intermediate, - tmp_dir=self.get_model_dir(model_name, tmp_dir, opt=False, suffix=".fp16", create=False), + tmp_dir=self.get_model_dir(model_name, tmp_dir, opt=False, suffix=".ort", create=False), ) else: logger.info("Found cached optimized model: %s", onnx_opt_path) @@ -291,9 +344,7 @@ def build_engines( if model_name == "vae" and self.vae_torch_fallback: continue - suffix = ".fp16" if self.model_config[model_name].fp16 else ".fp32" - onnx_opt_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=suffix) - + onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name) use_cuda_graph = self.model_config[model_name].use_cuda_graph engine = OrtCudaEngine( @@ -308,7 +359,5 @@ def build_engines( self.engines = built_engines - return built_engines - def run_engine(self, model_name, feed_dict): return self.engines[model_name].infer(feed_dict) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py index 8c637007b840d..a0b9ae886f04e 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py @@ -6,7 +6,6 @@ import gc import logging import os -import shutil import torch from cuda import cudart @@ -110,7 +109,6 @@ def __init__( self, pipeline_info: PipelineInfo, max_batch_size=16, - hf_token=None, device="cuda", use_cuda_graph=False, ): @@ -122,8 +120,6 @@ def __init__( Version and Type of pipeline. max_batch_size (int): Maximum batch size for dynamic batch engine. - hf_token (str): - HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. device (str): device to run. use_cuda_graph (bool): @@ -133,7 +129,6 @@ def __init__( EngineType.ORT_TRT, pipeline_info, max_batch_size=max_batch_size, - hf_token=hf_token, device=device, use_cuda_graph=use_cuda_graph, ) @@ -165,7 +160,6 @@ def build_engines( opt_image_height, opt_image_width, opt_batch_size=1, - force_engine_rebuild=False, static_batch=False, static_image_shape=True, max_workspace_size=0, @@ -175,14 +169,6 @@ def build_engines( self.torch_device = torch.device("cuda", device_id) self.load_models(framework_model_dir) - if force_engine_rebuild: - if os.path.isdir(onnx_dir): - logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir) - shutil.rmtree(onnx_dir) - if os.path.isdir(engine_dir): - logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir) - shutil.rmtree(engine_dir) - if not os.path.isdir(engine_dir): os.makedirs(engine_dir) @@ -298,7 +284,5 @@ def build_engines( self.engines = built_engines - return built_engines - def run_engine(self, model_name, feed_dict): return self.engines[model_name].infer(feed_dict) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py index bac1a8bb8140d..438145fc2c57a 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py @@ -26,8 +26,6 @@ from collections import OrderedDict import numpy as np -import onnx -import onnx_graphsurgeon as gs import tensorrt as trt import torch from cuda import cudart @@ -43,7 +41,6 @@ network_from_onnx_path, save_engine, ) -from trt_utilities import TRT_LOGGER # Map of numpy dtype -> torch dtype numpy_to_torch_dtype_dict = { @@ -83,115 +80,11 @@ def __del__(self): del self.buffers del self.tensors - def refit(self, onnx_path, onnx_refit_path): - def convert_int64(arr): - if len(arr.shape) == 0: - return np.int32(arr) - return arr - - def add_to_map(refit_dict, name, values): - if name in refit_dict: - assert refit_dict[name] is None - if values.dtype == np.int64: - values = convert_int64(values) - refit_dict[name] = values - - print(f"Refitting TensorRT engine with {onnx_refit_path} weights") - refit_nodes = gs.import_onnx(onnx.load(onnx_refit_path)).toposort().nodes - - # Construct mapping from weight names in refit model -> original model - name_map = {} - for n, node in enumerate(gs.import_onnx(onnx.load(onnx_path)).toposort().nodes): - refit_node = refit_nodes[n] - assert node.op == refit_node.op - # Constant nodes in ONNX do not have inputs but have a constant output - if node.op == "Constant": - name_map[refit_node.outputs[0].name] = node.outputs[0].name - # Handle scale and bias weights - elif node.op == "Conv": - if node.inputs[1].__class__ == gs.Constant: - name_map[refit_node.name + "_TRTKERNEL"] = node.name + "_TRTKERNEL" - if node.inputs[2].__class__ == gs.Constant: - name_map[refit_node.name + "_TRTBIAS"] = node.name + "_TRTBIAS" - # For all other nodes: find node inputs that are initializers (gs.Constant) - else: - for i, inp in enumerate(node.inputs): - if inp.__class__ == gs.Constant: - name_map[refit_node.inputs[i].name] = inp.name - - def map_name(name): - if name in name_map: - return name_map[name] - return name - - # Construct refit dictionary - refit_dict = {} - refitter = trt.Refitter(self.engine, TRT_LOGGER) - all_weights = refitter.get_all() - for layer_name, role in zip(all_weights[0], all_weights[1]): - # for specialized roles, use a unique name in the map: - if role == trt.WeightsRole.KERNEL: - name = layer_name + "_TRTKERNEL" - elif role == trt.WeightsRole.BIAS: - name = layer_name + "_TRTBIAS" - else: - name = layer_name - - assert name not in refit_dict, "Found duplicate layer: " + name - refit_dict[name] = None - - for n in refit_nodes: - # Constant nodes in ONNX do not have inputs but have a constant output - if n.op == "Constant": - name = map_name(n.outputs[0].name) - print(f"Add Constant {name}\n") - add_to_map(refit_dict, name, n.outputs[0].values) - - # Handle scale and bias weights - elif n.op == "Conv": - if n.inputs[1].__class__ == gs.Constant: - name = map_name(n.name + "_TRTKERNEL") - add_to_map(refit_dict, name, n.inputs[1].values) - - if n.inputs[2].__class__ == gs.Constant: - name = map_name(n.name + "_TRTBIAS") - add_to_map(refit_dict, name, n.inputs[2].values) - - # For all other nodes: find node inputs that are initializers (AKA gs.Constant) - else: - for inp in n.inputs: - name = map_name(inp.name) - if inp.__class__ == gs.Constant: - add_to_map(refit_dict, name, inp.values) - - for layer_name, weights_role in zip(all_weights[0], all_weights[1]): - if weights_role == trt.WeightsRole.KERNEL: - custom_name = layer_name + "_TRTKERNEL" - elif weights_role == trt.WeightsRole.BIAS: - custom_name = layer_name + "_TRTBIAS" - else: - custom_name = layer_name - - # Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model - if layer_name.startswith("onnx::Trilu"): - continue - - if refit_dict[custom_name] is not None: - refitter.set_weights(layer_name, weights_role, refit_dict[custom_name]) - else: - print(f"[W] No refit weights for layer: {layer_name}") - - if not refitter.refit_cuda_engine(): - print("Failed to refit!") - exit(0) - def build( self, onnx_path, fp16, input_profile=None, - enable_refit=False, - enable_preview=False, enable_all_tactics=False, timing_cache=None, update_output_names=None, @@ -214,7 +107,7 @@ def build( engine = engine_from_network( network, config=CreateConfig( - fp16=fp16, refittable=enable_refit, profiles=[p], load_timing_cache=timing_cache, **config_kwargs + fp16=fp16, refittable=False, profiles=[p], load_timing_cache=timing_cache, **config_kwargs ), save_timing_cache=timing_cache, ) @@ -294,7 +187,6 @@ def __init__( self, pipeline_info: PipelineInfo, max_batch_size=16, - hf_token=None, device="cuda", use_cuda_graph=False, ): @@ -306,8 +198,6 @@ def __init__( Version and Type of pipeline. max_batch_size (int): Maximum batch size for dynamic batch engine. - hf_token (str): - HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. device (str): device to run. use_cuda_graph (bool): @@ -317,7 +207,6 @@ def __init__( EngineType.TRT, pipeline_info, max_batch_size=max_batch_size, - hf_token=hf_token, device=device, use_cuda_graph=use_cuda_graph, ) @@ -348,16 +237,10 @@ def load_engines( opt_batch_size, opt_image_height, opt_image_width, - force_export=False, - force_optimize=False, - force_build=False, static_batch=False, static_shape=True, - enable_refit=False, - enable_preview=False, enable_all_tactics=False, timing_cache=None, - onnx_refit_dir=None, ): """ Build and load engines for TensorRT accelerated inference. @@ -378,26 +261,14 @@ def load_engines( Image height to optimize for during engine building. Must be a multiple of 8. opt_image_width (int): Image width to optimize for during engine building. Must be a multiple of 8. - force_export (bool): - Force re-exporting the ONNX models. - force_optimize (bool): - Force re-optimizing the ONNX models. - force_build (bool): - Force re-building the TensorRT engine. static_batch (bool): Build engine only for specified opt_batch_size. static_shape (bool): Build engine only for specified opt_image_height & opt_image_width. Default = True. - enable_refit (bool): - Build engines with refit option enabled. - enable_preview (bool): - Enable TensorRT preview features. enable_all_tactics (bool): Enable all tactic sources during TensorRT engine builds. timing_cache (str): Path to the timing cache to accelerate build or None - onnx_refit_dir (str): - Directory containing refit ONNX models. """ # Create directory for directory in [engine_dir, onnx_dir]: @@ -417,11 +288,11 @@ def load_engines( opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape ) engine_path = self.get_engine_path(engine_dir, model_name, profile_id) - if force_export or force_build or not os.path.exists(engine_path): + if not os.path.exists(engine_path): onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) - if force_export or not os.path.exists(onnx_opt_path): - if force_export or not os.path.exists(onnx_path): + if not os.path.exists(onnx_opt_path): + if not os.path.exists(onnx_path): load_lora = True break @@ -436,11 +307,11 @@ def load_engines( opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape ) engine_path = self.get_engine_path(engine_dir, model_name, profile_id) - if force_export or force_build or not os.path.exists(engine_path): + if not os.path.exists(engine_path): onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) - if force_export or not os.path.exists(onnx_opt_path): - if force_export or not os.path.exists(onnx_path): + if not os.path.exists(onnx_opt_path): + if not os.path.exists(onnx_path): print(f"Exporting model: {onnx_path}") model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir) @@ -464,7 +335,7 @@ def load_engines( print(f"Found cached model: {onnx_path}") # Optimize onnx - if force_optimize or not os.path.exists(onnx_opt_path): + if not os.path.exists(onnx_opt_path): print(f"Generating optimizing model: {onnx_opt_path}") model_obj.optimize_trt(onnx_path, onnx_opt_path) else: @@ -482,7 +353,7 @@ def load_engines( engine = TensorrtEngine(engine_path) onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) - if force_build or not os.path.exists(engine.engine_path): + if not os.path.exists(engine.engine_path): engine.build( onnx_opt_path, fp16=True, @@ -493,8 +364,6 @@ def load_engines( static_batch, static_shape, ), - enable_refit=enable_refit, - enable_preview=enable_preview, enable_all_tactics=enable_all_tactics, timing_cache=timing_cache, update_output_names=None, @@ -506,10 +375,6 @@ def load_engines( if model_name == "vae" and self.vae_torch_fallback: continue self.engines[model_name].load() - if onnx_refit_dir: - onnx_refit_path = self.get_onnx_path(model_name, onnx_refit_dir, opt=True) - if os.path.exists(onnx_refit_path): - self.engines[model_name].refit(onnx_opt_path, onnx_refit_path) def max_device_memory(self): max_device_memory = 0 diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_torch.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_torch.py new file mode 100644 index 0000000000000..0c59d5485f1cb --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_torch.py @@ -0,0 +1,107 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging + +from diffusion_models import PipelineInfo +from engine_builder import EngineBuilder, EngineType + +logger = logging.getLogger(__name__) + + +class TorchEngineBuilder(EngineBuilder): + def __init__( + self, + pipeline_info: PipelineInfo, + max_batch_size=16, + device="cuda", + use_cuda_graph=False, + ): + """ + Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + device (str): + device to run. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + """ + super().__init__( + EngineType.TORCH, + pipeline_info, + max_batch_size=max_batch_size, + device=device, + use_cuda_graph=use_cuda_graph, + ) + + if use_cuda_graph: + self.compile_config = { + "clip": {"mode": "reduce-overhead", "dynamic": False}, + "clip2": {"mode": "reduce-overhead", "dynamic": False}, + "unet": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False}, + "unetxl": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False}, + "vae": {"mode": "reduce-overhead", "fullgraph": False, "dynamic": False}, + } + + def build_engines( + self, + framework_model_dir: str, + ): + import torch + + self.torch_device = torch.device("cuda", torch.cuda.current_device()) + self.load_models(framework_model_dir) + + pipe = self.load_pipeline_with_lora() if self.pipeline_info.lora_weights else None + + built_engines = {} + for model_name, model_obj in self.models.items(): + model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir) + if self.pipeline_info.is_xl() and not self.custom_fp16_vae: + model = model.to(device=self.torch_device, dtype=torch.float32) + else: + model = model.to(device=self.torch_device, dtype=torch.float16) + + if model_name in self.compile_config: + compile_config = self.compile_config[model_name] + if model_name in ["unet", "unetxl"]: + model.to(memory_format=torch.channels_last) + engine = torch.compile(model, **compile_config) + built_engines[model_name] = engine + else: # eager mode + built_engines[model_name] = model + + self.engines = built_engines + + def run_engine(self, model_name, feed_dict): + if model_name in ["unet", "unetxl"]: + if "controlnet_images" in feed_dict: + return {"latent": self.engines[model_name](**feed_dict)} + + if model_name == "unetxl": + added_cond_kwargs = {k: feed_dict[k] for k in feed_dict if k in ["text_embeds", "time_ids"]} + return { + "latent": self.engines[model_name]( + feed_dict["sample"], + feed_dict["timestep"], + feed_dict["encoder_hidden_states"], + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + } + + return { + "latent": self.engines[model_name]( + feed_dict["sample"], feed_dict["timestep"], feed_dict["encoder_hidden_states"], return_dict=False + )[0] + } + + if model_name in ["vae_encoder"]: + return {"latent": self.engines[model_name](feed_dict["images"])} + + raise RuntimeError(f"Shall not reach here: {model_name}") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py deleted file mode 100644 index 37785869a355b..0000000000000 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py +++ /dev/null @@ -1,292 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -# -# Copyright 2023 The HuggingFace Inc. team. -# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Stable diffusion text to image pipeline using ONNX Runtime CUDA execution provider. -Based on https://github.com/huggingface/diffusers/blob/v0.17.1/examples/community/stable_diffusion_tensorrt_txt2img.py -Modifications: (1) Create ONNX Runtime session (2) Use I/O Binding of ONNX Runtime for inference - -Installation instructions -pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 -pip install --upgrade transformers diffusers>=0.16.0 -pip install numpy>=1.24.1 onnx>=1.13.0 coloredlogs protobuf==3.20.3 psutil sympy -pip install onnxruntime-gpu -""" - -import logging -import os -from typing import List, Optional, Union - -import torch -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipelines.stable_diffusion import ( - StableDiffusionPipeline, - StableDiffusionPipelineOutput, - StableDiffusionSafetyChecker, -) -from diffusers.schedulers import DDIMScheduler -from diffusion_models import CLIP, VAE, PipelineInfo, UNet -from ort_utils import Engines, StableDiffusionPipelineMixin -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -logger = logging.getLogger(__name__) - - -class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipelineMixin, StableDiffusionPipeline): - r""" - Pipeline for text-to-image generation using CUDA provider in ONNX Runtime. - This pipeline inherits from [`StableDiffusionPipeline`]. Check the documentation in super class for most parameters. - """ - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: DDIMScheduler, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - # ONNX export parameters - onnx_opset: int = 14, - onnx_dir: str = "onnx_ort", - # Onnxruntime execution provider parameters - engine_dir: str = "ORT_CUDA", - force_engine_rebuild: bool = False, - enable_cuda_graph: bool = False, - pipeline_info: PipelineInfo = None, - ): - super().__init__( - vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker - ) - - self.vae.forward = self.vae.decode - self.unet_in_channels = unet.config.in_channels - - self.inpaint = False - self.onnx_dir = onnx_dir - self.engine_dir = engine_dir - self.force_engine_rebuild = force_engine_rebuild - self.enable_cuda_graph = enable_cuda_graph - - self.max_batch_size = 16 - - self.models = {} # loaded in __load_models() - self.engines = Engines("CUDAExecutionProvider", onnx_opset) - - self.fp16 = False - - self.pipeline_info = pipeline_info - - def load_models(self): - assert self.pipeline_info.clip_embedding_dim() == self.text_encoder.config.hidden_size - - stages = self.pipeline_info.stages() - if "clip" in stages: - self.models["clip"] = CLIP( - self.pipeline_info, - self.text_encoder, - device=self.torch_device, - max_batch_size=self.max_batch_size, - clip_skip=0, - ) - - if "unet" in stages: - self.models["unet"] = UNet( - self.pipeline_info, - self.unet, - device=self.torch_device, - fp16=False, - max_batch_size=self.max_batch_size, - unet_dim=(9 if self.pipeline_info.is_inpaint() else 4), - ) - - if "vae" in stages: - self.models["vae"] = VAE( - self.pipeline_info, - self.vae, - device=self.torch_device, - max_batch_size=self.max_batch_size, - ) - - def to( - self, - torch_device: Union[str, torch.device], - torch_dtype: Optional[torch.dtype] = None, - silence_dtype_warnings: bool = False, - ): - self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir) - self.engine_dir = os.path.join(self.cached_folder, self.engine_dir) - - # set device - self.torch_device = torch.device(torch_device) - - # load models - self.fp16 = torch_dtype == torch.float16 - self.load_models() - - # build engines - self.engines.build( - self.models, - self.engine_dir, - self.onnx_dir, - force_engine_rebuild=self.force_engine_rebuild, - fp16=self.fp16, - device_id=self.torch_device.index or torch.cuda.current_device(), - enable_cuda_graph=self.enable_cuda_graph, - ) - - # Load the remaining modules to GPU. - self.text_encoder = None - self.vae = None - self.unet = None - super().to(torch_device, torch_dtype, silence_dtype_warnings=silence_dtype_warnings) - - self.torch_device = self._execution_device - logger.info(f"Running inference on device: {self.torch_device}") - - return self - - def __allocate_buffers(self, image_height, image_width, batch_size): - # Allocate output tensors for I/O bindings - for model_name, obj in self.models.items(): - self.engines.get_engine(model_name).allocate_buffers( - obj.get_shape_dict(batch_size, image_height, image_width) - ) - - @torch.no_grad() - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - image_height: int = 512, - image_width: int = 512, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - """ - self.generator = generator - self.denoising_steps = num_inference_steps - self.guidance_scale = guidance_scale - - # Pre-compute latent input scales and linear multistep coefficients - self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device) - - # Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - prompt = [prompt] - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"Expected prompt to be of type list or str but got {type(prompt)}") - - if negative_prompt is None: - negative_prompt = [""] * batch_size - - if negative_prompt is not None and isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] - - assert len(prompt) == len(negative_prompt) - - if batch_size > self.max_batch_size: - raise ValueError( - f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4" - ) - - self.__allocate_buffers(image_height, image_width, batch_size) - - with torch.inference_mode(), torch.autocast("cuda"): - # CLIP text encoder - text_embeddings = self.encode_prompt(self.engines.get_engine("clip"), prompt, negative_prompt) - - # Pre-initialize latents - num_channels_latents = self.unet_in_channels - latents = self.prepare_latents( - batch_size, - num_channels_latents, - image_height, - image_width, - torch.float16 if self.fp16 else torch.float32, - self.torch_device, - generator, - ) - - # UNet denoiser - latents = self.denoise_latent( - self.engines.get_engine("unet"), latents, text_embeddings, timestep_fp16=self.fp16 - ) - - # VAE decode latent - images = self.decode_latent(self.engines.get_engine("vae"), latents) - - images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype) - images = self.numpy_to_pil(images) - return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) - - -def example(): - pipeline_info = PipelineInfo("1.5") - model_name_or_path = pipeline_info.name() - scheduler = DDIMScheduler.from_pretrained(model_name_or_path, subfolder="scheduler") - pipe = OnnxruntimeCudaStableDiffusionPipeline.from_pretrained( - model_name_or_path, - scheduler=scheduler, - pipeline_info=pipeline_info, - ) - - # re-use cached folder to save ONNX models - pipe.set_cached_folder(model_name_or_path, resume_download=True, local_files_only=True) - - pipe = pipe.to("cuda", torch_dtype=torch.float16) - - prompt = "photorealistic new zealand hills" - image = pipe(prompt).images[0] - image.save("ort_cuda_txt2img_new_zealand_hills.png") - - -if __name__ == "__main__": - example() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py deleted file mode 100644 index c663e37c7ea7d..0000000000000 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py +++ /dev/null @@ -1,261 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -# -# Copyright 2023 The HuggingFace Inc. team. -# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Stable diffusion text to image pipeline using ONNX Runtime TensorRT execution provider. -Based on https://github.com/huggingface/diffusers/blob/v0.17.1/examples/community/stable_diffusion_tensorrt_txt2img.py -Modifications: (1) Create ONNX Runtime session (2) Use I/O Binding of ONNX Runtime for inference - -Installation instructions -pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 -pip install --upgrade transformers diffusers>=0.16.0 -pip install --upgrade tensorrt>=8.6.1 -pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com -pip install onnxruntime-gpu -""" - -import logging -import os -from typing import List, Optional, Union - -import torch -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipelines.stable_diffusion import ( - StableDiffusionPipeline, - StableDiffusionPipelineOutput, - StableDiffusionSafetyChecker, -) -from diffusers.schedulers import DDIMScheduler -from diffusion_models import PipelineInfo -from engine_builder_ort_trt import OrtTensorrtEngineBuilder -from ort_utils import StableDiffusionPipelineMixin -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -logger = logging.getLogger(__name__) - - -class OnnxruntimeTensorRTStableDiffusionPipeline(StableDiffusionPipelineMixin, StableDiffusionPipeline): - r""" - Pipeline for text-to-image generation using TensorRT execution provider in ONNX Runtime. - - This pipeline inherits from [`StableDiffusionPipeline`]. Check the documentation in super class for most parameters. - """ - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: DDIMScheduler, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - image_height: int = 768, - image_width: int = 768, - max_batch_size: int = 16, - # ONNX export parameters - onnx_opset: int = 17, - onnx_dir: str = "onnx_trt", - # TensorRT engine build parameters - engine_dir: str = "ORT_TRT", # use short name here to avoid path exceeds 260 chars in Windows. - force_engine_rebuild: bool = False, - enable_cuda_graph: bool = False, - pipeline_info: Optional[PipelineInfo] = None, - ): - super().__init__( - vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker - ) - - self.vae.forward = self.vae.decode - - self.image_height = image_height - self.image_width = image_width - self.onnx_opset = onnx_opset - self.onnx_dir = onnx_dir - self.engine_dir = engine_dir - self.force_engine_rebuild = force_engine_rebuild - - # Although cuda graph requires static input shape, engine built with dynamic batch gets better performance in T4. - # Use static batch could reduce GPU memory footprint. - self.build_static_batch = enable_cuda_graph - - # TODO: support dynamic image shape. - self.build_dynamic_shape = False - - self.max_batch_size = max_batch_size - # Restrict batch size to 4 for larger image dimensions as a walkaround for TensorRT limitation. - if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512: - self.max_batch_size = 4 - - self.engines = {} # loaded in build_engines() - self.engine_builder = OrtTensorrtEngineBuilder( - pipeline_info, max_batch_size=max_batch_size, use_cuda_graph=enable_cuda_graph - ) - - self.pipeline_info = pipeline_info - self.stages = pipeline_info.stages() - - def to( - self, - torch_device: Optional[Union[str, torch.device]] = None, - silence_dtype_warnings: bool = False, - ): - super().to(torch_device, silence_dtype_warnings=silence_dtype_warnings) - - self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir) - self.engine_dir = os.path.join(self.cached_folder, self.engine_dir) - - # set device - self.torch_device = self._execution_device - logger.info(f"Running inference on device: {self.torch_device}") - - self.engines = self.engine_builder.build_engines( - self.engine_dir, - None, - self.onnx_dir, - self.onnx_opset, - opt_image_height=self.image_height, - opt_image_width=self.image_width, - force_engine_rebuild=self.force_engine_rebuild, - static_batch=self.build_static_batch, - static_image_shape=not self.build_dynamic_shape, - device_id=self.torch_device.index, - ) - - return self - - @torch.no_grad() - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - - """ - self.generator = generator - self.denoising_steps = num_inference_steps - self.guidance_scale = guidance_scale - - # Pre-compute latent input scales and linear multistep coefficients - self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device) - - # Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - prompt = [prompt] - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"Expected prompt to be of type list or str but got {type(prompt)}") - - if negative_prompt is None: - negative_prompt = [""] * batch_size - - if negative_prompt is not None and isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] - - assert len(prompt) == len(negative_prompt) - - if batch_size > self.max_batch_size: - raise ValueError( - f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4" - ) - - self.engine_builder.load_resources(self.image_height, self.image_width, batch_size) - - with torch.inference_mode(), torch.autocast("cuda"): - # CLIP text encoder - text_embeddings = self.encode_prompt(self.engines["clip"], prompt, negative_prompt) - - # Pre-initialize latents - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size, - num_channels_latents, - self.image_height, - self.image_width, - torch.float32, - self.torch_device, - generator, - ) - - # UNet denoiser - latents = self.denoise_latent(self.engines["unet"], latents, text_embeddings) - - # VAE decode latent - images = self.decode_latent(self.engines["vae"], latents) - - images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype) - images = self.numpy_to_pil(images) - return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) - - -if __name__ == "__main__": - pipeline_info = PipelineInfo("1.5") - model_name_or_path = pipeline_info.name() - scheduler = DDIMScheduler.from_pretrained(model_name_or_path, subfolder="scheduler") - - pipe = OnnxruntimeTensorRTStableDiffusionPipeline.from_pretrained( - model_name_or_path, - revision="fp16", - torch_dtype=torch.float16, - scheduler=scheduler, - image_height=512, - image_width=512, - max_batch_size=4, - pipeline_info=pipeline_info, - ) - - # re-use cached folder to save ONNX models and TensorRT Engines - pipe.set_cached_folder(model_name_or_path, revision="fp16") - - pipe = pipe.to("cuda") - - prompt = "photorealistic new zealand hills" - image = pipe(prompt).images[0] - image.save("ort_trt_txt2img_new_zealand_hills.png") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py index 0afa13a0f4dca..f238b70389371 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py @@ -3,16 +3,9 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -import gc import logging import os -import shutil import sys -from typing import Union - -import torch - -import onnxruntime as ort logger = logging.getLogger(__name__) @@ -26,237 +19,7 @@ def add_transformers_dir_to_path(): add_transformers_dir_to_path() -from io_binding_helper import CudaSession # noqa: E402. Walk-around to test locally - - -# ----------------------------------------------------------------------------------------------------- -# Utilities for CUDA EP -# ----------------------------------------------------------------------------------------------------- -class Engine(CudaSession): - def __init__(self, engine_path, provider: str, device_id: int = 0, enable_cuda_graph=False): - self.engine_path = engine_path - self.provider = provider - self.provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph) - - device = torch.device("cuda", device_id) - ort_session = ort.InferenceSession( - self.engine_path, - providers=[ - (provider, self.provider_options), - "CPUExecutionProvider", - ], - ) - - super().__init__(ort_session, device, enable_cuda_graph) - - -class Engines: - def __init__(self, provider, onnx_opset: int = 14): - self.provider = provider - self.engines = {} - self.onnx_opset = onnx_opset - - @staticmethod - def get_onnx_path(onnx_dir, model_name): - return os.path.join(onnx_dir, model_name + ".onnx") - - @staticmethod - def get_engine_path(engine_dir, model_name, profile_id): - return os.path.join(engine_dir, model_name + profile_id + ".onnx") - - def build( - self, - models, - engine_dir: str, - onnx_dir: str, - force_engine_rebuild: bool = False, - fp16: bool = True, - device_id: int = 0, - enable_cuda_graph: bool = False, - ): - profile_id = "_fp16" if fp16 else "_fp32" - - if force_engine_rebuild: - if os.path.isdir(onnx_dir): - logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir) - shutil.rmtree(onnx_dir) - if os.path.isdir(engine_dir): - logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir) - shutil.rmtree(engine_dir) - - if not os.path.isdir(engine_dir): - os.makedirs(engine_dir) - - if not os.path.isdir(onnx_dir): - os.makedirs(onnx_dir) - - # Export models to ONNX - for model_name, model_obj in models.items(): - onnx_path = Engines.get_onnx_path(onnx_dir, model_name) - onnx_opt_path = Engines.get_engine_path(engine_dir, model_name, profile_id) - if os.path.exists(onnx_opt_path): - logger.info("Found cached optimized model: %s", onnx_opt_path) - else: - if os.path.exists(onnx_path): - logger.info("Found cached model: %s", onnx_path) - else: - logger.info("Exporting model: %s", onnx_path) - model = model_obj.get_model().to(model_obj.device) - with torch.inference_mode(): - inputs = model_obj.get_sample_input(1, 512, 512) - fp32_inputs = tuple( - [ - (tensor.to(torch.float32) if tensor.dtype == torch.float16 else tensor) - for tensor in inputs - ] - ) - - torch.onnx.export( - model, - fp32_inputs, - onnx_path, - export_params=True, - opset_version=self.onnx_opset, - do_constant_folding=True, - input_names=model_obj.get_input_names(), - output_names=model_obj.get_output_names(), - dynamic_axes=model_obj.get_dynamic_axes(), - ) - del model - torch.cuda.empty_cache() - gc.collect() - - # Optimize onnx - logger.info("Generating optimized model: %s", onnx_opt_path) - model_obj.optimize_ort(onnx_path, onnx_opt_path, to_fp16=fp16) - - for model_name in models: - engine_path = Engines.get_engine_path(engine_dir, model_name, profile_id) - engine = Engine(engine_path, self.provider, device_id=device_id, enable_cuda_graph=enable_cuda_graph) - logger.info("%s options for %s: %s", self.provider, model_name, engine.provider_options) - self.engines[model_name] = engine - - def get_engine(self, model_name): - return self.engines[model_name] - - -def run_engine(engine, feed_dict): - return engine.infer(feed_dict) - - -# ----------------------------------------------------------------------------------------------------- -# Utilities for both CUDA and TensorRT EP -# ----------------------------------------------------------------------------------------------------- - - -class StableDiffusionPipelineMixin: - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def encode_prompt(self, clip_engine, prompt, negative_prompt): - """ - Encodes the prompt into text encoder hidden states. - """ - - # Tokenize prompt - text_input_ids = ( - self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.torch_device) - ) - - # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt - text_embeddings = run_engine(clip_engine, {"input_ids": text_input_ids})["text_embeddings"].clone() - - # Tokenize negative prompt - uncond_input_ids = ( - self.tokenizer( - negative_prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.torch_device) - ) - - uncond_embeddings = run_engine(clip_engine, {"input_ids": uncond_input_ids})["text_embeddings"] - - # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) - - return text_embeddings - - def denoise_latent( - self, - unet_engine, - latents, - text_embeddings, - timesteps=None, - mask=None, - masked_image_latents=None, - timestep_fp16=False, - ): - if not isinstance(timesteps, torch.Tensor): - timesteps = self.scheduler.timesteps - - for _step_index, timestep in enumerate(timesteps): - # Expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) - if isinstance(mask, torch.Tensor): - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - - # Predict the noise residual - timestep_float = timestep.to(torch.float16) if timestep_fp16 else timestep.to(torch.float32) - - noise_pred = run_engine( - unet_engine, - {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, - )["latent"] - - # Perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample - - latents = 1.0 / 0.18215 * latents - return latents - - def decode_latent(self, vae_engine, latents): - images = run_engine(vae_engine, {"latent": latents})["images"] - images = (images / 2 + 0.5).clamp(0, 1) - return images.cpu().permute(0, 2, 3, 1).float().numpy() - - def set_cached_folder(self, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): - from diffusers.utils import DIFFUSERS_CACHE - from huggingface_hub import snapshot_download - - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - self.cached_folder = ( - pretrained_model_name_or_path - if os.path.isdir(pretrained_model_name_or_path) - else snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - ) - ) +# Walkaround so that we can test local change without building new package +from io_binding_helper import CudaSession # noqa +from onnx_model import OnnxModel # noqa diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py deleted file mode 100644 index 31ede1ba901f2..0000000000000 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py +++ /dev/null @@ -1,236 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -# Modified from TensorRT demo diffusion, which has the following license: -# -# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- - -import time - -import torch -from diffusion_models import PipelineInfo -from pipeline_stable_diffusion import StableDiffusionPipeline - - -class Img2ImgXLPipeline(StableDiffusionPipeline): - """ - Stable Diffusion Img2Img XL pipeline. - """ - - def __init__(self, pipeline_info: PipelineInfo, *args, **kwargs): - """ - Initializes the Img2Img XL Diffusion pipeline. - - Args: - pipeline_info (PipelineInfo): - Version and Type of stable diffusion pipeline. - """ - assert pipeline_info.is_xl_refiner() - - super().__init__(pipeline_info, *args, **kwargs) - - self.requires_aesthetics_score = True - - def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype - ): - if self.requires_aesthetics_score: - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) - else: - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0).to(device=self.device) - return add_time_ids - - def _infer( - self, - prompt, - negative_prompt, - init_image, - image_height, - image_width, - denoising_steps=30, - strength=0.3, - guidance=5.0, - seed=None, - warmup=False, - return_type="image", - ): - assert negative_prompt is None or len(prompt) == len(negative_prompt) - - original_size = (image_height, image_width) - crops_coords_top_left = (0, 0) - target_size = (image_height, image_width) - - aesthetic_score = 6.0 - negative_aesthetic_score = 2.5 - - self.set_denoising_steps(denoising_steps) - self.set_random_seed(seed) - - with torch.inference_mode(), torch.autocast("cuda"): - batch_size = len(prompt) - - torch.cuda.synchronize() - e2e_tic = time.perf_counter() - - # Initialize timesteps - timesteps, t_start = self.initialize_timesteps(self.denoising_steps, strength) - - latent_timestep = timesteps[:1].repeat(batch_size) - - # CLIP text encoder 2 - text_embeddings, pooled_embeddings2 = self.encode_prompt( - prompt, - negative_prompt, - encoder="clip2", - tokenizer=self.tokenizer2, - pooled_outputs=True, - output_hidden_states=True, - ) - - # Time embeddings - add_time_ids = self._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - dtype=text_embeddings.dtype, - ) - - add_time_ids = add_time_ids.repeat(batch_size, 1) - - add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids} - - # Pre-process input image - init_image = self.preprocess_images(batch_size, (init_image,))[0] - - # VAE encode init image - if init_image.shape[1] == 4: - init_latents = init_image - else: - init_latents = self.encode_image(init_image) - - # Add noise to latents using timesteps - noise = torch.randn(init_latents.shape, device=self.device, dtype=torch.float32, generator=self.generator) - latents = self.scheduler.add_noise(init_latents, noise, t_start, latent_timestep) - - # UNet denoiser - latents = self.denoise_latent( - latents, - text_embeddings, - timesteps=timesteps, - step_offset=t_start, - denoiser="unetxl", - guidance=guidance, - add_kwargs=add_kwargs, - ) - - with torch.inference_mode(): - # VAE decode latent - if return_type == "latent": - images = latents - else: - images = self.decode_latent(latents / self.vae_scaling_factor) - - torch.cuda.synchronize() - e2e_toc = time.perf_counter() - - perf_data = None - if not warmup: - print("SD-XL Refiner Pipeline") - perf_data = self.print_summary(e2e_tic, e2e_toc, batch_size) - - return images, perf_data - - def run( - self, - prompt, - negative_prompt, - init_image, - image_height, - image_width, - denoising_steps=30, - guidance=5.0, - strength=0.3, - seed=None, - warmup=False, - return_type="image", - ): - """ - Run the diffusion pipeline. - - Args: - prompt (str): - The text prompt to guide image generation. - negative_prompt (str): - The prompt not to guide the image generation. - init_image (tuple[torch.Tensor]): - Image from base pipeline. - image_height (int): - Height (in pixels) of the image to be generated. Must be a multiple of 8. - image_width (int): - Width (in pixels) of the image to be generated. Must be a multiple of 8. - denoising_steps (int): - Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference. - guidance (float): - Higher guidance scale encourages to generate images that are closely linked to the text prompt. - seed (int): - Seed for the random generator - warmup (bool): - Indicate if this is a warmup run. - return_type (str): - It can be "latent" or "image". - """ - - if self.is_backend_tensorrt(): - import tensorrt as trt - from trt_utilities import TRT_LOGGER - - with trt.Runtime(TRT_LOGGER): - return self._infer( - prompt, - negative_prompt, - init_image, - image_height, - image_width, - denoising_steps=denoising_steps, - strength=strength, - guidance=guidance, - seed=seed, - warmup=warmup, - return_type=return_type, - ) - else: - return self._infer( - prompt, - negative_prompt, - init_image, - image_height, - image_width, - denoising_steps=denoising_steps, - strength=strength, - guidance=guidance, - seed=seed, - warmup=warmup, - return_type=return_type, - ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py index e18a68d3edef8..85106f29167d4 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py @@ -23,7 +23,8 @@ import os import pathlib import random -from typing import Any, Dict, List +import time +from typing import Any, Dict, List, Optional import numpy as np import nvtx @@ -35,6 +36,8 @@ from engine_builder_ort_cuda import OrtCudaEngineBuilder from engine_builder_ort_trt import OrtTensorrtEngineBuilder from engine_builder_tensorrt import TensorrtEngineBuilder +from engine_builder_torch import TorchEngineBuilder +from PIL import Image class StableDiffusionPipeline: @@ -49,12 +52,11 @@ def __init__( scheduler="DDIM", device="cuda", output_dir=".", - hf_token=None, verbose=False, nvtx_profile=False, use_cuda_graph=False, framework_model_dir="pytorch_model", - engine_type: EngineType = EngineType.ORT_TRT, + engine_type: EngineType = EngineType.ORT_CUDA, ): """ Initializes the Diffusion pipeline. @@ -70,8 +72,6 @@ def __init__( PyTorch device to run inference. Default: 'cuda' output_dir (str): Output directory for log files and image artifacts - hf_token (str): - HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. verbose (bool): Enable verbose logging. nvtx_profile (bool): @@ -98,7 +98,6 @@ def __init__( print(f"[I] Create directory: {directory}") pathlib.Path(directory).mkdir(parents=True) - self.hf_token = hf_token self.device = device self.torch_device = torch.device(device, torch.cuda.current_device()) self.verbose = verbose @@ -118,24 +117,22 @@ def __init__( # backend engine self.engine_type = engine_type if engine_type == EngineType.TRT: - self.backend = TensorrtEngineBuilder(pipeline_info, max_batch_size, hf_token, device, use_cuda_graph) + self.backend = TensorrtEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph) elif engine_type == EngineType.ORT_TRT: - self.backend = OrtTensorrtEngineBuilder(pipeline_info, max_batch_size, hf_token, device, use_cuda_graph) + self.backend = OrtTensorrtEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph) elif engine_type == EngineType.ORT_CUDA: - self.backend = OrtCudaEngineBuilder(pipeline_info, max_batch_size, hf_token, device, use_cuda_graph) + self.backend = OrtCudaEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph) + elif engine_type == EngineType.TORCH: + self.backend = TorchEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph) else: raise RuntimeError(f"Backend engine type {engine_type.name} is not supported") # Load text tokenizer if not self.pipeline_info.is_xl_refiner(): - self.tokenizer = get_tokenizer( - self.pipeline_info, self.framework_model_dir, self.hf_token, subfolder="tokenizer" - ) + self.tokenizer = get_tokenizer(self.pipeline_info, self.framework_model_dir, subfolder="tokenizer") if self.pipeline_info.is_xl(): - self.tokenizer2 = get_tokenizer( - self.pipeline_info, self.framework_model_dir, self.hf_token, subfolder="tokenizer_2" - ) + self.tokenizer2 = get_tokenizer(self.pipeline_info, self.framework_model_dir, subfolder="tokenizer_2") self.control_image_processor = None if self.pipeline_info.is_xl() and self.pipeline_info.controlnet: @@ -147,7 +144,7 @@ def __init__( # Create CUDA events self.events = {} - for stage in ["clip", "denoise", "vae", "vae_encoder"]: + for stage in ["clip", "denoise", "vae", "vae_encoder", "pil"]: for marker in ["start", "stop"]: self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1] self.markers = {} @@ -211,7 +208,7 @@ def run_engine(self, model_name, feed_dict): return self.backend.run_engine(model_name, feed_dict) def initialize_latents(self, batch_size, unet_channels, latent_height, latent_width): - latents_dtype = torch.float32 # text_embeddings.dtype + latents_dtype = torch.float16 latents_shape = (batch_size, unet_channels, latent_height, latent_width) latents = torch.randn(latents_shape, device=self.device, dtype=latents_dtype, generator=self.generator) # Scale the initial noise by the standard deviation required by the scheduler @@ -219,6 +216,7 @@ def initialize_latents(self, batch_size, unet_channels, latent_height, latent_wi return latents def initialize_timesteps(self, timesteps, strength): + """Initialize timesteps for refiner.""" self.scheduler.set_timesteps(timesteps) offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0 init_timestep = int(timesteps * strength) + offset @@ -227,6 +225,51 @@ def initialize_timesteps(self, timesteps, strength): timesteps = self.scheduler.timesteps[t_start:].to(self.device) return timesteps, t_start + def initialize_refiner(self, batch_size, image, strength): + """Add noise to a reference image.""" + # Initialize timesteps + timesteps, t_start = self.initialize_timesteps(self.denoising_steps, strength) + + latent_timestep = timesteps[:1].repeat(batch_size) + + # Pre-process input image + image = self.preprocess_images(batch_size, (image,))[0] + + # VAE encode init image + if image.shape[1] == 4: + init_latents = image + else: + init_latents = self.encode_image(image) + + # Add noise to latents using timesteps + noise = torch.randn(init_latents.shape, device=self.device, dtype=torch.float16, generator=self.generator) + + latents = self.scheduler.add_noise(init_latents, noise, t_start, latent_timestep) + + return timesteps, t_start, latents + + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype, + requires_aesthetics_score, + ): + if requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + def start_profile(self, name, color="blue"): if self.nvtx_profile: self.markers[name] = nvtx.start_range(message=name, color=color) @@ -245,7 +288,7 @@ def preprocess_images(self, batch_size, images=()): self.start_profile("preprocess", color="pink") init_images = [] for i in images: - image = i.to(self.device).float() + image = i.to(self.device) if image.shape[0] != batch_size: image = image.repeat(batch_size, 1, 1, 1) init_images.append(image) @@ -296,30 +339,46 @@ def encode_prompt( output_hidden_states=False, force_zeros_for_empty_prompt=False, do_classifier_free_guidance=True, + dtype=torch.float16, ): if tokenizer is None: tokenizer = self.tokenizer self.start_profile("clip", color="green") - # Tokenize prompt - text_input_ids = ( - tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", + def tokenize(prompt, output_hidden_states): + text_input_ids = ( + tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.device) ) - .input_ids.type(torch.int32) - .to(self.device) - ) + + hidden_states = None + if self.engine_type == EngineType.TORCH: + outputs = self.backend.engines[encoder](text_input_ids) + text_embeddings = outputs[0] + if output_hidden_states: + hidden_states = outputs["last_hidden_state"] + else: + outputs = self.run_engine(encoder, {"input_ids": text_input_ids}) + text_embeddings = outputs["text_embeddings"] + if output_hidden_states: + hidden_states = outputs["hidden_states"] + return text_embeddings, hidden_states + + # Tokenize prompt + text_embeddings, hidden_states = tokenize(prompt, output_hidden_states) # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt - outputs = self.run_engine(encoder, {"input_ids": text_input_ids}) - text_embeddings = outputs["text_embeddings"].clone() - if output_hidden_states: - hidden_states = outputs["hidden_states"].clone() + text_embeddings = text_embeddings.clone() + if hidden_states is not None: + hidden_states = hidden_states.clone() # Note: negative prompt embedding is not needed for SD XL when guidance <= 1 if do_classifier_free_guidance: @@ -331,22 +390,7 @@ def encode_prompt( uncond_hidden_states = torch.zeros_like(hidden_states) else: # Tokenize negative prompt - uncond_input_ids = ( - tokenizer( - negative_prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.device) - ) - - outputs = self.run_engine(encoder, {"input_ids": uncond_input_ids}) - uncond_embeddings = outputs["text_embeddings"] - if output_hidden_states: - uncond_hidden_states = outputs["hidden_states"] + uncond_embeddings, uncond_hidden_states = tokenize(negative_prompt, output_hidden_states) # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) @@ -363,8 +407,8 @@ def encode_prompt( self.stop_profile("clip") if pooled_outputs: - return text_embeddings.to(dtype=torch.float16), pooled_output.to(dtype=torch.float16) - return text_embeddings.to(dtype=torch.float16) + return text_embeddings.to(dtype=dtype), pooled_output.to(dtype=dtype) + return text_embeddings.to(dtype=dtype) def denoise_latent( self, @@ -373,8 +417,6 @@ def denoise_latent( denoiser="unet", timesteps=None, step_offset=0, - mask=None, - masked_image_latents=None, guidance=7.5, add_kwargs=None, ): @@ -393,18 +435,13 @@ def denoise_latent( latent_model_input, step_offset + step_index, timestep ) - if isinstance(mask, torch.Tensor): - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - # Predict the noise residual if self.nvtx_profile: nvtx_unet = nvtx.start_range(message="unet", color="blue") - timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep - params = { "sample": latent_model_input, - "timestep": timestep_float, + "timestep": timestep.to(latents.dtype), "encoder_hidden_states": text_embeddings, } @@ -434,9 +471,9 @@ def denoise_latent( self.stop_profile("denoise") return latents - def encode_image(self, init_image): + def encode_image(self, image): self.start_profile("vae_encoder", color="red") - init_latents = self.run_engine("vae_encoder", {"images": init_image})["latent"] + init_latents = self.run_engine("vae_encoder", {"images": image})["latent"] init_latents = self.vae_scaling_factor * init_latents self.stop_profile("vae_encoder") return init_latents @@ -447,7 +484,7 @@ def decode_latent(self, latents): self.stop_profile("vae") return images - def print_summary(self, tic, toc, batch_size, vae_enc=False) -> Dict[str, Any]: + def print_summary(self, tic, toc, batch_size, vae_enc=False, pil=False) -> Dict[str, Any]: throughput = batch_size / (toc - tic) latency_clip = cudart.cudaEventElapsedTime(self.events["clip-start"], self.events["clip-stop"])[1] latency_unet = cudart.cudaEventElapsedTime(self.events["denoise-start"], self.events["denoise-stop"])[1] @@ -457,6 +494,8 @@ def print_summary(self, tic, toc, batch_size, vae_enc=False) -> Dict[str, Any]: if vae_enc else None ) + latency_pil = cudart.cudaEventElapsedTime(self.events["pil-start"], self.events["pil-stop"])[1] if pil else None + latency = (toc - tic) * 1000.0 print("|----------------|--------------|") @@ -472,9 +511,11 @@ def print_summary(self, tic, toc, batch_size, vae_enc=False) -> Dict[str, Any]: ) ) print("| {:^14} | {:>9.2f} ms |".format("VAE-Dec", latency_vae)) - + pipeline = "Refiner" if self.pipeline_info.is_xl_refiner() else "Pipeline" + if pil: + print("| {:^14} | {:>9.2f} ms |".format("PIL", latency_pil)) print("|----------------|--------------|") - print("| {:^14} | {:>9.2f} ms |".format("Pipeline", latency)) + print(f"| {pipeline:^14} | {latency:>9.2f} ms |") print("|----------------|--------------|") print(f"Throughput: {throughput:.2f} image/s") @@ -482,6 +523,7 @@ def print_summary(self, tic, toc, batch_size, vae_enc=False) -> Dict[str, Any]: "latency_clip": latency_clip, "latency_unet": latency_unet, "latency_vae": latency_vae, + "latency_pil": latency_pil, "latency": latency, "throughput": throughput, } @@ -490,15 +532,19 @@ def print_summary(self, tic, toc, batch_size, vae_enc=False) -> Dict[str, Any]: return perf_data @staticmethod - def to_pil_image(images): + def pt_to_pil(images): images = ( ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy() ) - - from PIL import Image - return [Image.fromarray(images[i]) for i in range(images.shape[0])] + @staticmethod + def pt_to_numpy(images: torch.FloatTensor): + """ + Convert a PyTorch tensor to a NumPy image. + """ + return ((images + 1) / 2).clamp(0, 1).detach().permute(0, 2, 3, 1).float().cpu().numpy() + def metadata(self) -> Dict[str, Any]: return { "actual_steps": self.actual_steps, @@ -509,7 +555,6 @@ def metadata(self) -> Dict[str, Any]: } def save_images(self, images: List, prompt: List[str], negative_prompt: List[str], metadata: Dict[str, Any]): - images = self.to_pil_image(images) session_id = str(random.randint(1000, 9999)) for i, image in enumerate(images): seed = str(self.get_current_seed()) @@ -527,3 +572,249 @@ def save_images(self, images: List, prompt: List[str], negative_prompt: List[str info.add_text("negative_prompt", negative_prompt[i]) image.save(image_path, "PNG", pnginfo=info) + + def _infer( + self, + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=30, + guidance=5.0, + seed=None, + image=None, + strength=0.3, + controlnet_images=None, + controlnet_scales=None, + show_latency=False, + output_type="pil", + ): + if show_latency: + torch.cuda.synchronize() + start_time = time.perf_counter() + + assert len(prompt) == len(negative_prompt) + batch_size = len(prompt) + + self.set_denoising_steps(denoising_steps) + self.set_random_seed(seed) + + timesteps = None + step_offset = 0 + with torch.inference_mode(), torch.autocast("cuda"): + if image is not None: + timesteps, step_offset, latents = self.initialize_refiner( + batch_size=batch_size, + image=image, + strength=strength, + ) + else: + # Pre-initialize latents + latents = self.initialize_latents( + batch_size=batch_size, + unet_channels=4, + latent_height=(image_height // 8), + latent_width=(image_width // 8), + ) + + do_classifier_free_guidance = guidance > 1.0 + if not self.pipeline_info.is_xl(): + denoiser = "unet" + text_embeddings = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + dtype=latents.dtype, + ) + add_kwargs = {} + else: + denoiser = "unetxl" + + # Time embeddings + original_size = (image_height, image_width) + crops_coords_top_left = (0, 0) + target_size = (image_height, image_width) + aesthetic_score = 6.0 + negative_aesthetic_score = 2.5 + add_time_ids, add_negative_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype=latents.dtype, + requires_aesthetics_score=self.pipeline_info.is_xl_refiner(), + ) + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_negative_time_ids, add_time_ids], dim=0) + add_time_ids = add_time_ids.to(device=self.device).repeat(batch_size, 1) + + if self.pipeline_info.is_xl_refiner(): + # CLIP text encoder 2 + text_embeddings, pooled_embeddings2 = self.encode_prompt( + prompt, + negative_prompt, + encoder="clip2", + tokenizer=self.tokenizer2, + pooled_outputs=True, + output_hidden_states=True, + dtype=latents.dtype, + ) + add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids} + else: # XL Base + # CLIP text encoder + text_embeddings = self.encode_prompt( + prompt, + negative_prompt, + encoder="clip", + tokenizer=self.tokenizer, + output_hidden_states=True, + force_zeros_for_empty_prompt=True, + do_classifier_free_guidance=do_classifier_free_guidance, + dtype=latents.dtype, + ) + # CLIP text encoder 2 + text_embeddings2, pooled_embeddings2 = self.encode_prompt( + prompt, + negative_prompt, + encoder="clip2", + tokenizer=self.tokenizer2, + pooled_outputs=True, + output_hidden_states=True, + force_zeros_for_empty_prompt=True, + do_classifier_free_guidance=do_classifier_free_guidance, + dtype=latents.dtype, + ) + + # Merged text embeddings + text_embeddings = torch.cat([text_embeddings, text_embeddings2], dim=-1) + + add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids} + + if self.pipeline_info.controlnet: + controlnet_images = self.preprocess_controlnet_images( + latents.shape[0], + controlnet_images, + do_classifier_free_guidance=do_classifier_free_guidance, + height=image_height, + width=image_width, + ) + add_kwargs.update( + { + "controlnet_images": controlnet_images, + "controlnet_scales": controlnet_scales.to(controlnet_images.dtype).to(controlnet_images.device), + } + ) + + # UNet denoiser + latents = self.denoise_latent( + latents, + text_embeddings, + timesteps=timesteps, + step_offset=step_offset, + denoiser=denoiser, + guidance=guidance, + add_kwargs=add_kwargs, + ) + + with torch.inference_mode(): + # VAE decode latent + if output_type == "latent": + images = latents + else: + images = self.decode_latent(latents / self.vae_scaling_factor) + if output_type == "pil": + self.start_profile("pil", color="green") + images = self.pt_to_pil(images) + self.stop_profile("pil") + + perf_data = None + if show_latency: + torch.cuda.synchronize() + end_time = time.perf_counter() + perf_data = self.print_summary( + start_time, end_time, batch_size, vae_enc=self.pipeline_info.is_xl_refiner(), pil=(output_type == "pil") + ) + + return images, perf_data + + def run( + self, + prompt: List[str], + negative_prompt: List[str], + image_height: int, + image_width: int, + denoising_steps: int = 30, + guidance: float = 5.0, + seed: Optional[int] = None, + image: Optional[torch.Tensor] = None, + strength: float = 0.3, + controlnet_images: Optional[torch.Tensor] = None, + controlnet_scales: Optional[torch.Tensor] = None, + show_latency: bool = False, + output_type: str = "pil", + ): + """ + Run the diffusion pipeline. + + Args: + prompt (List[str]): + The text prompt to guide image generation. + negative_prompt (List[str]): + The prompt not to guide the image generation. + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + denoising_steps (int): + Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference. + guidance (float): + Higher guidance scale encourages to generate images that are closely linked to the text prompt. + seed (int): + Seed for the random generator + image (tuple[torch.Tensor]): + Reference image. + strength (float): + Indicates extent to transform the reference image, which is used as a starting point, + and more noise is added the higher the strength. + show_latency (bool): + Whether return latency data. + output_type (str): + It can be "latent", "pt" or "pil". + """ + if self.is_backend_tensorrt(): + import tensorrt as trt + from trt_utilities import TRT_LOGGER + + with trt.Runtime(TRT_LOGGER): + return self._infer( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + image=image, + strength=strength, + controlnet_images=controlnet_images, + controlnet_scales=controlnet_scales, + show_latency=show_latency, + output_type=output_type, + ) + else: + return self._infer( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + image=image, + strength=strength, + controlnet_images=controlnet_images, + controlnet_scales=controlnet_scales, + show_latency=show_latency, + output_type=output_type, + ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py deleted file mode 100644 index 2d2fdb542c845..0000000000000 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py +++ /dev/null @@ -1,178 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -# Modified from TensorRT demo diffusion, which has the following license: -# -# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- - -import time - -import torch -from diffusion_models import PipelineInfo -from pipeline_stable_diffusion import StableDiffusionPipeline - - -class Txt2ImgPipeline(StableDiffusionPipeline): - """ - Stable Diffusion Txt2Img pipeline using NVidia TensorRT. - """ - - def __init__(self, pipeline_info: PipelineInfo, **kwargs): - """ - Initializes the Txt2Img Diffusion pipeline. - - Args: - pipeline_info (PipelineInfo): - Version and Type of stable diffusion pipeline. - """ - super().__init__(pipeline_info, **kwargs) - - def _infer( - self, - prompt, - negative_prompt, - image_height, - image_width, - denoising_steps=50, - guidance=7.5, - seed=None, - controlnet_images=None, - controlnet_scales=None, - warmup=False, - return_type="latent", - ): - assert len(prompt) == len(negative_prompt) - batch_size = len(prompt) - - self.set_denoising_steps(denoising_steps) - self.set_random_seed(seed) - - with torch.inference_mode(), torch.autocast("cuda"): - # Pre-initialize latents - latents = self.initialize_latents( - batch_size=batch_size, - unet_channels=4, - latent_height=(image_height // 8), - latent_width=(image_width // 8), - ) - - torch.cuda.synchronize() - e2e_tic = time.perf_counter() - - # CLIP text encoder - do_classifier_free_guidance = guidance > 1.0 - text_embeddings = self.encode_prompt( - prompt, - negative_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - ) - - add_kwargs = None - if self.pipeline_info.controlnet: - controlnet_images = self.preprocess_controlnet_images( - latents.shape[0], controlnet_images, do_classifier_free_guidance=do_classifier_free_guidance - ) - add_kwargs = { - "controlnet_images": controlnet_images, - "controlnet_scales": controlnet_scales.to(controlnet_images.dtype).to(controlnet_images.device), - } - - # UNet denoiser - latents = self.denoise_latent(latents, text_embeddings, guidance=guidance, add_kwargs=add_kwargs) - - # VAE decode latent - images = self.decode_latent(latents / self.vae_scaling_factor) - - torch.cuda.synchronize() - e2e_toc = time.perf_counter() - - perf_data = None - if not warmup: - perf_data = self.print_summary(e2e_tic, e2e_toc, batch_size) - - return images, perf_data - - def run( - self, - prompt, - negative_prompt, - image_height, - image_width, - denoising_steps=30, - guidance=7.5, - seed=None, - controlnet_images=None, - controlnet_scales=None, - warmup=False, - return_type="image", - ): - """ - Run the diffusion pipeline. - - Args: - prompt (str): - The text prompt to guide image generation. - negative_prompt (str): - The prompt not to guide the image generation. - image_height (int): - Height (in pixels) of the image to be generated. Must be a multiple of 8. - image_width (int): - Width (in pixels) of the image to be generated. Must be a multiple of 8. - denoising_steps (int): - Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference. - guidance (float): - Higher guidance scale encourages to generate images that are closely linked to the text prompt. - seed (int): - Seed for the random generator - warmup (bool): - Indicate if this is a warmup run. - return_type (str): - type of return. The value can be "latent" or "image". - """ - if self.is_backend_tensorrt(): - import tensorrt as trt - from trt_utilities import TRT_LOGGER - - with trt.Runtime(TRT_LOGGER): - return self._infer( - prompt, - negative_prompt, - image_height, - image_width, - denoising_steps=denoising_steps, - guidance=guidance, - seed=seed, - controlnet_images=controlnet_images, - controlnet_scales=controlnet_scales, - warmup=warmup, - return_type=return_type, - ) - else: - return self._infer( - prompt, - negative_prompt, - image_height, - image_width, - denoising_steps=denoising_steps, - guidance=guidance, - seed=seed, - controlnet_images=controlnet_images, - controlnet_scales=controlnet_scales, - warmup=warmup, - return_type=return_type, - ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py deleted file mode 100644 index fa0035494217b..0000000000000 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py +++ /dev/null @@ -1,231 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -# Modified from TensorRT demo diffusion, which has the following license: -# -# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- - -import time - -import torch -from diffusion_models import PipelineInfo -from pipeline_stable_diffusion import StableDiffusionPipeline - - -class Txt2ImgXLPipeline(StableDiffusionPipeline): - """ - Stable Diffusion Txt2Img XL pipeline. - """ - - def __init__(self, pipeline_info: PipelineInfo, *args, **kwargs): - """ - Initializes the Txt2Img XL Diffusion pipeline. - - Args: - pipeline_info (PipelineInfo): - Version and Type of stable diffusion pipeline. - """ - assert pipeline_info.is_xl_base_or_turbo() - - super().__init__(pipeline_info, *args, **kwargs) - - def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids - - def _infer( - self, - prompt, - negative_prompt, - image_height, - image_width, - denoising_steps=30, - guidance=5.0, - seed=None, - controlnet_images=None, - controlnet_scales=None, - warmup=False, - return_type="image", - ): - assert len(prompt) == len(negative_prompt) - do_classifier_free_guidance = guidance > 1.0 - original_size = (image_height, image_width) - crops_coords_top_left = (0, 0) - target_size = (image_height, image_width) - batch_size = len(prompt) - - self.set_denoising_steps(denoising_steps) - self.set_random_seed(seed) - - with torch.inference_mode(), torch.autocast("cuda"): - # Pre-initialize latents - latents = self.initialize_latents( - batch_size=batch_size, - unet_channels=4, - latent_height=(image_height // 8), - latent_width=(image_width // 8), - ) - - torch.cuda.synchronize() - e2e_tic = time.perf_counter() - - # CLIP text encoder - text_embeddings = self.encode_prompt( - prompt, - negative_prompt, - encoder="clip", - tokenizer=self.tokenizer, - output_hidden_states=True, - force_zeros_for_empty_prompt=True, - do_classifier_free_guidance=do_classifier_free_guidance, - ) - # CLIP text encoder 2 - text_embeddings2, pooled_embeddings2 = self.encode_prompt( - prompt, - negative_prompt, - encoder="clip2", - tokenizer=self.tokenizer2, - pooled_outputs=True, - output_hidden_states=True, - force_zeros_for_empty_prompt=True, - do_classifier_free_guidance=do_classifier_free_guidance, - ) - - # Merged text embeddings - text_embeddings = torch.cat([text_embeddings, text_embeddings2], dim=-1) - - # Time embeddings - add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=text_embeddings.dtype - ) - add_time_ids = add_time_ids.repeat(batch_size, 1) - if do_classifier_free_guidance: - add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) - - add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids.to(self.device)} - if self.pipeline_info.controlnet: - controlnet_images = self.preprocess_controlnet_images( - latents.shape[0], - controlnet_images, - do_classifier_free_guidance=do_classifier_free_guidance, - height=image_height, - width=image_width, - ) - add_kwargs.update( - { - "controlnet_images": controlnet_images, - "controlnet_scales": controlnet_scales.to(controlnet_images.dtype).to(controlnet_images.device), - } - ) - - # UNet denoiser - latents = self.denoise_latent( - latents, - text_embeddings, - denoiser="unetxl", - guidance=guidance, - add_kwargs=add_kwargs, - ) - - # VAE decode latent - if return_type == "latent": - images = latents - else: - images = self.decode_latent(latents / self.vae_scaling_factor) - - torch.cuda.synchronize() - e2e_toc = time.perf_counter() - - perf_data = None - if not warmup: - print("SD-XL Base Pipeline") - perf_data = self.print_summary(e2e_tic, e2e_toc, batch_size) - - return images, perf_data - - def run( - self, - prompt, - negative_prompt, - image_height, - image_width, - denoising_steps=30, - guidance=5.0, - seed=None, - controlnet_images=None, - controlnet_scales=None, - warmup=False, - return_type="image", - ): - """ - Run the diffusion pipeline. - - Args: - prompt (str): - The text prompt to guide image generation. - negative_prompt (str): - The prompt not to guide the image generation. - image_height (int): - Height (in pixels) of the image to be generated. Must be a multiple of 8. - image_width (int): - Width (in pixels) of the image to be generated. Must be a multiple of 8. - denoising_steps (int): - Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference. - guidance (float): - Higher guidance scale encourages to generate images that are closely linked to the text prompt. - seed (int): - Seed for the random generator - warmup (bool): - Indicate if this is a warmup run. - return_type (str): - It can be "latent" or "image". - """ - - if self.is_backend_tensorrt(): - import tensorrt as trt - from trt_utilities import TRT_LOGGER - - with trt.Runtime(TRT_LOGGER): - return self._infer( - prompt, - negative_prompt, - image_height, - image_width, - denoising_steps=denoising_steps, - guidance=guidance, - seed=seed, - controlnet_images=controlnet_images, - controlnet_scales=controlnet_scales, - warmup=warmup, - return_type=return_type, - ) - else: - return self._infer( - prompt, - negative_prompt, - image_height, - image_width, - denoising_steps=denoising_steps, - guidance=guidance, - seed=seed, - controlnet_images=controlnet_images, - controlnet_scales=controlnet_scales, - warmup=warmup, - return_type=return_type, - ) diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 7bdbc08cf733a..37b39c91b5c15 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -1311,3 +1311,119 @@ def use_float16(self): queue = sub_graphs return False + + def change_graph_input_type( + self, + graph_input: ValueInfoProto, + new_type: int, + ): + """Change graph input type, and add Cast node if needed. + + Args: + graph_input (ValueInfoProto): input of the graph + new_type (int): new data type like TensorProto.INT32. + + Returns: + NodeProto: a new Cast node that added. None if Cast node is not added. + List[NodeProto]: Cast nodes that have been removed. + """ + assert isinstance(graph_input, ValueInfoProto) + assert self.find_graph_input(graph_input.name) + + if graph_input.type.tensor_type.elem_type == int(new_type): + return None, [] + + graph = self.graph() + new_cast_node = None + nodes_to_remove = [] + + input_name_to_nodes = self.input_name_to_nodes() + if graph_input.name in input_name_to_nodes: + nodes = input_name_to_nodes[graph_input.name] + + # For children that is not Cast node, insert a Cast node to convert int32 to original data type. + nodes_not_cast = [node for node in nodes if node.op_type != "Cast"] + if nodes_not_cast: + node_name = self.create_node_name("Cast") + output_name = node_name + "_" + graph_input.name + new_value_info = graph.value_info.add() + new_value_info.CopyFrom(graph_input) + new_value_info.name = output_name + new_cast_node = helper.make_node( + "Cast", + [graph_input.name], + [output_name], + to=int(graph_input.type.tensor_type.elem_type), + name=node_name, + ) + graph.node.extend([new_cast_node]) + + for node in nodes_not_cast: + OnnxModel.replace_node_input(node, graph_input.name, output_name) + + # For children that is Cast node, no need to insert Cast. + # When the children is Cast to int32, we can remove that Cast node since input type is int32 now. + nodes_cast = [node for node in nodes if node.op_type == "Cast"] + for node in nodes_cast: + if OnnxModel.get_node_attribute(node, "to") == int(new_type): + self.replace_input_of_all_nodes(node.output[0], graph_input.name) + if not self.find_graph_output(node.output[0]): + nodes_to_remove.append(node) + if nodes_to_remove: + self.remove_nodes(nodes_to_remove) + + graph_input.type.tensor_type.elem_type = int(new_type) + return new_cast_node, nodes_to_remove + + def change_graph_output_type( + self, + graph_output: ValueInfoProto, + new_type: int, + ): + """Change graph input type, and add Cast node if needed. + + Args: + graph_input (str | ValueInfoProto): output of the graph + new_type (int): new data type. + + Returns: + NodeProto: a new Cast node that added. None if Cast node is not added. + """ + assert isinstance(graph_output, ValueInfoProto) + assert self.find_graph_output(graph_output.name) + + if graph_output.type.tensor_type.elem_type == int(new_type): + return None + + cast_node = None + graph = self.graph() + + # Add a cast node + node_name = self.create_node_name("Cast") + input_name = node_name + "_" + graph_output.name + self.replace_input_of_all_nodes(graph_output.name, input_name) + new_value_info = graph.value_info.add() + new_value_info.CopyFrom(graph_output) + new_value_info.name = input_name + cast_node = helper.make_node( + "Cast", + [input_name], + [graph_output.name], + to=int(new_type), + name=node_name, + ) + graph.node.extend([cast_node]) + graph_output.type.tensor_type.elem_type = int(new_type) + return cast_node + + def rename_graph_output(self, old_name: str, new_name: str): + if new_name in self.output_name_to_node(): + raise RuntimeError("{new_name} exists in graph") + + graph = self.graph() + for output in graph.output: + if output.name == old_name: + logger.debug("replace output name from %s to %s", old_name, new_name) + self.replace_input_of_all_nodes(old_name, new_name) + self.replace_output_of_all_nodes(old_name, new_name) + output.name = new_name diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 882100a0d019e..51deb67ce5bf3 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -27,7 +27,7 @@ from fusion_simplified_layernorm import FusionSimplifiedLayerNormalization, FusionSkipSimplifiedLayerNormalization from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization from fusion_utils import FusionUtils -from onnx import GraphProto, ModelProto, TensorProto, ValueInfoProto, helper +from onnx import ModelProto, TensorProto, helper from onnx_model import OnnxModel logger = getLogger(__name__) @@ -170,78 +170,13 @@ def get_graph_inputs_from_fused_nodes(self, casted: bool): inputs += self.get_graph_inputs_from_node_type("Attention", [3], casted) return inputs - def change_graph_input_type( - self, - graph: GraphProto, - graph_input: ValueInfoProto, - new_type: int = TensorProto.INT32, - ): - """Change graph input type, and add Cast node if needed. - - Args: - graph (GraphProto): graph - graph_input (TensorProto): input of the graph - new_type (int, optional): new data type. Defaults to TensorProto.INT32. - - Returns: - NodeProto: a new Cast node that added. None if Cast node is not added. - List[NodeProto]: Cast nodes that have been removed. - """ - assert isinstance(graph, GraphProto) - assert isinstance(graph_input, ValueInfoProto) - assert self.find_graph_input(graph_input.name) - - if graph_input.type.tensor_type.elem_type == int(new_type): - return None, [] - - new_cast_node = None - nodes_to_remove = [] - - input_name_to_nodes = self.input_name_to_nodes() - if graph_input.name in input_name_to_nodes: - nodes = input_name_to_nodes[graph_input.name] - - # For children that is not Cast node, insert a Cast node to convert int32 to original data type. - nodes_not_cast = [node for node in nodes if node.op_type != "Cast"] - if nodes_not_cast: - node_name = self.create_node_name("Cast") - output_name = node_name + "_" + graph_input.name - new_value_info = graph.value_info.add() - new_value_info.CopyFrom(graph_input) - new_value_info.name = output_name - new_cast_node = helper.make_node( - "Cast", - [graph_input.name], - [output_name], - to=int(graph_input.type.tensor_type.elem_type), - name=node_name, - ) - graph.node.extend([new_cast_node]) - - for node in nodes_not_cast: - OnnxModel.replace_node_input(node, graph_input.name, output_name) - - # For children that is Cast node, no need to insert Cast. - # When the children is Cast to int32, we can remove that Cast node since input type is int32 now. - nodes_cast = [node for node in nodes if node.op_type == "Cast"] - for node in nodes_cast: - if OnnxModel.get_node_attribute(node, "to") == int(new_type): - self.replace_input_of_all_nodes(node.output[0], graph_input.name) - if not self.find_graph_output(node.output[0]): - nodes_to_remove.append(node) - if nodes_to_remove: - self.remove_nodes(nodes_to_remove) - - graph_input.type.tensor_type.elem_type = int(new_type) - return new_cast_node, nodes_to_remove - def change_graph_inputs_to_int32(self): """Change data type of all graph inputs to int32 type, and add Cast node if needed.""" graph = self.graph() add_cast_count = 0 remove_cast_count = 0 for graph_input in graph.input: - new_node, removed_nodes = self.change_graph_input_type(graph, graph_input, TensorProto.INT32) + new_node, removed_nodes = self.change_graph_input_type(graph_input, TensorProto.INT32) if new_node: add_cast_count += 1 remove_cast_count += len(removed_nodes)