diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 67d3c95922a87..4f898245d01bd 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -542,7 +542,7 @@ def measure_gpu_usage(self): while True: for i in range(device_count): max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i)) - time.sleep(0.005) # 2ms + time.sleep(0.005) # 5ms if not self.keep_measuring: break return [ @@ -555,7 +555,7 @@ def measure_gpu_usage(self): ] -def measure_memory(is_gpu, func, monitor_type="cuda"): +def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None): memory_monitor_type = None if monitor_type == "rocm": memory_monitor_type = RocmMemoryMonitor @@ -565,10 +565,16 @@ def measure_memory(is_gpu, func, monitor_type="cuda"): monitor = memory_monitor_type(False) if is_gpu: - memory_before_test = monitor.measure_gpu_usage() + if start_memory is not None: + memory_before_test = start_memory + else: + memory_before_test = monitor.measure_gpu_usage() if memory_before_test is None: return None + if func is None: + return memory_before_test + with ThreadPoolExecutor() as executor: monitor = memory_monitor_type() mem_thread = executor.submit(monitor.measure_gpu_usage) @@ -595,7 +601,13 @@ def measure_memory(is_gpu, func, monitor_type="cuda"): return None # CPU memory - memory_before_test = monitor.measure_cpu_usage() + if start_memory is not None: + memory_before_test = start_memory + else: + memory_before_test = monitor.measure_cpu_usage() + + if func is None: + return memory_before_test with ThreadPoolExecutor() as executor: monitor = MemoryMonitor() diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index 71c1a21d8f768..de17f195c99cc 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -283,6 +283,7 @@ def infer(self, feed_dict: Dict[str, torch.Tensor]): if name in self.input_names: if self.enable_cuda_graph: assert self.input_tensors[name].nelement() == tensor.nelement() + assert self.input_tensors[name].dtype == tensor.dtype assert tensor.device.type == "cuda" # Please install cuda-python package with a version corresponding to CUDA in your machine. from cuda import cudart diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 7ffefdd05f215..1fbd5092a719a 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -74,15 +74,16 @@ Below is an example to optimize Stable Diffusion 1.5 in Linux. For Windows OS, p ### Setup Environment (CUDA) -It is recommended to create a Conda environment with Python 3.8, 3.9 or 3.10, and run the model with [CUDA 11.7](https://developer.nvidia.com/cuda-11-7-0-download-archive) or 11.8. +It is recommended to create a Conda environment with Python 3.8, 3.9 or 3.10, and run the model with CUDA 11.8. +If you use CUDA 12.*, you will need build onnxruntime-gpu from source. ``` conda create -n py38 python=3.8 conda activate py38 -pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 +pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 +pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com pip install -r requirements-cuda.txt ``` - -ONNX Runtime requires CUDA and [cuDNN](https://developer.nvidia.com/rdp/cudnn-download) for GPU inference. CUDA 11.7 and cuDNN 8.5 are used in our tests. +ONNX Runtime requires CUDA and [cuDNN](https://developer.nvidia.com/rdp/cudnn-download) for GPU inference. CUDA 11.8 and cuDNN 8.5 or above are recommended. #### Install Nightly (Optional) @@ -233,18 +234,21 @@ Sometime, it complains ptxas not found when there are multiple CUDA versions ins Note that torch.compile is not supported in Windows: we encountered error `Windows not yet supported for torch.compile`. So it is excluded from RTX 3060 results of Windows. -### Run Benchmark with TensorRT and TensorRT execution provider +### Run Benchmark with TensorRT or TensorRT execution provider For TensorRT installation, follow https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html. ``` pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 -pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com pip install -r requirements-tensorrt.txt export CUDA_MODULE_LOADING=LAZY python benchmark.py -e tensorrt -b 1 -v 1.5 python benchmark.py -e onnxruntime -r tensorrt -b 1 -v 1.5 python benchmark.py -e onnxruntime -r tensorrt -b 1 -v 1.5 --enable_cuda_graph + +python benchmark.py -e tensorrt --height 1024 --width 1024 -s 30 -b 1 -v xl-1.0 --enable_cuda_graph +python benchmark.py -e onnxruntime -r tensorrt --height 1024 --width 1024 -s 30 -b 1 -v xl-1.0 --enable_cuda_graph ``` ### Example Benchmark output diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index 13126f648d290..f8fda13a35b93 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -10,15 +10,18 @@ import sys import time +import __init__ # noqa: F401. Walk-around to run this script directly import coloredlogs # import torch before onnxruntime so that onnxruntime uses the cuDNN in the torch package. import torch +from benchmark_helper import measure_memory SD_MODELS = { "1.5": "runwayml/stable-diffusion-v1-5", "2.0": "stabilityai/stable-diffusion-2", "2.1": "stabilityai/stable-diffusion-2-1", + "xl-1.0": "stabilityai/stable-diffusion-xl-refiner-1.0", } PROVIDERS = { @@ -43,139 +46,13 @@ def example_prompts(): "delicate elvish moonstone necklace on a velvet background, symmetrical intricate motifs, leaves, flowers, 8k", ] - return prompts + negative_prompt = "bad composition, ugly, abnormal, malformed" - -class CudaMemoryMonitor: - def __init__(self, keep_measuring=True): - self.keep_measuring = keep_measuring - - def measure_gpu_usage(self): - from py3nvml.py3nvml import ( - NVMLError, - nvmlDeviceGetCount, - nvmlDeviceGetHandleByIndex, - nvmlDeviceGetMemoryInfo, - nvmlDeviceGetName, - nvmlInit, - nvmlShutdown, - ) - - max_gpu_usage = [] - gpu_name = [] - try: - nvmlInit() - device_count = nvmlDeviceGetCount() - if not isinstance(device_count, int): - print(f"nvmlDeviceGetCount result is not integer: {device_count}") - return None - - max_gpu_usage = [0 for i in range(device_count)] - gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)] - while True: - for i in range(device_count): - info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i)) - if isinstance(info, str): - print(f"nvmlDeviceGetMemoryInfo returns str: {info}") - return None - max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2) - time.sleep(0.002) # 2ms - if not self.keep_measuring: - break - nvmlShutdown() - return [ - { - "device_id": i, - "name": gpu_name[i], - "max_used_MB": max_gpu_usage[i], - } - for i in range(device_count) - ] - except NVMLError as error: - print("Error fetching GPU information using nvml: %s", error) - return None - - -class RocmMemoryMonitor: - def __init__(self, keep_measuring=True): - self.keep_measuring = keep_measuring - rocm_smi_path = "/opt/rocm/libexec/rocm_smi" - if os.path.exists(rocm_smi_path): - if rocm_smi_path not in sys.path: - sys.path.append(rocm_smi_path) - try: - import rocm_smi - - self.rocm_smi = rocm_smi - self.rocm_smi.initializeRsmi() - except ImportError: - self.rocm_smi = None - - def get_used_memory(self, dev): - if self.rocm_smi is None: - return -1 - return self.rocm_smi.getMemInfo(dev, "VRAM")[0] / 1024 / 1024 - - def measure_gpu_usage(self): - device_count = len(self.rocm_smi.listDevices()) if self.rocm_smi is not None else 0 - max_gpu_usage = [0 for i in range(device_count)] - gpu_name = [f"GPU{i}" for i in range(device_count)] - while True: - for i in range(device_count): - max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i)) - time.sleep(0.002) # 2ms - if not self.keep_measuring: - break - return [ - { - "device_id": i, - "name": gpu_name[i], - "max_used_MB": max_gpu_usage[i], - } - for i in range(device_count) - ] + return prompts, negative_prompt def measure_gpu_memory(monitor_type, func, start_memory=None): - if monitor_type is None: - return None - - monitor = monitor_type(False) - memory_before_test = monitor.measure_gpu_usage() - - if start_memory is None: - start_memory = memory_before_test - if start_memory is None: - return None - if func is None: - return start_memory - - from concurrent.futures import ThreadPoolExecutor - - with ThreadPoolExecutor() as executor: - monitor = monitor_type() - mem_thread = executor.submit(monitor.measure_gpu_usage) - try: - fn_thread = executor.submit(func) - _ = fn_thread.result() - finally: - monitor.keep_measuring = False - max_usage = mem_thread.result() - - if max_usage is None: - return None - - print(f"GPU memory usage: before={memory_before_test} peak={max_usage}") - if len(start_memory) >= 1 and len(max_usage) >= 1 and len(start_memory) == len(max_usage): - # When there are multiple GPUs, we will check the one with maximum usage. - max_used = 0 - for i, memory_before in enumerate(start_memory): - before = memory_before["max_used_MB"] - after = max_usage[i]["max_used_MB"] - used = after - before - max_used = max(max_used, used) - return max_used - return None + return measure_memory(is_gpu=True, func=func, monitor_type=monitor_type, start_memory=start_memory) def get_ort_pipeline(model_name: str, directory: str, provider, disable_safety_checker: bool): @@ -256,7 +133,7 @@ def run_ort_pipeline( assert isinstance(pipe, OnnxStableDiffusionPipeline) - prompts = example_prompts() + prompts, negative_prompt = example_prompts() def warmup(): pipe("warm up", height, width, num_inference_steps=steps, num_images_per_prompt=batch_size) @@ -275,13 +152,12 @@ def warmup(): for j in range(batch_count): inference_start = time.time() images = pipe( - prompt, + [prompt] * batch_size, height, width, num_inference_steps=steps, - negative_prompt=None, + negative_prompt=[negative_prompt] * batch_size, guidance_scale=7.5, - num_images_per_prompt=batch_size, ).images inference_end = time.time() latency = inference_end - inference_start @@ -320,7 +196,7 @@ def run_torch_pipeline( start_memory, memory_monitor_type, ): - prompts = example_prompts() + prompts, negative_prompt = example_prompts() # total 2 runs of warm up, and measure GPU memory for CUDA EP def warmup(): @@ -342,13 +218,12 @@ def warmup(): for j in range(batch_count): inference_start = time.time() images = pipe( - prompt=prompt, + prompt=[prompt] * batch_size, height=height, width=width, num_inference_steps=steps, guidance_scale=7.5, - negative_prompt=None, - num_images_per_prompt=batch_size, + negative_prompt=[negative_prompt] * batch_size, generator=None, # torch.Generator ).images @@ -427,7 +302,7 @@ def run_ort( def export_and_run_ort( - model_name: str, + version: str, provider: str, batch_size: int, disable_safety_checker: bool, @@ -443,15 +318,19 @@ def export_and_run_ort( assert provider == "CUDAExecutionProvider" from diffusers import DDIMScheduler + from diffusion_models import PipelineInfo from onnxruntime_cuda_txt2img import OnnxruntimeCudaStableDiffusionPipeline - scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") + pipeline_info = PipelineInfo(version) + model_name = pipeline_info.name() + scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") pipe = OnnxruntimeCudaStableDiffusionPipeline.from_pretrained( model_name, scheduler=scheduler, requires_safety_checker=not disable_safety_checker, enable_cuda_graph=enable_cuda_graph, + pipeline_info=pipeline_info, ) # re-use cached folder to save ONNX models @@ -473,7 +352,7 @@ def warmup(): image_filename_prefix = get_image_filename_prefix("ort_cuda", model_name, batch_size, disable_safety_checker) latency_list = [] - prompts = example_prompts() + prompts, negative_prompt = example_prompts() for i, prompt in enumerate(prompts): if i >= num_prompts: break @@ -481,6 +360,7 @@ def warmup(): inference_start = time.time() images = pipe( [prompt] * batch_size, + negative_prompt=[negative_prompt] * batch_size, num_inference_steps=steps, ).images inference_end = time.time() @@ -514,7 +394,7 @@ def warmup(): def run_ort_trt( - model_name: str, + version: str, batch_size: int, disable_safety_checker: bool, height: int, @@ -528,8 +408,12 @@ def run_ort_trt( enable_cuda_graph: bool, ): from diffusers import DDIMScheduler + from diffusion_models import PipelineInfo from onnxruntime_tensorrt_txt2img import OnnxruntimeTensorRTStableDiffusionPipeline + pipeline_info = PipelineInfo(version) + model_name = pipeline_info.name() + assert batch_size <= max_batch_size scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") @@ -544,6 +428,7 @@ def run_ort_trt( max_batch_size=max_batch_size, onnx_opset=17, enable_cuda_graph=enable_cuda_graph, + pipeline_info=pipeline_info, ) # re-use cached folder to save ONNX models and TensorRT Engines @@ -552,7 +437,7 @@ def run_ort_trt( pipe = pipe.to("cuda") def warmup(): - pipe(["warm up"] * batch_size, num_inference_steps=steps) + pipe(["warm up"] * batch_size, negative_prompt=["negative"] * batch_size, num_inference_steps=steps) # Run warm up, and measure GPU memory of two runs # The first run has algo search so it might need more memory @@ -564,7 +449,7 @@ def warmup(): image_filename_prefix = get_image_filename_prefix("ort_trt", model_name, batch_size, disable_safety_checker) latency_list = [] - prompts = example_prompts() + prompts, negative_prompt = example_prompts() for i, prompt in enumerate(prompts): if i >= num_prompts: break @@ -572,6 +457,7 @@ def warmup(): inference_start = time.time() images = pipe( [prompt] * batch_size, + negative_prompt=[negative_prompt] * batch_size, num_inference_steps=steps, ).images inference_end = time.time() @@ -589,7 +475,7 @@ def warmup(): "model_name": model_name, "engine": "onnxruntime", "version": ort_version, - "provider": f"tensorrt{trt_version})", + "provider": f"tensorrt({trt_version})", "directory": pipe.engine_dir, "height": height, "width": width, @@ -606,7 +492,148 @@ def warmup(): } -def run_tensorrt( +def run_ort_trt_static( + work_dir: str, + version: str, + batch_size: int, + disable_safety_checker: bool, + height: int, + width: int, + steps: int, + num_prompts: int, + batch_count: int, + start_memory, + memory_monitor_type, + max_batch_size: int, + nvtx_profile: bool = False, + use_cuda_graph: bool = True, +): + print("[I] Initializing ORT TensorRT EP accelerated StableDiffusionXL txt2img pipeline (static input shape)") + + # Register TensorRT plugins + from trt_utilities import init_trt_plugins + + init_trt_plugins() + + assert batch_size <= max_batch_size + + from diffusion_models import PipelineInfo + + pipeline_info = PipelineInfo(version) + short_name = pipeline_info.short_name() + + from engine_builder import EngineType, get_engine_paths + from pipeline_txt2img import Txt2ImgPipeline + + engine_type = EngineType.ORT_TRT + onnx_dir, engine_dir, output_dir, framework_model_dir, _ = get_engine_paths(work_dir, pipeline_info, engine_type) + + # Initialize pipeline + pipeline = Txt2ImgPipeline( + pipeline_info, + scheduler="DDIM", + output_dir=output_dir, + hf_token=None, + verbose=False, + nvtx_profile=nvtx_profile, + max_batch_size=max_batch_size, + use_cuda_graph=use_cuda_graph, + framework_model_dir=framework_model_dir, + engine_type=engine_type, + ) + + # Load TensorRT engines and pytorch modules + pipeline.backend.build_engines( + engine_dir, + framework_model_dir, + onnx_dir, + 17, + opt_image_height=height, + opt_image_width=width, + opt_batch_size=batch_size, + force_engine_rebuild=False, + static_batch=True, + static_image_shape=True, + max_workspace_size=0, + device_id=torch.cuda.current_device(), + ) + + # Here we use static batch and image size, so the resource allocation only need done once. + # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency. + pipeline.load_resources(height, width, batch_size) + + def warmup(): + pipeline.run( + ["warm up"] * batch_size, ["negative"] * batch_size, height, width, denoising_steps=steps, warmup=True + ) + + # Run warm up, and measure GPU memory of two runs + # The first run has algo search so it might need more memory + first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + + warmup() + + image_filename_prefix = get_image_filename_prefix("ort_trt", short_name, batch_size, disable_safety_checker) + + latency_list = [] + prompts, negative_prompt = example_prompts() + for i, prompt in enumerate(prompts): + if i >= num_prompts: + break + for j in range(batch_count): + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + images, pipeline_time = pipeline.run( + [prompt] * batch_size, + [negative_prompt] * batch_size, + height, + width, + denoising_steps=steps, + guidance=7.5, + seed=123, + warmup=True, + ) + images = pipeline.to_pil_image( + images + ) # include image conversion time to pil image for apple-to-apple compare + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") + + pipeline.teardown() + + from tensorrt import __version__ as trt_version + + from onnxruntime import __version__ as ort_version + + return { + "model_name": pipeline_info.name(), + "engine": "onnxruntime", + "version": ort_version, + "provider": f"tensorrt({trt_version})", + "directory": engine_dir, + "height": height, + "width": width, + "steps": steps, + "batch_size": batch_size, + "batch_count": batch_count, + "num_prompts": num_prompts, + "average_latency": sum(latency_list) / len(latency_list), + "median_latency": statistics.median(latency_list), + "first_run_memory_MB": first_run_memory, + "second_run_memory_MB": second_run_memory, + "disable_safety_checker": disable_safety_checker, + "enable_cuda_graph": use_cuda_graph, + } + + +def run_tensorrt_static( + work_dir: str, + version: str, model_name: str, batch_size: int, disable_safety_checker: bool, @@ -618,32 +645,79 @@ def run_tensorrt( start_memory, memory_monitor_type, max_batch_size: int, + nvtx_profile: bool = False, + use_cuda_graph: bool = True, ): - from diffusers import DDIMScheduler - from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline + print("[I] Initializing TensorRT accelerated StableDiffusionXL txt2img pipeline (static input shape)") + + from cuda import cudart + + # Register TensorRT plugins + from trt_utilities import init_trt_plugins + + init_trt_plugins() assert batch_size <= max_batch_size - scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") - pipe = StableDiffusionPipeline.from_pretrained( - model_name, - custom_pipeline="stable_diffusion_tensorrt_txt2img", - revision="fp16", - torch_dtype=torch.float16, - scheduler=scheduler, - requires_safety_checker=not disable_safety_checker, - image_height=height, - image_width=width, + from diffusion_models import PipelineInfo + + pipeline_info = PipelineInfo(version) + + from engine_builder import EngineType, get_engine_paths + from pipeline_txt2img import Txt2ImgPipeline + + engine_type = EngineType.TRT + onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( + work_dir, pipeline_info, engine_type + ) + + # Initialize pipeline + pipeline = Txt2ImgPipeline( + pipeline_info, + scheduler="DDIM", + output_dir=output_dir, + hf_token=None, + verbose=False, + nvtx_profile=nvtx_profile, max_batch_size=max_batch_size, + use_cuda_graph=True, + engine_type=engine_type, ) - # re-use cached folder to save ONNX models and TensorRT Engines - pipe.set_cached_folder(model_name, revision="fp16") + # Load TensorRT engines and pytorch modules + pipeline.backend.load_engines( + engine_dir=engine_dir, + framework_model_dir=framework_model_dir, + onnx_dir=onnx_dir, + onnx_opset=17, + opt_batch_size=batch_size, + opt_image_height=height, + opt_image_width=width, + force_export=False, + force_optimize=False, + force_build=False, + static_batch=True, + static_shape=True, + enable_refit=False, + enable_preview=False, + enable_all_tactics=False, + timing_cache=timing_cache, + onnx_refit_dir=None, + ) - pipe = pipe.to("cuda") + # activate engines + max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + pipeline.backend.activate_engines(shared_device_memory) + + # Here we use static batch and image size, so the resource allocation only need done once. + # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency. + pipeline.load_resources(height, width, batch_size) def warmup(): - pipe(["warm up"] * batch_size, num_inference_steps=steps) + pipeline.run( + ["warm up"] * batch_size, ["negative"] * batch_size, height, width, denoising_steps=steps, warmup=True + ) # Run warm up, and measure GPU memory of two runs # The first run has algo search so it might need more memory @@ -655,28 +729,225 @@ def warmup(): image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, disable_safety_checker) latency_list = [] - prompts = example_prompts() + prompts, negative_prompt = example_prompts() for i, prompt in enumerate(prompts): if i >= num_prompts: break for j in range(batch_count): inference_start = time.time() - images = pipe( + # Use warmup mode here since non-warmup mode will save image to disk. + images, pipeline_time = pipeline.run( [prompt] * batch_size, - num_inference_steps=steps, - ).images + [negative_prompt] * batch_size, + height, + width, + denoising_steps=steps, + guidance=7.5, + seed=123, + warmup=True, + ) + images = pipeline.to_pil_image( + images + ) # include image conversion time to pil image for apple-to-apple compare inference_end = time.time() latency = inference_end - inference_start latency_list.append(latency) - print(f"Inference took {latency:.3f} seconds") + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") for k, image in enumerate(images): image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") - from tensorrt import __version__ as trt_version + pipeline.teardown() + + import tensorrt as trt + + return { + "engine": "tensorrt", + "version": trt.__version__, + "provider": "default", + "height": height, + "width": width, + "steps": steps, + "batch_size": batch_size, + "batch_count": batch_count, + "num_prompts": num_prompts, + "average_latency": sum(latency_list) / len(latency_list), + "median_latency": statistics.median(latency_list), + "first_run_memory_MB": first_run_memory, + "second_run_memory_MB": second_run_memory, + "enable_cuda_graph": use_cuda_graph, + } + + +def run_tensorrt_static_xl( + work_dir: str, + version: str, + batch_size: int, + disable_safety_checker: bool, + height: int, + width: int, + steps: int, + num_prompts: int, + batch_count: int, + start_memory, + memory_monitor_type, + max_batch_size: int, + nvtx_profile: bool = False, + use_cuda_graph=True, +): + print("[I] Initializing TensorRT accelerated StableDiffusionXL txt2img pipeline (static input shape)") + + import tensorrt as trt + from cuda import cudart + from trt_utilities import init_trt_plugins + + # Validate image dimensions + image_height = height + image_width = width + if image_height % 8 != 0 or image_width % 8 != 0: + raise ValueError( + f"Image height and width have to be divisible by 8 but specified as: {image_height} and {image_width}." + ) + + # Register TensorRT plugins + init_trt_plugins() + + assert batch_size <= max_batch_size + + from diffusion_models import PipelineInfo + from engine_builder import EngineType, get_engine_paths + + def init_pipeline(pipeline_class, pipeline_info): + engine_type = EngineType.TRT + + onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( + work_dir, pipeline_info, engine_type + ) + + # Initialize pipeline + pipeline = pipeline_class( + pipeline_info, + scheduler="DDIM", + output_dir=output_dir, + hf_token=None, + verbose=False, + nvtx_profile=nvtx_profile, + max_batch_size=max_batch_size, + use_cuda_graph=use_cuda_graph, + framework_model_dir=framework_model_dir, + engine_type=engine_type, + ) + + pipeline.backend.load_engines( + engine_dir=engine_dir, + framework_model_dir=framework_model_dir, + onnx_dir=onnx_dir, + onnx_opset=17, + opt_batch_size=batch_size, + opt_image_height=height, + opt_image_width=width, + force_export=False, + force_optimize=False, + force_build=False, + static_batch=True, + static_shape=True, + enable_refit=False, + enable_preview=False, + enable_all_tactics=False, + timing_cache=timing_cache, + onnx_refit_dir=None, + ) + return pipeline + + from pipeline_img2img_xl import Img2ImgXLPipeline + from pipeline_txt2img_xl import Txt2ImgXLPipeline + + base_pipeline_info = PipelineInfo(version) + demo_base = init_pipeline(Txt2ImgXLPipeline, base_pipeline_info) + + refiner_pipeline_info = PipelineInfo(version, is_sd_xl_refiner=True) + demo_refiner = init_pipeline(Img2ImgXLPipeline, refiner_pipeline_info) + + max_device_memory = max(demo_base.backend.max_device_memory(), demo_refiner.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + demo_base.backend.activate_engines(shared_device_memory) + demo_refiner.backend.activate_engines(shared_device_memory) + + # Here we use static batch and image size, so the resource allocation only need done once. + # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency. + demo_base.load_resources(image_height, image_width, batch_size) + demo_refiner.load_resources(image_height, image_width, batch_size) + + def run_sd_xl_inference(prompt, negative_prompt, seed=None, warmup=False): + images, time_base = demo_base.run( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=steps, + guidance=5.0, + warmup=warmup, + seed=seed, + return_type="latents", + ) + + images, time_refiner = demo_refiner.run( + prompt, + negative_prompt, + images, + image_height, + image_width, + denoising_steps=steps, + guidance=5.0, + warmup=warmup, + seed=seed, + ) + return images, time_base + time_refiner + + def warmup(): + run_sd_xl_inference(["warm up"] * batch_size, ["negative"] * batch_size, warmup=True) + + # Run warm up, and measure GPU memory of two runs + # The first run has algo search so it might need more memory + first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + + warmup() + + model_name = refiner_pipeline_info.name() + image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, disable_safety_checker) + + latency_list = [] + prompts, negative_prompt = example_prompts() + for i, prompt in enumerate(prompts): + if i >= num_prompts: + break + for j in range(batch_count): + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + if nvtx_profile: + cudart.cudaProfilerStart() + images, pipeline_time = run_sd_xl_inference( + [prompt] * batch_size, [negative_prompt] * batch_size, seed=123, warmup=True + ) + if nvtx_profile: + cudart.cudaProfilerStop() + images = demo_refiner.to_pil_image( + images + ) # include image conversion time to pil image for apple-to-apple compare + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{j}_{k}.png") + + demo_base.teardown() + demo_refiner.teardown() return { + "model_name": model_name, "engine": "tensorrt", - "version": trt_version, + "version": trt.__version__, "provider": "default", "height": height, "width": width, @@ -688,7 +959,178 @@ def warmup(): "median_latency": statistics.median(latency_list), "first_run_memory_MB": first_run_memory, "second_run_memory_MB": second_run_memory, - "enable_cuda_graph": False, + "enable_cuda_graph": use_cuda_graph, + } + + +def run_ort_trt_xl( + work_dir: str, + version: str, + batch_size: int, + disable_safety_checker: bool, + height: int, + width: int, + steps: int, + num_prompts: int, + batch_count: int, + start_memory, + memory_monitor_type, + max_batch_size: int, + nvtx_profile: bool = False, + use_cuda_graph=True, +): + from cuda import cudart + + # Validate image dimensions + image_height = height + image_width = width + if image_height % 8 != 0 or image_width % 8 != 0: + raise ValueError( + f"Image height and width have to be divisible by 8 but specified as: {image_height} and {image_width}." + ) + + assert batch_size <= max_batch_size + + from engine_builder import EngineType, get_engine_paths + + def init_pipeline(pipeline_class, pipeline_info): + engine_type = EngineType.ORT_TRT + + onnx_dir, engine_dir, output_dir, framework_model_dir, _ = get_engine_paths( + work_dir, pipeline_info, engine_type + ) + + # Initialize pipeline + pipeline = pipeline_class( + pipeline_info, + scheduler="DDIM", + output_dir=output_dir, + hf_token=None, + verbose=False, + nvtx_profile=nvtx_profile, + max_batch_size=max_batch_size, + use_cuda_graph=use_cuda_graph, + framework_model_dir=framework_model_dir, + engine_type=engine_type, + ) + + pipeline.backend.build_engines( + engine_dir, + framework_model_dir, + onnx_dir, + 17, + opt_image_height=height, + opt_image_width=width, + opt_batch_size=batch_size, + force_engine_rebuild=False, + static_batch=True, + static_image_shape=True, + max_workspace_size=0, + device_id=torch.cuda.current_device(), # TODO: might not work with CUDA_VISIBLE_DEVICES + ) + return pipeline + + from diffusion_models import PipelineInfo + from pipeline_img2img_xl import Img2ImgXLPipeline + from pipeline_txt2img_xl import Txt2ImgXLPipeline + + base_pipeline_info = PipelineInfo(version) + demo_base = init_pipeline(Txt2ImgXLPipeline, base_pipeline_info) + + refiner_pipeline_info = PipelineInfo(version, is_sd_xl_refiner=True) + demo_refiner = init_pipeline(Img2ImgXLPipeline, refiner_pipeline_info) + + demo_base.load_resources(image_height, image_width, batch_size) + demo_refiner.load_resources(image_height, image_width, batch_size) + + def run_sd_xl_inference(prompt, negative_prompt, seed=None, warmup=False): + images, time_base = demo_base.run( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=steps, + guidance=5.0, + warmup=warmup, + seed=seed, + return_type="latents", + ) + images, time_refiner = demo_refiner.run( + prompt, + negative_prompt, + images, + image_height, + image_width, + denoising_steps=steps, + guidance=5.0, + warmup=warmup, + seed=seed, + ) + return images, time_base + time_refiner + + def warmup(): + run_sd_xl_inference(["warm up"] * batch_size, ["negative"] * batch_size, warmup=True) + + # Run warm up, and measure GPU memory of two runs + # The first run has algo search so it might need more memory + first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + + warmup() + + model_name = refiner_pipeline_info.name() + image_filename_prefix = get_image_filename_prefix("ort_trt", model_name, batch_size, disable_safety_checker) + + latency_list = [] + prompts, negative_prompt = example_prompts() + for i, prompt in enumerate(prompts): + if i >= num_prompts: + break + for j in range(batch_count): + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + if nvtx_profile: + cudart.cudaProfilerStart() + images, pipeline_time = run_sd_xl_inference( + [prompt] * batch_size, [negative_prompt] * batch_size, seed=123, warmup=True + ) + if nvtx_profile: + cudart.cudaProfilerStop() + images = demo_refiner.to_pil_image( + images + ) # include image conversion time to pil image for apple-to-apple compare + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") + for k, image in enumerate(images): + filename = f"{image_filename_prefix}_{i}_{j}_{k}.png" + image.save(filename) + print("Image saved to", filename) + + demo_base.teardown() + demo_refiner.teardown() + + from tensorrt import __version__ as trt_version + + from onnxruntime import __version__ as ort_version + + return { + "model_name": model_name, + "engine": "onnxruntime", + "version": ort_version, + "provider": f"tensorrt{trt_version})", + "height": height, + "width": width, + "steps": steps, + "batch_size": batch_size, + "batch_count": batch_count, + "num_prompts": num_prompts, + "average_latency": sum(latency_list) / len(latency_list), + "median_latency": statistics.median(latency_list), + "first_run_memory_MB": first_run_memory, + "second_run_memory_MB": second_run_memory, + "enable_cuda_graph": use_cuda_graph, } @@ -808,6 +1250,15 @@ def parse_arguments(): help="Directory of saved onnx pipeline. It could be the output directory of optimize_pipeline.py.", ) + parser.add_argument( + "-w", + "--work_dir", + required=False, + type=str, + default=".", + help="Root directory to save exported onnx models, built engines etc.", + ) + parser.add_argument( "--enable_safety_checker", required=False, @@ -922,28 +1373,31 @@ def main(): args = parse_arguments() print(args) - if args.enable_cuda_graph: - if not (args.engine == "onnxruntime" and args.provider in ["cuda", "tensorrt"] and args.pipeline is None): - raise ValueError("The stable diffusion pipeline does not support CUDA graph.") + if args.engine == "onnxruntime": + if args.version in ["2.1"]: + # Set a flag to avoid overflow in attention, which causes black image output in SD 2.1 model. + # The environment variables shall be set before the first run of Attention or MultiHeadAttention operator. + os.environ["ORT_DISABLE_TRT_FLASH_ATTENTION"] = "1" from packaging import version from onnxruntime import __version__ as ort_version - if version.parse(ort_version) < version.parse("1.16"): - raise ValueError( - "CUDA graph requires ONNX Runtime 1.16. You can install nightly like the following:\n" - " pip uninstall onnxruntime-gpu\n" - " pip install ort-nightly-gpu -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/" - ) + if version.parse(ort_version) == version.parse("1.16.0"): + # ORT 1.16 has a bug that might trigger Attention RuntimeError when latest fusion script is applied on clip model. + # The walkaround is to enable fused causal attention, or disable Attention fusion for clip model. + os.environ["ORT_ENABLE_FUSED_CAUSAL_ATTENTION"] = "1" + + if args.enable_cuda_graph: + if not (args.engine == "onnxruntime" and args.provider in ["cuda", "tensorrt"] and args.pipeline is None): + raise ValueError("The stable diffusion pipeline does not support CUDA graph.") + + if version.parse(ort_version) < version.parse("1.16"): + raise ValueError("CUDA graph requires ONNX Runtime 1.16 or later") coloredlogs.install(fmt="%(funcName)20s: %(message)s") - memory_monitor_type = None - if args.provider in ["cuda", "tensorrt"]: - memory_monitor_type = CudaMemoryMonitor - elif args.provider == "rocm": - memory_monitor_type = RocmMemoryMonitor + memory_monitor_type = "rocm" if args.provider == "rocm" else "cuda" start_memory = measure_gpu_memory(memory_monitor_type, None) print("GPU memory used before loading models:", start_memory) @@ -951,89 +1405,157 @@ def main(): sd_model = SD_MODELS[args.version] provider = PROVIDERS[args.provider] if args.engine == "onnxruntime" and args.provider == "tensorrt": - result = run_ort_trt( - sd_model, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.max_trt_batch_size, - args.enable_cuda_graph, - ) + if "xl" in args.version: + print("Testing Txt2ImgXLPipeline with static input shape. Backend is ORT TensorRT EP.") + result = run_ort_trt_xl( + work_dir=args.work_dir, + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=True, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + nvtx_profile=False, + use_cuda_graph=args.enable_cuda_graph, + ) + elif args.tuning: + print( + "Testing OnnxruntimeTensorRTStableDiffusionPipeline with {}.".format( + "static input shape" if args.enable_cuda_graph else "dynamic batch size" + ) + ) + result = run_ort_trt( + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + enable_cuda_graph=args.enable_cuda_graph, + ) + else: + print("Testing Txt2ImgPipeline with static input shape. Backend is ORT TensorRT EP.") + result = run_ort_trt_static( + work_dir=args.work_dir, + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + nvtx_profile=False, + use_cuda_graph=args.enable_cuda_graph, + ) + elif args.engine == "onnxruntime" and provider == "CUDAExecutionProvider" and args.pipeline is None: - print("Pipeline is not specified. Trying export and optimize onnx models...") + print( + "Testing OnnxruntimeCudaStableDiffusionPipeline with {} input shape. Backend is ORT CUDA EP.".format( + "static" if args.enable_cuda_graph else "dynamic" + ) + ) result = export_and_run_ort( - sd_model, - provider, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.enable_cuda_graph, + version=args.version, + provider=provider, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + enable_cuda_graph=args.enable_cuda_graph, ) elif args.engine == "onnxruntime": assert args.pipeline and os.path.isdir( args.pipeline ), "--pipeline should be specified for the directory of ONNX models" - - if args.version in ["2.1"]: - # Set a flag to avoid overflow in attention, which causes black image output in SD 2.1 model - # This shall be done before the first inference run. - os.environ["ORT_DISABLE_TRT_FLASH_ATTENTION"] = "1" - + print(f"Testing diffusers StableDiffusionPipeline with {provider} provider and tuning={args.tuning}") result = run_ort( - sd_model, - args.pipeline, - provider, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.tuning, + model_name=sd_model, + directory=args.pipeline, + provider=provider, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + tuning=args.tuning, + ) + elif args.engine == "tensorrt" and "xl" in args.version: + print("Testing Txt2ImgXLPipeline with static input shape. Backend is TensorRT.") + result = run_tensorrt_static_xl( + work_dir=args.work_dir, + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=True, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + nvtx_profile=False, + use_cuda_graph=args.enable_cuda_graph, ) elif args.engine == "tensorrt": - result = run_tensorrt( - sd_model, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.max_trt_batch_size, + print("Testing Txt2ImgPipeline with static input shape. Backend is TensorRT.") + result = run_tensorrt_static( + work_dir=args.work_dir, + version=args.version, + model_name=sd_model, + batch_size=args.batch_size, + disable_safety_checker=True, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + nvtx_profile=False, + use_cuda_graph=args.enable_cuda_graph, ) else: + print( + f"Testing Txt2ImgPipeline with dynamic input shape. Backend is PyTorch: compile={args.enable_torch_compile}, xformers={args.use_xformers}." + ) result = run_torch( - sd_model, - args.batch_size, - not args.enable_safety_checker, - args.enable_torch_compile, - args.use_xformers, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, + model_name=sd_model, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + enable_torch_compile=args.enable_torch_compile, + use_xformers=args.use_xformers, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, ) print(result) @@ -1068,8 +1590,9 @@ def main(): if __name__ == "__main__": + import traceback + try: main() - except Exception as e: - tb = sys.exc_info() - print(e.with_traceback(tb[2])) + except Exception: + traceback.print_exception(*sys.exc_info()) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py new file mode 100644 index 0000000000000..f6e00063a6391 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -0,0 +1,94 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import coloredlogs +from cuda import cudart +from demo_utils import init_pipeline, parse_arguments, repeat_prompt +from diffusion_models import PipelineInfo +from engine_builder import EngineType, get_engine_type +from pipeline_txt2img import Txt2ImgPipeline + +if __name__ == "__main__": + coloredlogs.install(fmt="%(funcName)20s: %(message)s") + + args = parse_arguments(is_xl=False, description="Options for Stable Diffusion Demo") + prompt, negative_prompt = repeat_prompt(args) + + image_height = args.height + image_width = args.width + + # Register TensorRT plugins + engine_type = get_engine_type(args.engine) + if engine_type == EngineType.TRT: + from trt_utilities import init_trt_plugins + + init_trt_plugins() + + max_batch_size = 16 + if engine_type != EngineType.ORT_CUDA and (args.build_dynamic_shape or image_height > 512 or image_width > 512): + max_batch_size = 4 + + batch_size = len(prompt) + if batch_size > max_batch_size: + raise ValueError( + f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4" + ) + + pipeline_info = PipelineInfo(args.version) + pipeline = init_pipeline(Txt2ImgPipeline, pipeline_info, engine_type, args, max_batch_size, batch_size) + + if engine_type == EngineType.TRT: + max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + pipeline.backend.activate_engines(shared_device_memory) + + pipeline.load_resources(image_height, image_width, batch_size) + + def run_inference(warmup=False): + return pipeline.run( + prompt, + negative_prompt, + image_height, + image_width, + warmup=warmup, + denoising_steps=args.denoising_steps, + guidance=args.guidance, + seed=args.seed, + return_type="images", + ) + + if not args.disable_cuda_graph: + # inference once to get cuda graph + _image, _latency = run_inference(warmup=True) + + print("[I] Warming up ..") + for _ in range(args.num_warmup_runs): + _image, _latency = run_inference(warmup=True) + + print("[I] Running StableDiffusion pipeline") + if args.nvtx_profile: + cudart.cudaProfilerStart() + _image, _latency = run_inference(warmup=False) + if args.nvtx_profile: + cudart.cudaProfilerStop() + + pipeline.teardown() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py new file mode 100644 index 0000000000000..c3a2e4e293cc8 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -0,0 +1,129 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import coloredlogs +from cuda import cudart +from demo_utils import init_pipeline, parse_arguments, repeat_prompt +from diffusion_models import PipelineInfo +from engine_builder import EngineType, get_engine_type +from pipeline_img2img_xl import Img2ImgXLPipeline +from pipeline_txt2img_xl import Txt2ImgXLPipeline + +if __name__ == "__main__": + coloredlogs.install(fmt="%(funcName)20s: %(message)s") + args = parse_arguments(is_xl=True, description="Options for Stable Diffusion XL Demo") + prompt, negative_prompt = repeat_prompt(args) + + image_height = args.height + image_width = args.width + + # Register TensorRT plugins + engine_type = get_engine_type(args.engine) + if engine_type == EngineType.TRT: + from trt_utilities import init_trt_plugins + + init_trt_plugins() + + max_batch_size = 16 + if args.build_dynamic_shape or image_height > 512 or image_width > 512: + max_batch_size = 4 + + batch_size = len(prompt) + if batch_size > max_batch_size: + raise ValueError( + f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4" + ) + + base_info = PipelineInfo(args.version, use_vae_in_xl_base=not args.enable_refiner) + base = init_pipeline(Txt2ImgXLPipeline, base_info, engine_type, args, max_batch_size, batch_size) + + if args.enable_refiner: + refiner_info = PipelineInfo(args.version, is_sd_xl_refiner=True) + refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type, args, max_batch_size, batch_size) + + if engine_type == EngineType.TRT: + max_device_memory = max(base.backend.max_device_memory(), refiner.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + base.backend.activate_engines(shared_device_memory) + refiner.backend.activate_engines(shared_device_memory) + + base.load_resources(image_height, image_width, batch_size) + refiner.load_resources(image_height, image_width, batch_size) + else: + if engine_type == EngineType.TRT: + max_device_memory = max(base.backend.max_device_memory(), base.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + base.backend.activate_engines(shared_device_memory) + + base.load_resources(image_height, image_width, batch_size) + + def run_sd_xl_inference(enable_refiner: bool, warmup=False): + images, time_base = base.run( + prompt, + negative_prompt, + image_height, + image_width, + warmup=warmup, + denoising_steps=args.denoising_steps, + guidance=args.guidance, + seed=args.seed, + return_type="latents" if enable_refiner else "images", + ) + + if enable_refiner: + images, time_refiner = refiner.run( + prompt, + negative_prompt, + images, + image_height, + image_width, + warmup=warmup, + denoising_steps=args.denoising_steps, + guidance=args.guidance, + seed=args.seed, + ) + return images, time_base + time_refiner + else: + return images, time_base + + if not args.disable_cuda_graph: + # inference once to get cuda graph + images, _ = run_sd_xl_inference(args.enable_refiner, warmup=True) + + print("[I] Warming up ..") + for _ in range(args.num_warmup_runs): + images, _ = run_sd_xl_inference(args.enable_refiner, warmup=True) + + print("[I] Running StableDiffusion XL pipeline") + if args.nvtx_profile: + cudart.cudaProfilerStart() + images, pipeline_time = run_sd_xl_inference(args.enable_refiner, warmup=False) + if args.nvtx_profile: + cudart.cudaProfilerStop() + + base.teardown() + + if args.enable_refiner: + print("|------------|--------------|") + print("| {:^10} | {:>9.2f} ms |".format("e2e", pipeline_time)) + print("|------------|--------------|") + refiner.teardown() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py new file mode 100644 index 0000000000000..5fdafc463f4e2 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -0,0 +1,255 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import argparse + +import torch +from diffusion_models import PipelineInfo +from engine_builder import EngineType, get_engine_paths + + +class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter): + pass + + +def parse_arguments(is_xl: bool, description: str): + parser = argparse.ArgumentParser(description=description, formatter_class=RawTextArgumentDefaultsHelpFormatter) + + parser.add_argument( + "--engine", + type=str, + default="ORT_TRT", + choices=["ORT_TRT", "TRT"], + help="Backend engine. Default is OnnxRuntime CUDA execution provider.", + ) + + supported_versions = PipelineInfo.supported_versions(is_xl) + parser.add_argument( + "--version", + type=str, + default=supported_versions[-1] if is_xl else "1.5", + choices=supported_versions, + help="Version of Stable Diffusion" + (" XL." if is_xl else "."), + ) + + parser.add_argument( + "--height", + type=int, + default=1024 if is_xl else 512, + help="Height of image to generate (must be multiple of 8).", + ) + parser.add_argument( + "--width", type=int, default=1024 if is_xl else 512, help="Height of image to generate (must be multiple of 8)." + ) + + parser.add_argument( + "--scheduler", + type=str, + default="DDIM", + choices=["DDIM", "EulerA", "UniPC"], + help="Scheduler for diffusion process", + ) + + parser.add_argument( + "--work-dir", + default=".", + help="Root Directory to store torch or ONNX models, built engines and output images etc.", + ) + + parser.add_argument("prompt", nargs="+", help="Text prompt(s) to guide image generation.") + + parser.add_argument( + "--negative-prompt", nargs="*", default=[""], help="Optional negative prompt(s) to guide the image generation." + ) + parser.add_argument( + "--repeat-prompt", + type=int, + default=1, + choices=[1, 2, 4, 8, 16], + help="Number of times to repeat the prompt (batch size multiplier).", + ) + + parser.add_argument( + "--denoising-steps", + type=int, + default=30 if is_xl else 50, + help="Number of denoising steps" + (" in each of base and refiner." if is_xl else "."), + ) + + parser.add_argument( + "--guidance", + type=float, + default=5.0 if is_xl else 7.5, + help="Higher guidance scale encourages to generate images that are closely linked to the text prompt.", + ) + + # ONNX export + parser.add_argument( + "--onnx-opset", + type=int, + default=17, + choices=range(14, 18), + help="Select ONNX opset version to target for exported models.", + ) + parser.add_argument( + "--force-onnx-export", action="store_true", help="Force ONNX export of CLIP, UNET, and VAE models." + ) + parser.add_argument( + "--force-onnx-optimize", action="store_true", help="Force ONNX optimizations for CLIP, UNET, and VAE models." + ) + + # Framework model ckpt + parser.add_argument( + "--framework-model-dir", + default="pytorch_model", + help="Directory for HF saved models. Default is pytorch_model.", + ) + parser.add_argument("--hf-token", type=str, help="HuggingFace API access token for downloading model checkpoints.") + + # Engine build options. + parser.add_argument("--force-engine-build", action="store_true", help="Force rebuilding the TensorRT engine.") + parser.add_argument( + "--build-dynamic-batch", action="store_true", help="Build TensorRT engines to support dynamic batch size." + ) + parser.add_argument( + "--build-dynamic-shape", action="store_true", help="Build TensorRT engines to support dynamic image sizes." + ) + + # Inference related options + parser.add_argument( + "--num-warmup-runs", type=int, default=5, help="Number of warmup runs before benchmarking performance." + ) + parser.add_argument("--nvtx-profile", action="store_true", help="Enable NVTX markers for performance profiling.") + parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.") + parser.add_argument("--disable-cuda-graph", action="store_true", help="Disable cuda graph.") + + # TensorRT only options + group = parser.add_argument_group("Options for TensorRT (--engine=TRT) only") + group.add_argument("--onnx-refit-dir", help="ONNX models to load the weights from.") + group.add_argument( + "--build-enable-refit", action="store_true", help="Enable Refit option in TensorRT engines during build." + ) + group.add_argument( + "--build-preview-features", action="store_true", help="Build TensorRT engines with preview features." + ) + group.add_argument( + "--build-all-tactics", action="store_true", help="Build TensorRT engines using all tactic sources." + ) + + # Pipeline options + if is_xl: + parser.add_argument( + "--enable-refiner", action="store_true", help="Enable refiner and run both base and refiner pipelines." + ) + + args = parser.parse_args() + + # Validate image dimensions + if args.height % 8 != 0 or args.width % 8 != 0: + raise ValueError( + f"Image height and width have to be divisible by 8 but specified as: {args.height} and {args.width}." + ) + + if (args.build_dynamic_batch or args.build_dynamic_shape) and not args.disable_cuda_graph: + print("[I] CUDA Graph is disabled since dynamic input shape is configured.") + args.disable_cuda_graph = True + + print(args) + + return args + + +def repeat_prompt(args): + if not isinstance(args.prompt, list): + raise ValueError(f"`prompt` must be of type `str` or `str` list, but is {type(args.prompt)}") + prompt = args.prompt * args.repeat_prompt + + if not isinstance(args.negative_prompt, list): + raise ValueError( + f"`--negative-prompt` must be of type `str` or `str` list, but is {type(args.negative_prompt)}" + ) + if len(args.negative_prompt) == 1: + negative_prompt = args.negative_prompt * len(prompt) + else: + negative_prompt = args.negative_prompt + + return prompt, negative_prompt + + +def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_size, batch_size): + onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( + args.work_dir, pipeline_info, engine_type + ) + + # Initialize demo + pipeline = pipeline_class( + pipeline_info, + scheduler=args.scheduler, + output_dir=output_dir, + hf_token=args.hf_token, + verbose=False, + nvtx_profile=args.nvtx_profile, + max_batch_size=max_batch_size, + use_cuda_graph=not args.disable_cuda_graph, + framework_model_dir=framework_model_dir, + engine_type=engine_type, + ) + + if engine_type == EngineType.ORT_TRT: + # Build TensorRT EP engines and load pytorch modules + pipeline.backend.build_engines( + engine_dir, + framework_model_dir, + onnx_dir, + args.onnx_opset, + opt_image_height=args.height, + opt_image_width=args.height, + opt_batch_size=batch_size, + force_engine_rebuild=args.force_engine_build, + static_batch=not args.build_dynamic_batch, + static_image_shape=not args.build_dynamic_shape, + max_workspace_size=0, + device_id=torch.cuda.current_device(), + ) + elif engine_type == EngineType.TRT: + # Load TensorRT engines and pytorch modules + pipeline.backend.load_engines( + engine_dir, + framework_model_dir, + onnx_dir, + args.onnx_opset, + opt_batch_size=batch_size, + opt_image_height=args.height, + opt_image_width=args.height, + force_export=args.force_onnx_export, + force_optimize=args.force_onnx_optimize, + force_build=args.force_engine_build, + static_batch=not args.build_dynamic_batch, + static_shape=not args.build_dynamic_shape, + enable_refit=args.build_enable_refit, + enable_preview=args.build_preview_features, + enable_all_tactics=args.build_all_tactics, + timing_cache=timing_cache, + onnx_refit_dir=args.onnx_refit_dir, + ) + + return pipeline diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py new file mode 100644 index 0000000000000..951cd66005f4c --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -0,0 +1,858 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from stable_diffusion_tensorrt_txt2img.py in diffusers and TensorRT demo diffusion, +# which has the following license: +# +# Copyright 2023 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import tempfile +from typing import List, Optional + +import onnx +import onnx_graphsurgeon as gs +import torch +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from onnx import GraphProto, ModelProto, shape_inference +from ort_optimizer import OrtStableDiffusionOptimizer +from polygraphy.backend.onnx.loader import fold_constants +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from onnxruntime.transformers.onnx_model import OnnxModel + +logger = logging.getLogger(__name__) + + +class TrtOptimizer: + def __init__(self, onnx_graph): + self.graph = gs.import_onnx(onnx_graph) + + def cleanup(self): + self.graph.cleanup().toposort() + + def get_optimized_onnx_graph(self): + return gs.export_onnx(self.graph) + + def select_outputs(self, keep, names=None): + self.graph.outputs = [self.graph.outputs[o] for o in keep] + if names: + for i, name in enumerate(names): + self.graph.outputs[i].name = name + + def fold_constants(self): + onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) + self.graph = gs.import_onnx(onnx_graph) + + def infer_shapes(self): + onnx_graph = gs.export_onnx(self.graph) + if onnx_graph.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF: + with tempfile.TemporaryDirectory() as temp_dir: + input_onnx_path = os.path.join(temp_dir, "model.onnx") + onnx.save_model( + onnx_graph, + input_onnx_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False, + ) + output_onnx_path = os.path.join(temp_dir, "model_with_shape.onnx") + onnx.shape_inference.infer_shapes_path(input_onnx_path, output_onnx_path) + onnx_graph = onnx.load(output_onnx_path) + else: + onnx_graph = shape_inference.infer_shapes(onnx_graph) + + self.graph = gs.import_onnx(onnx_graph) + + +class PipelineInfo: + def __init__( + self, version: str, is_inpaint: bool = False, is_sd_xl_refiner: bool = False, use_vae_in_xl_base=False + ): + self.version = version + self._is_inpaint = is_inpaint + self._is_sd_xl_refiner = is_sd_xl_refiner + self._use_vae_in_xl_base = use_vae_in_xl_base + + if is_sd_xl_refiner: + assert self.is_sd_xl() + + def is_inpaint(self) -> bool: + return self._is_inpaint + + def is_sd_xl(self) -> bool: + return "xl" in self.version + + def is_sd_xl_base(self) -> bool: + return self.is_sd_xl() and not self._is_sd_xl_refiner + + def is_sd_xl_refiner(self) -> bool: + return self.is_sd_xl() and self._is_sd_xl_refiner + + def use_safetensors(self) -> bool: + return self.is_sd_xl() + + def stages(self) -> List[str]: + if self.is_sd_xl_base(): + return ["clip", "clip2", "unetxl"] + (["vae"] if self._use_vae_in_xl_base else []) + + if self.is_sd_xl_refiner(): + return ["clip2", "unetxl", "vae"] + + return ["clip", "unet", "vae"] + + def vae_scaling_factor(self) -> float: + return 0.13025 if self.is_sd_xl() else 0.18215 + + @staticmethod + def supported_versions(is_xl: bool): + return ["xl-1.0"] if is_xl else ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base"] + + def name(self) -> str: + if self.version == "1.4": + if self.is_inpaint(): + return "runwayml/stable-diffusion-inpainting" + else: + return "CompVis/stable-diffusion-v1-4" + elif self.version == "1.5": + if self.is_inpaint(): + return "runwayml/stable-diffusion-inpainting" + else: + return "runwayml/stable-diffusion-v1-5" + elif self.version == "2.0-base": + if self.is_inpaint(): + return "stabilityai/stable-diffusion-2-inpainting" + else: + return "stabilityai/stable-diffusion-2-base" + elif self.version == "2.0": + if self.is_inpaint(): + return "stabilityai/stable-diffusion-2-inpainting" + else: + return "stabilityai/stable-diffusion-2" + elif self.version == "2.1": + return "stabilityai/stable-diffusion-2-1" + elif self.version == "2.1-base": + return "stabilityai/stable-diffusion-2-1-base" + elif self.version == "xl-1.0": + if self.is_sd_xl_refiner(): + return "stabilityai/stable-diffusion-xl-refiner-1.0" + else: + return "stabilityai/stable-diffusion-xl-base-1.0" + + raise ValueError(f"Incorrect version {self.version}") + + def short_name(self) -> str: + return self.name().split("/")[-1].replace("stable-diffusion", "sd") + + def clip_embedding_dim(self): + # TODO: can we read from config instead + if self.version in ("1.4", "1.5"): + return 768 + elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base"): + return 1024 + elif self.version in ("xl-1.0") and self.is_sd_xl_base(): + return 768 + else: + raise ValueError(f"Invalid version {self.version}") + + def clipwithproj_embedding_dim(self): + if self.version in ("xl-1.0"): + return 1280 + else: + raise ValueError(f"Invalid version {self.version}") + + def unet_embedding_dim(self): + if self.version in ("1.4", "1.5"): + return 768 + elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base"): + return 1024 + elif self.version in ("xl-1.0") and self.is_sd_xl_base(): + return 2048 + elif self.version in ("xl-1.0") and self.is_sd_xl_refiner(): + return 1280 + else: + raise ValueError(f"Invalid version {self.version}") + + +class BaseModel: + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + fp16: bool = False, + max_batch_size: int = 16, + embedding_dim: int = 768, + text_maxlen: int = 77, + ): + self.name = self.__class__.__name__ + + self.pipeline_info = pipeline_info + + self.model = model + self.fp16 = fp16 + self.device = device + + self.min_batch = 1 + self.max_batch = max_batch_size + self.min_image_shape = 256 # min image resolution: 256x256 + self.max_image_shape = 1024 # max image resolution: 1024x1024 + self.min_latent_shape = self.min_image_shape // 8 + self.max_latent_shape = self.max_image_shape // 8 + + self.embedding_dim = embedding_dim + self.text_maxlen = text_maxlen + + def get_ort_optimizer(self): + model_name_to_model_type = { + "CLIP": "clip", + "UNet": "unet", + "VAE": "vae", + "UNetXL": "unet", + "CLIPWithProj": "clip", + } + model_type = model_name_to_model_type[self.name] + return OrtStableDiffusionOptimizer(model_type) + + def get_model(self): + return self.model + + def from_pretrained(self, model_class, framework_model_dir, hf_token, subfolder, **kwargs): + model_dir = os.path.join(framework_model_dir, self.pipeline_info.name(), subfolder) + + if not os.path.exists(model_dir): + model = model_class.from_pretrained( + self.pipeline_info.name(), + subfolder=subfolder, + use_safetensors=self.pipeline_info.use_safetensors(), + use_auth_token=hf_token, + **kwargs, + ).to(self.device) + model.save_pretrained(model_dir) + else: + print(f"Load {self.name} pytorch model from: {model_dir}") + + model = model_class.from_pretrained(model_dir).to(self.device) + return model + + def load_model(self, framework_model_dir: str, hf_token: str, subfolder: str): + pass + + def get_input_names(self): + pass + + def get_output_names(self): + pass + + def get_dynamic_axes(self): + return None + + def get_sample_input(self, batch_size, image_height, image_width): + pass + + def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape): + """For TensorRT EP""" + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + _, + _, + _, + _, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + + profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}" + + if self.name != "CLIP": + if static_image_shape: + profile_id += f"_h_{image_height}_w_{image_width}" + else: + profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}" + + return profile_id + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + """For TensorRT""" + return None + + def get_shape_dict(self, batch_size, image_height, image_width): + return None + + def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True): + optimizer = self.get_ort_optimizer() + optimizer.optimize(input_onnx_path, optimized_onnx_path, to_fp16) + + def optimize_trt(self, input_onnx_path, optimized_onnx_path): + onnx_graph = onnx.load(input_onnx_path) + opt = TrtOptimizer(onnx_graph) + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + opt.cleanup() + onnx_opt_graph = opt.get_optimized_onnx_graph() + + if onnx_opt_graph.ByteSize() > onnx.checker.MAXIMUM_PROTOBUF: + onnx.save_model( + onnx_opt_graph, + optimized_onnx_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False, + ) + else: + onnx.save(onnx_opt_graph, optimized_onnx_path) + + def check_dims(self, batch_size, image_height, image_width): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + assert image_height % 8 == 0 or image_width % 8 == 0 + latent_height = image_height // 8 + latent_width = image_width // 8 + assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape + assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape + return (latent_height, latent_width) + + def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape): + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + latent_height = image_height // 8 + latent_width = image_width // 8 + min_image_height = image_height if static_image_shape else self.min_image_shape + max_image_height = image_height if static_image_shape else self.max_image_shape + min_image_width = image_width if static_image_shape else self.min_image_shape + max_image_width = image_width if static_image_shape else self.max_image_shape + min_latent_height = latent_height if static_image_shape else self.min_latent_shape + max_latent_height = latent_height if static_image_shape else self.max_latent_shape + min_latent_width = latent_width if static_image_shape else self.min_latent_shape + max_latent_width = latent_width if static_image_shape else self.max_latent_shape + return ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) + + +class CLIP(BaseModel): + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + max_batch_size, + embedding_dim: int = 0, + clip_skip=0, + ): + super().__init__( + pipeline_info, + model=model, + device=device, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim if embedding_dim > 0 else pipeline_info.clip_embedding_dim(), + ) + self.output_hidden_state = pipeline_info.is_sd_xl() + + # see https://github.com/huggingface/diffusers/pull/5057 for more information of clip_skip. + # Clip_skip=1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + self.clip_skip = clip_skip + + def get_input_names(self): + return ["input_ids"] + + def get_output_names(self): + # The exported onnx model has no hidden_state. For SD-XL, We will add hidden_state to optimized onnx model. + return ["text_embeddings"] + + def get_dynamic_axes(self): + return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_image_shape + ) + return { + "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + output = { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), + } + + if self.output_hidden_state: + output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim) + + return output + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return (torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device),) + + def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path): + graph: GraphProto = model.graph + hidden_layers = -1 + for i in range(len(graph.node)): + for j in range(len(graph.node[i].output)): + name = graph.node[i].output[j] + if "layers" in name: + hidden_layers = max(int(name.split(".")[1].split("/")[0]), hidden_layers) + + assert self.clip_skip >= 0 and self.clip_skip < hidden_layers + + node_output_name = "/text_model/encoder/layers.{}/Add_1_output_0".format(hidden_layers - 1 - self.clip_skip) + + # search the name in outputs of all node + found = False + for i in range(len(graph.node)): + for j in range(len(graph.node[i].output)): + if graph.node[i].output[j] == node_output_name: + found = True + break + if found: + break + if not found: + raise RuntimeError("Failed to find hidden_states graph output in clip") + + # Insert a Cast (fp32 -> fp16) node so that hidden_states has same data type as the first graph output. + graph_output_name = "hidden_states" + cast_node = onnx.helper.make_node("Cast", inputs=[node_output_name], outputs=[graph_output_name]) + cast_node.attribute.extend([onnx.helper.make_attribute("to", graph.output[0].type.tensor_type.elem_type)]) + + hidden_state = graph.output.add() + hidden_state.CopyFrom( + onnx.helper.make_tensor_value_info( + graph_output_name, + graph.output[0].type.tensor_type.elem_type, + ["B", self.text_maxlen, self.embedding_dim], + ) + ) + + onnx_model = OnnxModel(model) + onnx_model.add_node(cast_node) + onnx_model.save_model_to_file(optimized_onnx_path) + + def optimize_trt(self, input_onnx_path, optimized_onnx_path): + onnx_graph = onnx.load(input_onnx_path) + opt = TrtOptimizer(onnx_graph) + opt.select_outputs([0]) # delete graph output#1 + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + opt.select_outputs([0], names=["text_embeddings"]) # rename network output + opt.cleanup() + onnx_opt_graph = opt.get_optimized_onnx_graph() + if self.output_hidden_state: + self.add_hidden_states_graph_output(onnx_opt_graph, optimized_onnx_path) + else: + onnx.save(onnx_opt_graph, optimized_onnx_path) + + def load_model(self, framework_model_dir, hf_token, subfolder="text_encoder"): + return self.from_pretrained(CLIPTextModel, framework_model_dir, hf_token, subfolder) + + +class CLIPWithProj(CLIP): + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + max_batch_size=16, + clip_skip=0, + ): + super().__init__( + pipeline_info, + model, + device=device, + max_batch_size=max_batch_size, + embedding_dim=pipeline_info.clipwithproj_embedding_dim(), + clip_skip=clip_skip, + ) + + def load_model(self, framework_model_dir, hf_token, subfolder="text_encoder_2"): + return self.from_pretrained(CLIPTextModelWithProjection, framework_model_dir, hf_token, subfolder) + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + output = { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.embedding_dim), + } + + if self.output_hidden_state: + output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim) + + return output + + +class UNet(BaseModel): + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + fp16=False, # used by TRT + max_batch_size=16, + text_maxlen=77, + unet_dim=4, + ): + super().__init__( + pipeline_info, + model=model, + device=device, + fp16=fp16, + max_batch_size=max_batch_size, + embedding_dim=pipeline_info.unet_embedding_dim(), + text_maxlen=text_maxlen, + ) + self.unet_dim = unet_dim + + def load_model(self, framework_model_dir, hf_token, subfolder="unet"): + options = {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 else {} + return self.from_pretrained(UNet2DConditionModel, framework_model_dir, hf_token, subfolder, **options) + + def get_input_names(self): + return ["sample", "timestep", "encoder_hidden_states"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + return { + "sample": [ + (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), + (2 * batch_size, self.unet_dim, latent_height, latent_width), + (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), + ], + "encoder_hidden_states": [ + (2 * min_batch, self.text_maxlen, self.embedding_dim), + (2 * batch_size, self.text_maxlen, self.embedding_dim), + (2 * max_batch, self.text_maxlen, self.embedding_dim), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "timestep": [1], + "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (2 * batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn( + 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + ), + torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + ) + + +class UNetXL(BaseModel): + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + fp16=False, # used by TRT + max_batch_size=16, + text_maxlen=77, + unet_dim=4, + time_dim=6, + ): + super().__init__( + pipeline_info, + model, + device=device, + fp16=fp16, + max_batch_size=max_batch_size, + embedding_dim=pipeline_info.unet_embedding_dim(), + text_maxlen=text_maxlen, + ) + self.unet_dim = unet_dim + self.time_dim = time_dim + + def load_model(self, framework_model_dir, hf_token, subfolder="unet"): + options = {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 else {} + return self.from_pretrained(UNet2DConditionModel, framework_model_dir, hf_token, subfolder, **options) + + def get_input_names(self): + return ["sample", "timestep", "encoder_hidden_states", "text_embeds", "time_ids"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + "text_embeds": {0: "2B"}, + "time_ids": {0: "2B"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + return { + "sample": [ + (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), + (2 * batch_size, self.unet_dim, latent_height, latent_width), + (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), + ], + "encoder_hidden_states": [ + (2 * min_batch, self.text_maxlen, self.embedding_dim), + (2 * batch_size, self.text_maxlen, self.embedding_dim), + (2 * max_batch, self.text_maxlen, self.embedding_dim), + ], + "text_embeds": [(2 * min_batch, 1280), (2 * batch_size, 1280), (2 * max_batch, 1280)], + "time_ids": [ + (2 * min_batch, self.time_dim), + (2 * batch_size, self.time_dim), + (2 * max_batch, self.time_dim), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "timestep": (1,), + "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (2 * batch_size, 4, latent_height, latent_width), + "text_embeds": (2 * batch_size, 1280), + "time_ids": (2 * batch_size, self.time_dim), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn( + 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + ), + torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + { + "added_cond_kwargs": { + "text_embeds": torch.randn(2 * batch_size, 1280, dtype=dtype, device=self.device), + "time_ids": torch.randn(2 * batch_size, self.time_dim, dtype=dtype, device=self.device), + } + }, + ) + + +# VAE Decoder +class VAE(BaseModel): + def __init__(self, pipeline_info: PipelineInfo, model, device, max_batch_size): + super().__init__( + pipeline_info, + model=model, + device=device, + max_batch_size=max_batch_size, + ) + + def load_model(self, framework_model_dir, hf_token: Optional[str] = None, subfolder: str = "vae_decoder"): + model_dir = os.path.join(framework_model_dir, self.pipeline_info.name(), subfolder) + if not os.path.exists(model_dir): + vae = AutoencoderKL.from_pretrained( + self.pipeline_info.name(), + subfolder="vae", + use_safetensors=self.pipeline_info.use_safetensors(), + use_auth_token=hf_token, + ).to(self.device) + vae.save_pretrained(model_dir) + else: + print(f"Load {self.name} pytorch model from: {model_dir}") + vae = AutoencoderKL.from_pretrained(model_dir).to(self.device) + + vae.forward = vae.decode + return vae + + def get_input_names(self): + return ["latent"] + + def get_output_names(self): + return ["images"] + + def get_dynamic_axes(self): + return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + return { + "latent": [ + (min_batch, 4, min_latent_height, min_latent_width), + (batch_size, 4, latent_height, latent_width), + (max_batch, 4, max_latent_height, max_latent_width), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "latent": (batch_size, 4, latent_height, latent_width), + "images": (batch_size, 3, image_height, image_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return (torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device),) + + +def get_tokenizer(pipeline_info: PipelineInfo, framework_model_dir, hf_token, subfolder="tokenizer"): + tokenizer_dir = os.path.join(framework_model_dir, pipeline_info.name(), subfolder) + + if not os.path.exists(tokenizer_dir): + model = CLIPTokenizer.from_pretrained( + pipeline_info.name(), + subfolder=subfolder, + use_safetensors=pipeline_info.is_sd_xl(), + use_auth_token=hf_token, + ) + model.save_pretrained(tokenizer_dir) + else: + print(f"[I] Load tokenizer pytorch model from: {tokenizer_dir}") + model = CLIPTokenizer.from_pretrained(tokenizer_dir) + return model + + +class TorchVAEEncoder(torch.nn.Module): + def __init__(self, vae_encoder): + super().__init__() + self.vae_encoder = vae_encoder + + def forward(self, x): + return self.vae_encoder.encode(x).latent_dist.sample() + + +class VAEEncoder(BaseModel): + def __init__(self, pipeline_info: PipelineInfo, model, device, max_batch_size): + super().__init__( + pipeline_info, + model=model, + device=device, + max_batch_size=max_batch_size, + ) + + def load_model(self, framework_model_dir, hf_token, subfolder="vae_encoder"): + vae = self.from_pretrained(AutoencoderKL, framework_model_dir, hf_token, subfolder) + return TorchVAEEncoder(vae) + + def get_input_names(self): + return ["images"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return {"images": {0: "B", 2: "8H", 3: "8W"}, "latent": {0: "B", 2: "H", 3: "W"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + self.check_dims(batch_size, image_height, image_width) + + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + _, + _, + _, + _, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + + return { + "images": [ + (min_batch, 3, min_image_height, min_image_width), + (batch_size, 3, image_height, image_width), + (max_batch, 3, max_image_height, max_image_width), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "images": (batch_size, 3, image_height, image_width), + "latent": (batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py new file mode 100644 index 0000000000000..13c450a517eba --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py @@ -0,0 +1,721 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from utilities.py of TensorRT demo diffusion, which has the following license: +# +# Copyright 2022 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +from typing import List, Optional + +import numpy as np +import torch + + +class DDIMScheduler: + def __init__( + self, + device="cuda", + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + clip_sample: bool = False, + set_alpha_to_one: bool = False, + steps_offset: int = 1, + prediction_type: str = "epsilon", + ): + # this schedule is very specific to the latent diffusion model. + betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + alphas = 1.0 - betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + self.steps_offset = steps_offset + self.num_train_timesteps = num_train_timesteps + self.clip_sample = clip_sample + self.prediction_type = prediction_type + self.device = device + + def configure(self): + variance = np.zeros(self.num_inference_steps, dtype=np.float32) + for idx, timestep in enumerate(self.timesteps): + prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps + variance[idx] = self._get_variance(timestep, prev_timestep) + self.variance = torch.from_numpy(variance).to(self.device) + + timesteps = self.timesteps.long().cpu() + self.alphas_cumprod = self.alphas_cumprod[timesteps].to(self.device) + self.final_alpha_cumprod = self.final_alpha_cumprod.to(self.device) + + def scale_model_input(self, sample: torch.FloatTensor, idx, *args, **kwargs) -> torch.FloatTensor: + return sample + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def set_timesteps(self, num_inference_steps: int): + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(self.device) + self.timesteps += self.steps_offset + + def step( + self, + model_output, + sample, + idx, + timestep, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: torch.FloatTensor = None, + ): + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + prev_idx = idx + 1 + alpha_prod_t = self.alphas_cumprod[idx] + alpha_prod_t_prev = ( + self.alphas_cumprod[prev_idx] if prev_idx < self.num_inference_steps else self.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.prediction_type == "sample": + pred_original_sample = model_output + elif self.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # predict V + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # o_t = sqrt((1 - a_t-1)/(1 - a_t)) * sqrt(1 - a_t/a_t-1) + variance = self.variance[idx] + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the model_output is always re-derived from the clipped x_0 in Glide + model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 + device = model_output.device + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + variance_noise = torch.randn( + model_output.shape, generator=generator, device=device, dtype=model_output.dtype + ) + variance = variance ** (0.5) * eta * variance_noise + + prev_sample = prev_sample + variance + + return prev_sample + + def add_noise(self, init_latents, noise, idx, latent_timestep): + sqrt_alpha_prod = self.alphas_cumprod[idx] ** 0.5 + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[idx]) ** 0.5 + noisy_latents = sqrt_alpha_prod * init_latents + sqrt_one_minus_alpha_prod * noise + + return noisy_latents + + +class EulerAncestralDiscreteScheduler: + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + device="cuda", + steps_offset=0, + prediction_type="epsilon", + ): + # this schedule is very specific to the latent diffusion model. + betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + alphas = 1.0 - betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = self.sigmas.max() + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.is_scale_input_called = False + self.device = device + self.num_train_timesteps = num_train_timesteps + self.steps_offset = steps_offset + self.prediction_type = prediction_type + + def scale_model_input(self, sample: torch.FloatTensor, idx, timestep, *args, **kwargs) -> torch.FloatTensor: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int): + self.num_inference_steps = num_inference_steps + + timesteps = np.linspace(0, self.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy() + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=self.device) + self.timesteps = torch.from_numpy(timesteps).to(device=self.device) + + def configure(self): + dts = np.zeros(self.num_inference_steps, dtype=np.float32) + sigmas_up = np.zeros(self.num_inference_steps, dtype=np.float32) + for idx, timestep in enumerate(self.timesteps): + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + + sigma_from = self.sigmas[step_index] + sigma_to = self.sigmas[step_index + 1] + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + dt = sigma_down - sigma + dts[idx] = dt + sigmas_up[idx] = sigma_up + + self.dts = torch.from_numpy(dts).to(self.device) + self.sigmas_up = torch.from_numpy(sigmas_up).to(self.device) + + def step( + self, + model_output, + sample, + idx, + timestep, + generator=None, + ): + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + sigma_up = self.sigmas_up[idx] + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + dt = self.dts[idx] + + prev_sample = sample + derivative * dt + + device = model_output.device + noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(device) + + prev_sample = prev_sample + noise * sigma_up + + return prev_sample + + def add_noise(self, original_samples, noise, idx, timestep=None): + step_index = (self.timesteps == timestep).nonzero().item() + noisy_samples = original_samples + noise * self.sigmas[step_index] + return noisy_samples + + +class UniPCMultistepScheduler: + def __init__( + self, + device="cuda", + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.012, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: Optional[List[int]] = None, + ): + self.device = device + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector if disable_corrector else [] + self.last_sample = None + self.num_train_timesteps = num_train_timesteps + self.solver_order = solver_order + self.prediction_type = prediction_type + self.thresholding = thresholding + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.sample_max_value = sample_max_value + self.solver_type = solver_type + self.lower_order_final = lower_order_final + + def set_timesteps(self, num_inference_steps: int): + timesteps = ( + np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + + self.timesteps = torch.from_numpy(timesteps).to(self.device) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.solver_order + self.lower_order_nums = 0 + self.last_sample = None + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + def convert_model_output( + self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + ) -> torch.FloatTensor: + if self.predict_x0: + if self.prediction_type == "epsilon": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.prediction_type == "sample": + x0_pred = model_output + elif self.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." + ) + + if self.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.prediction_type == "epsilon": + return model_output + elif self.prediction_type == "sample": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." + ) + + def multistep_uni_p_bh_update( + self, + model_output: torch.FloatTensor, + prev_timestep: int, + sample: torch.FloatTensor, + order: int, + ) -> torch.FloatTensor: + timestep_list = self.timestep_list + model_output_list = self.model_outputs + + s0, t = self.timestep_list[-1], prev_timestep + m0 = model_output_list[-1] + x = sample + + lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + + h = lambda_t - lambda_s0 + + rks = [] + d1s = [] + for i in range(1, order): + si = timestep_list[-(i + 1)] + mi = model_output_list[-(i + 1)] + lambda_si = self.lambda_t[si] + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + d1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=self.device) + + r = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.solver_type == "bh1": + b_h = hh + elif self.solver_type == "bh2": + b_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + r.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / b_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + r = torch.stack(r) + b = torch.tensor(b, device=self.device) + + if len(d1s) > 0: + d1s = torch.stack(d1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=self.device) + else: + rhos_p = torch.linalg.solve(r[:-1, :-1], b[:-1]) + else: + d1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if d1s is not None: + pred_res = torch.einsum("k,bkchw->bchw", rhos_p, d1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * b_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if d1s is not None: + pred_res = torch.einsum("k,bkchw->bchw", rhos_p, d1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * b_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.FloatTensor, + this_timestep: int, + last_sample: torch.FloatTensor, + # this_sample: torch.FloatTensor, + order: int, + ) -> torch.FloatTensor: + timestep_list = self.timestep_list + model_output_list = self.model_outputs + + s0, t = timestep_list[-1], this_timestep + m0 = model_output_list[-1] + x = last_sample + # x_t = this_sample + model_t = this_model_output + + lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + + h = lambda_t - lambda_s0 + + rks = [] + d1s = [] + for i in range(1, order): + si = timestep_list[-(i + 1)] + mi = model_output_list[-(i + 1)] + lambda_si = self.lambda_t[si] + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + d1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=self.device) + + r = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.solver_type == "bh1": + b_h = hh + elif self.solver_type == "bh2": + b_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + r.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / b_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + r = torch.stack(r) + b = torch.tensor(b, device=self.device) + + if len(d1s) > 0: + d1s = torch.stack(d1s, dim=1) + else: + d1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=self.device) + else: + rhos_c = torch.linalg.solve(r, b) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if d1s is not None: + corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], d1s) + else: + corr_res = 0 + d1_t = model_t - m0 + x_t = x_t_ - alpha_t * b_h * (corr_res + rhos_c[-1] * d1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if d1s is not None: + corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], d1s) + else: + corr_res = 0 + d1_t = model_t - m0 + x_t = x_t_ - sigma_t * b_h * (corr_res + rhos_c[-1] * d1_t) + x_t = x_t.to(x.dtype) + return x_t + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + return_dict: bool = True, + ): + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.device) + step_index = (self.timesteps == timestep).nonzero() + if len(step_index) == 0: + step_index = len(self.timesteps) - 1 + else: + step_index = step_index.item() + + use_corrector = step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None + + model_output_convert = self.convert_model_output(model_output, timestep, sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + this_timestep=timestep, + last_sample=self.last_sample, + # this_sample=sample, + order=self.this_order, + ) + + # now prepare to run the predictor + prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + + for i in range(self.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + if self.lower_order_final: + this_order = min(self.solver_order, len(self.timesteps) - step_index) + else: + this_order = self.solver_order + + self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + prev_timestep=prev_timestep, + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.solver_order: + self.lower_order_nums += 1 + + if not return_dict: + return (prev_sample,) + + return prev_sample + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=self.device, dtype=original_samples.dtype) + timesteps = timesteps.to(self.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def configure(self): + pass + + def __len__(self): + return self.num_train_timesteps diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py new file mode 100644 index 0000000000000..64c3c5bc80ecb --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -0,0 +1,181 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import os +from enum import Enum + +import torch +from diffusion_models import CLIP, VAE, CLIPWithProj, PipelineInfo, UNet, UNetXL + + +class EngineType(Enum): + ORT_CUDA = 0 # ONNX Runtime CUDA Execution Provider + ORT_TRT = 1 # ONNX Runtime TensorRT Execution Provider + TRT = 2 # TensorRT + TORCH = 3 # PyTorch + + +def get_engine_type(name: str) -> EngineType: + name_to_type = { + "ORT_CUDA": EngineType.ORT_CUDA, + "ORT_TRT": EngineType.ORT_TRT, + "TRT": EngineType.TRT, + "TORCH": EngineType.TORCH, + } + return name_to_type[name] + + +class EngineBuilder: + def __init__( + self, + engine_type: EngineType, + pipeline_info: PipelineInfo, + device="cuda", + max_batch_size=16, + hf_token=None, + use_cuda_graph=False, + ): + """ + Initializes the Engine Builder. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + device (str | torch.device): + device to run engine + max_batch_size (int): + Maximum batch size for dynamic batch engine. + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + """ + self.engine_type = engine_type + self.pipeline_info = pipeline_info + self.max_batch_size = max_batch_size + self.hf_token = hf_token + self.use_cuda_graph = use_cuda_graph + self.device = torch.device(device) + self.torch_device = torch.device(device, torch.cuda.current_device()) + self.stages = pipeline_info.stages() + self.vae_torch_fallback = self.pipeline_info.is_sd_xl() + + self.models = {} + self.engines = {} + self.torch_models = {} + + def teardown(self): + for engine in self.engines.values(): + del engine + self.engines = {} + + def get_cached_model_name(self, model_name): + if self.pipeline_info.is_inpaint(): + model_name += "_inpaint" + return model_name + + def get_onnx_path(self, model_name, onnx_dir, opt=True): + engine_name = self.engine_type.name.lower() + onnx_model_dir = os.path.join( + onnx_dir, self.get_cached_model_name(model_name) + (f".{engine_name}" if opt else "") + ) + os.makedirs(onnx_model_dir, exist_ok=True) + return os.path.join(onnx_model_dir, "model.onnx") + + def get_engine_path(self, engine_dir, model_name, profile_id): + return os.path.join(engine_dir, self.get_cached_model_name(model_name) + profile_id) + + def load_models(self, framework_model_dir: str): + # Disable torch SDPA since torch 2.0.* cannot export it to ONNX + if hasattr(torch.nn.functional, "scaled_dot_product_attention"): + delattr(torch.nn.functional, "scaled_dot_product_attention") + + # For TRT or ORT_TRT, we will export fp16 torch model for UNet. + # For ORT_CUDA, we export fp32 model first, then optimize to fp16. + export_fp16_unet = self.engine_type in [EngineType.ORT_TRT, EngineType.TRT] + + if "clip" in self.stages: + self.models["clip"] = CLIP( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + max_batch_size=self.max_batch_size, + clip_skip=0, + ) + + if "clip2" in self.stages: + self.models["clip2"] = CLIPWithProj( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + max_batch_size=self.max_batch_size, + clip_skip=0, + ) + + if "unet" in self.stages: + self.models["unet"] = UNet( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + fp16=export_fp16_unet, + max_batch_size=self.max_batch_size, + unet_dim=(9 if self.pipeline_info.is_inpaint() else 4), + ) + + if "unetxl" in self.stages: + self.models["unetxl"] = UNetXL( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + fp16=export_fp16_unet, + max_batch_size=self.max_batch_size, + unet_dim=4, + time_dim=(5 if self.pipeline_info.is_sd_xl_refiner() else 6), + ) + + # VAE Decoder + if "vae" in self.stages: + self.models["vae"] = VAE( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + max_batch_size=self.max_batch_size, + ) + + if self.vae_torch_fallback: + self.torch_models["vae"] = self.models["vae"].load_model(framework_model_dir, self.hf_token) + + def load_resources(self, image_height, image_width, batch_size): + # Allocate buffers for I/O bindings + for model_name, obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + self.engines[model_name].allocate_buffers( + shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device + ) + + def vae_decode(self, latents): + if self.vae_torch_fallback: + latents = latents.to(dtype=torch.float32) + self.torch_models["vae"] = self.torch_models["vae"].to(dtype=torch.float32) + images = self.torch_models["vae"](latents)["sample"] + else: + images = self.run_engine("vae", {"latent": latents})["images"] + + return images + + +def get_engine_paths(work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType): + root_dir = work_dir or "." + short_name = pipeline_info.short_name() + + # When both ORT_CUDA and ORT_TRT/TRT is used, we shall make sub directory for each engine since + # ORT_CUDA need fp32 torch model, while ORT_TRT/TRT use fp16 torch model. + onnx_dir = os.path.join(root_dir, engine_type.name, short_name, "onnx") + engine_dir = os.path.join(root_dir, engine_type.name, short_name, "engine") + output_dir = os.path.join(root_dir, engine_type.name, short_name, "output") + timing_cache = os.path.join(root_dir, engine_type.name, "timing_cache") + framework_model_dir = os.path.join(root_dir, engine_type.name, "torch_model") + + return onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py new file mode 100644 index 0000000000000..253cdcc45bf2e --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py @@ -0,0 +1,263 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import gc +import logging +import os +import shutil + +import torch +from cuda import cudart +from diffusion_models import PipelineInfo +from engine_builder import EngineBuilder, EngineType + +import onnxruntime as ort +from onnxruntime.transformers.io_binding_helper import CudaSession + +logger = logging.getLogger(__name__) + + +class OrtTensorrtEngine(CudaSession): + def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, workspace_size, enable_cuda_graph): + self.engine_path = engine_path + self.ort_trt_provider_options = self.get_tensorrt_provider_options( + input_profile, + workspace_size, + fp16, + device_id, + enable_cuda_graph, + ) + + session_options = ort.SessionOptions() + session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL + print("creating TRT EP session for ", onnx_path) + ort_session = ort.InferenceSession( + onnx_path, + session_options, + providers=[ + ("TensorrtExecutionProvider", self.ort_trt_provider_options), + ], + ) + print("created TRT EP session for ", onnx_path) + + device = torch.device("cuda", device_id) + super().__init__(ort_session, device, enable_cuda_graph) + + def get_tensorrt_provider_options(self, input_profile, workspace_size, fp16, device_id, enable_cuda_graph): + trt_ep_options = { + "device_id": device_id, + "trt_fp16_enable": fp16, + "trt_engine_cache_enable": True, + "trt_timing_cache_enable": True, + "trt_detailed_build_log": True, + "trt_engine_cache_path": self.engine_path, + } + + if enable_cuda_graph: + trt_ep_options["trt_cuda_graph_enable"] = True + + if workspace_size > 0: + trt_ep_options["trt_max_workspace_size"] = workspace_size + + if input_profile: + min_shapes = [] + max_shapes = [] + opt_shapes = [] + for name, profile in input_profile.items(): + assert isinstance(profile, list) and len(profile) == 3 + min_shape = profile[0] + opt_shape = profile[1] + max_shape = profile[2] + assert len(min_shape) == len(opt_shape) and len(opt_shape) == len(max_shape) + + min_shapes.append(f"{name}:" + "x".join([str(x) for x in min_shape])) + opt_shapes.append(f"{name}:" + "x".join([str(x) for x in opt_shape])) + max_shapes.append(f"{name}:" + "x".join([str(x) for x in max_shape])) + + trt_ep_options["trt_profile_min_shapes"] = ",".join(min_shapes) + trt_ep_options["trt_profile_max_shapes"] = ",".join(max_shapes) + trt_ep_options["trt_profile_opt_shapes"] = ",".join(opt_shapes) + + logger.info("trt_ep_options=%s", trt_ep_options) + + return trt_ep_options + + def allocate_buffers(self, shape_dict, device): + super().allocate_buffers(shape_dict) + + +class OrtTensorrtEngineBuilder(EngineBuilder): + def __init__( + self, + pipeline_info: PipelineInfo, + max_batch_size=16, + hf_token=None, + device="cuda", + use_cuda_graph=False, + ): + """ + Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + device (str): + device to run. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + """ + super().__init__( + EngineType.ORT_TRT, + pipeline_info, + max_batch_size=max_batch_size, + hf_token=hf_token, + device=device, + use_cuda_graph=use_cuda_graph, + ) + + def has_engine_file(self, engine_path): + if os.path.isdir(engine_path): + children = os.scandir(engine_path) + for entry in children: + if entry.is_file() and entry.name.endswith(".engine"): + return True + return False + + def get_work_space_size(self, model_name, max_workspace_size): + gibibyte = 2**30 + workspace_size = 4 * gibibyte if model_name == "clip" else max_workspace_size + if workspace_size == 0: + _, free_mem, _ = cudart.cudaMemGetInfo() + # The following logic are adopted from TensorRT demo diffusion. + if free_mem > 6 * gibibyte: + workspace_size = free_mem - 4 * gibibyte + return workspace_size + + def build_engines( + self, + engine_dir, + framework_model_dir, + onnx_dir, + onnx_opset, + opt_image_height, + opt_image_width, + opt_batch_size=1, + force_engine_rebuild=False, + static_batch=False, + static_image_shape=True, + max_workspace_size=0, + device_id=0, + ): + self.torch_device = torch.device("cuda", device_id) + self.load_models(framework_model_dir) + + if force_engine_rebuild: + if os.path.isdir(onnx_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir) + shutil.rmtree(onnx_dir) + if os.path.isdir(engine_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir) + shutil.rmtree(engine_dir) + + if not os.path.isdir(engine_dir): + os.makedirs(engine_dir) + + if not os.path.isdir(onnx_dir): + os.makedirs(onnx_dir) + + # Export models to ONNX + for model_name, model_obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + + profile_id = model_obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape + ) + engine_path = self.get_engine_path(engine_dir, model_name, profile_id) + if not self.has_engine_file(engine_path): + onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) + onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) + if not os.path.exists(onnx_opt_path): + if not os.path.exists(onnx_path): + logger.info(f"Exporting model: {onnx_path}") + model = model_obj.load_model(framework_model_dir, self.hf_token) + with torch.inference_mode(), torch.autocast("cuda"): + inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=True, + input_names=model_obj.get_input_names(), + output_names=model_obj.get_output_names(), + dynamic_axes=model_obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + else: + logger.info("Found cached model: %s", onnx_path) + + # Optimize onnx + if not os.path.exists(onnx_opt_path): + logger.info("Generating optimizing model: %s", onnx_opt_path) + model_obj.optimize_trt(onnx_path, onnx_opt_path) + else: + logger.info("Found cached optimized model: %s", onnx_opt_path) + + built_engines = {} + for model_name, model_obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + + profile_id = model_obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape + ) + + engine_path = self.get_engine_path(engine_dir, model_name, profile_id) + onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) + + if not self.has_engine_file(engine_path): + logger.info( + "Building TensorRT engine for %s from %s to %s. It can take a while to complete...", + model_name, + onnx_opt_path, + engine_path, + ) + else: + logger.info("Reuse cached TensorRT engine in directory %s", engine_path) + + input_profile = model_obj.get_input_profile( + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch=static_batch, + static_image_shape=static_image_shape, + ) + + engine = OrtTensorrtEngine( + engine_path, + device_id, + onnx_opt_path, + fp16=True, + input_profile=input_profile, + workspace_size=self.get_work_space_size(model_name, max_workspace_size), + enable_cuda_graph=self.use_cuda_graph, + ) + + built_engines[model_name] = engine + + self.engines = built_engines + + return built_engines + + def run_engine(self, model_name, feed_dict): + return self.engines[model_name].infer(feed_dict) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py new file mode 100644 index 0000000000000..4a924abfb8600 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py @@ -0,0 +1,507 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import gc +import os +import pathlib +from collections import OrderedDict + +import numpy as np +import onnx +import onnx_graphsurgeon as gs +import tensorrt as trt +import torch +from cuda import cudart +from diffusion_models import PipelineInfo +from engine_builder import EngineBuilder, EngineType +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.trt import ( + CreateConfig, + ModifyNetworkOutputs, + Profile, + engine_from_bytes, + engine_from_network, + network_from_onnx_path, + save_engine, +) +from trt_utilities import TRT_LOGGER + +# Map of numpy dtype -> torch dtype +numpy_to_torch_dtype_dict = { + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, +} + + +def _cuda_assert(cuda_ret): + err = cuda_ret[0] + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError( + f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" + ) + if len(cuda_ret) > 1: + return cuda_ret[1] + return None + + +class TensorrtEngine: + def __init__( + self, + engine_path, + ): + self.engine_path = engine_path + self.engine = None + self.context = None + self.buffers = OrderedDict() + self.tensors = OrderedDict() + self.cuda_graph_instance = None + + def __del__(self): + del self.engine + del self.context + del self.buffers + del self.tensors + + def refit(self, onnx_path, onnx_refit_path): + def convert_int64(arr): + if len(arr.shape) == 0: + return np.int32(arr) + return arr + + def add_to_map(refit_dict, name, values): + if name in refit_dict: + assert refit_dict[name] is None + if values.dtype == np.int64: + values = convert_int64(values) + refit_dict[name] = values + + print(f"Refitting TensorRT engine with {onnx_refit_path} weights") + refit_nodes = gs.import_onnx(onnx.load(onnx_refit_path)).toposort().nodes + + # Construct mapping from weight names in refit model -> original model + name_map = {} + for n, node in enumerate(gs.import_onnx(onnx.load(onnx_path)).toposort().nodes): + refit_node = refit_nodes[n] + assert node.op == refit_node.op + # Constant nodes in ONNX do not have inputs but have a constant output + if node.op == "Constant": + name_map[refit_node.outputs[0].name] = node.outputs[0].name + # Handle scale and bias weights + elif node.op == "Conv": + if node.inputs[1].__class__ == gs.Constant: + name_map[refit_node.name + "_TRTKERNEL"] = node.name + "_TRTKERNEL" + if node.inputs[2].__class__ == gs.Constant: + name_map[refit_node.name + "_TRTBIAS"] = node.name + "_TRTBIAS" + # For all other nodes: find node inputs that are initializers (gs.Constant) + else: + for i, inp in enumerate(node.inputs): + if inp.__class__ == gs.Constant: + name_map[refit_node.inputs[i].name] = inp.name + + def map_name(name): + if name in name_map: + return name_map[name] + return name + + # Construct refit dictionary + refit_dict = {} + refitter = trt.Refitter(self.engine, TRT_LOGGER) + all_weights = refitter.get_all() + for layer_name, role in zip(all_weights[0], all_weights[1]): + # for specialized roles, use a unique name in the map: + if role == trt.WeightsRole.KERNEL: + name = layer_name + "_TRTKERNEL" + elif role == trt.WeightsRole.BIAS: + name = layer_name + "_TRTBIAS" + else: + name = layer_name + + assert name not in refit_dict, "Found duplicate layer: " + name + refit_dict[name] = None + + for n in refit_nodes: + # Constant nodes in ONNX do not have inputs but have a constant output + if n.op == "Constant": + name = map_name(n.outputs[0].name) + print(f"Add Constant {name}\n") + add_to_map(refit_dict, name, n.outputs[0].values) + + # Handle scale and bias weights + elif n.op == "Conv": + if n.inputs[1].__class__ == gs.Constant: + name = map_name(n.name + "_TRTKERNEL") + add_to_map(refit_dict, name, n.inputs[1].values) + + if n.inputs[2].__class__ == gs.Constant: + name = map_name(n.name + "_TRTBIAS") + add_to_map(refit_dict, name, n.inputs[2].values) + + # For all other nodes: find node inputs that are initializers (AKA gs.Constant) + else: + for inp in n.inputs: + name = map_name(inp.name) + if inp.__class__ == gs.Constant: + add_to_map(refit_dict, name, inp.values) + + for layer_name, weights_role in zip(all_weights[0], all_weights[1]): + if weights_role == trt.WeightsRole.KERNEL: + custom_name = layer_name + "_TRTKERNEL" + elif weights_role == trt.WeightsRole.BIAS: + custom_name = layer_name + "_TRTBIAS" + else: + custom_name = layer_name + + # Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model + if layer_name.startswith("onnx::Trilu"): + continue + + if refit_dict[custom_name] is not None: + refitter.set_weights(layer_name, weights_role, refit_dict[custom_name]) + else: + print(f"[W] No refit weights for layer: {layer_name}") + + if not refitter.refit_cuda_engine(): + print("Failed to refit!") + exit(0) + + def build( + self, + onnx_path, + fp16, + input_profile=None, + enable_refit=False, + enable_preview=False, + enable_all_tactics=False, + timing_cache=None, + update_output_names=None, + ): + print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") + p = Profile() + if input_profile: + for name, dims in input_profile.items(): + assert len(dims) == 3 + p.add(name, min=dims[0], opt=dims[1], max=dims[2]) + + config_kwargs = {} + if not enable_all_tactics: + config_kwargs["tactic_sources"] = [] + + network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) + if update_output_names: + print(f"Updating network outputs to {update_output_names}") + network = ModifyNetworkOutputs(network, update_output_names) + engine = engine_from_network( + network, + config=CreateConfig( + fp16=fp16, refittable=enable_refit, profiles=[p], load_timing_cache=timing_cache, **config_kwargs + ), + save_timing_cache=timing_cache, + ) + save_engine(engine, path=self.engine_path) + + def load(self): + print(f"Loading TensorRT engine: {self.engine_path}") + self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) + + def activate(self, reuse_device_memory=None): + if reuse_device_memory: + self.context = self.engine.create_execution_context_without_device_memory() + self.context.device_memory = reuse_device_memory + else: + self.context = self.engine.create_execution_context() + + def allocate_buffers(self, shape_dict=None, device="cuda"): + for idx in range(self.engine.num_io_tensors): + binding = self.engine[idx] + if shape_dict and binding in shape_dict: + shape = shape_dict[binding] + else: + shape = self.engine.get_binding_shape(binding) + dtype = trt.nptype(self.engine.get_binding_dtype(binding)) + if self.engine.binding_is_input(binding): + self.context.set_binding_shape(idx, shape) + tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) + self.tensors[binding] = tensor + + def infer(self, feed_dict, stream, use_cuda_graph=False): + for name, buf in feed_dict.items(): + self.tensors[name].copy_(buf) + + for name, tensor in self.tensors.items(): + self.context.set_tensor_address(name, tensor.data_ptr()) + + if use_cuda_graph: + if self.cuda_graph_instance is not None: + _cuda_assert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) + _cuda_assert(cudart.cudaStreamSynchronize(stream)) + else: + # do inference before CUDA graph capture + noerror = self.context.execute_async_v3(stream) + if not noerror: + raise ValueError("ERROR: inference failed.") + # capture cuda graph + _cuda_assert( + cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal) + ) + self.context.execute_async_v3(stream) + self.graph = _cuda_assert(cudart.cudaStreamEndCapture(stream)) + + from cuda import nvrtc + + result, major, minor = nvrtc.nvrtcVersion() + assert result == nvrtc.nvrtcResult(0) + if major < 12: + self.cuda_graph_instance = _cuda_assert( + cudart.cudaGraphInstantiate(self.graph, b"", 0) + ) # cuda < 12 + else: + self.cuda_graph_instance = _cuda_assert(cudart.cudaGraphInstantiate(self.graph, 0)) # cuda >= 12 + else: + noerror = self.context.execute_async_v3(stream) + if not noerror: + raise ValueError("ERROR: inference failed.") + + return self.tensors + + +class TensorrtEngineBuilder(EngineBuilder): + """ + Helper class to hide the detail of TensorRT Engine from pipeline. + """ + + def __init__( + self, + pipeline_info: PipelineInfo, + max_batch_size=16, + hf_token=None, + device="cuda", + use_cuda_graph=False, + ): + """ + Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + device (str): + device to run. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + """ + super().__init__( + EngineType.TRT, + pipeline_info, + max_batch_size=max_batch_size, + hf_token=hf_token, + device=device, + use_cuda_graph=use_cuda_graph, + ) + + self.stream = None + self.shared_device_memory = None + + def load_resources(self, image_height, image_width, batch_size): + super().load_resources(image_height, image_width, batch_size) + + self.stream = _cuda_assert(cudart.cudaStreamCreate()) + + def teardown(self): + super().teardown() + + if self.shared_device_memory: + cudart.cudaFree(self.shared_device_memory) + + cudart.cudaStreamDestroy(self.stream) + del self.stream + + def load_engines( + self, + engine_dir, + framework_model_dir, + onnx_dir, + onnx_opset, + opt_batch_size, + opt_image_height, + opt_image_width, + force_export=False, + force_optimize=False, + force_build=False, + static_batch=False, + static_shape=True, + enable_refit=False, + enable_preview=False, + enable_all_tactics=False, + timing_cache=None, + onnx_refit_dir=None, + ): + """ + Build and load engines for TensorRT accelerated inference. + Export ONNX models first, if applicable. + + Args: + engine_dir (str): + Directory to write the TensorRT engines. + framework_model_dir (str): + Directory to write the framework model ckpt. + onnx_dir (str): + Directory to write the ONNX models. + onnx_opset (int): + ONNX opset version to export the models. + opt_batch_size (int): + Batch size to optimize for during engine building. + opt_image_height (int): + Image height to optimize for during engine building. Must be a multiple of 8. + opt_image_width (int): + Image width to optimize for during engine building. Must be a multiple of 8. + force_export (bool): + Force re-exporting the ONNX models. + force_optimize (bool): + Force re-optimizing the ONNX models. + force_build (bool): + Force re-building the TensorRT engine. + static_batch (bool): + Build engine only for specified opt_batch_size. + static_shape (bool): + Build engine only for specified opt_image_height & opt_image_width. Default = True. + enable_refit (bool): + Build engines with refit option enabled. + enable_preview (bool): + Enable TensorRT preview features. + enable_all_tactics (bool): + Enable all tactic sources during TensorRT engine builds. + timing_cache (str): + Path to the timing cache to accelerate build or None + onnx_refit_dir (str): + Directory containing refit ONNX models. + """ + # Create directory + for directory in [engine_dir, onnx_dir]: + if not os.path.exists(directory): + print(f"[I] Create directory: {directory}") + pathlib.Path(directory).mkdir(parents=True) + + self.load_models(framework_model_dir) + + # Export models to ONNX + for model_name, obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + profile_id = obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape + ) + engine_path = self.get_engine_path(engine_dir, model_name, profile_id) + if force_export or force_build or not os.path.exists(engine_path): + onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) + onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) + if force_export or not os.path.exists(onnx_opt_path): + if force_export or not os.path.exists(onnx_path): + print(f"Exporting model: {onnx_path}") + model = obj.load_model(framework_model_dir, self.hf_token) + with torch.inference_mode(), torch.autocast("cuda"): + inputs = obj.get_sample_input(1, opt_image_height, opt_image_width) + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=True, + input_names=obj.get_input_names(), + output_names=obj.get_output_names(), + dynamic_axes=obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + else: + print(f"Found cached model: {onnx_path}") + + # Optimize onnx + if force_optimize or not os.path.exists(onnx_opt_path): + print(f"Generating optimizing model: {onnx_opt_path}") + obj.optimize_trt(onnx_path, onnx_opt_path) + else: + print(f"Found cached optimized model: {onnx_opt_path} ") + + # Build TensorRT engines + for model_name, obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + profile_id = obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape + ) + engine_path = self.get_engine_path(engine_dir, model_name, profile_id) + engine = TensorrtEngine(engine_path) + onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) + + if force_build or not os.path.exists(engine.engine_path): + engine.build( + onnx_opt_path, + fp16=True, + input_profile=obj.get_input_profile( + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch, + static_shape, + ), + enable_refit=enable_refit, + enable_preview=enable_preview, + enable_all_tactics=enable_all_tactics, + timing_cache=timing_cache, + update_output_names=None, + ) + self.engines[model_name] = engine + + # Load TensorRT engines + for model_name in self.models: + if model_name == "vae" and self.vae_torch_fallback: + continue + self.engines[model_name].load() + if onnx_refit_dir: + onnx_refit_path = self.get_onnx_path(model_name, onnx_refit_dir, opt=True) + if os.path.exists(onnx_refit_path): + self.engines[model_name].refit(onnx_opt_path, onnx_refit_path) + + def max_device_memory(self): + max_device_memory = 0 + for _model_name, engine in self.engines.items(): + max_device_memory = max(max_device_memory, engine.engine.device_memory_size) + return max_device_memory + + def activate_engines(self, shared_device_memory=None): + if shared_device_memory is None: + max_device_memory = self.max_device_memory() + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + self.shared_device_memory = shared_device_memory + # Load and activate TensorRT engines + for engine in self.engines.values(): + engine.activate(reuse_device_memory=self.shared_device_memory) + + def run_engine(self, model_name, feed_dict): + return self.engines[model_name].infer(feed_dict, self.stream, use_cuda_graph=self.use_cuda_graph) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/models.py deleted file mode 100644 index 0f7688a3df9f6..0000000000000 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/models.py +++ /dev/null @@ -1,368 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -# -# Copyright 2023 The HuggingFace Inc. team. -# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Models used in Stable diffusion. -""" -import logging - -import onnx -import onnx_graphsurgeon as gs -import torch -from onnx import shape_inference -from ort_optimizer import OrtStableDiffusionOptimizer -from polygraphy.backend.onnx.loader import fold_constants - -logger = logging.getLogger(__name__) - - -class TrtOptimizer: - def __init__(self, onnx_graph): - self.graph = gs.import_onnx(onnx_graph) - - def cleanup(self): - self.graph.cleanup().toposort() - - def get_optimized_onnx_graph(self): - return gs.export_onnx(self.graph) - - def select_outputs(self, keep, names=None): - self.graph.outputs = [self.graph.outputs[o] for o in keep] - if names: - for i, name in enumerate(names): - self.graph.outputs[i].name = name - - def fold_constants(self): - onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) - self.graph = gs.import_onnx(onnx_graph) - - def infer_shapes(self): - onnx_graph = gs.export_onnx(self.graph) - if onnx_graph.ByteSize() > 2147483648: - raise TypeError("ERROR: model size exceeds supported 2GB limit") - else: - onnx_graph = shape_inference.infer_shapes(onnx_graph) - - self.graph = gs.import_onnx(onnx_graph) - - -class BaseModel: - def __init__(self, model, name, device="cuda", fp16=False, max_batch_size=16, embedding_dim=768, text_maxlen=77): - self.model = model - self.name = name - self.fp16 = fp16 - self.device = device - - self.min_batch = 1 - self.max_batch = max_batch_size - self.min_image_shape = 256 # min image resolution: 256x256 - self.max_image_shape = 1024 # max image resolution: 1024x1024 - self.min_latent_shape = self.min_image_shape // 8 - self.max_latent_shape = self.max_image_shape // 8 - - self.embedding_dim = embedding_dim - self.text_maxlen = text_maxlen - - self.model_type = name.lower() if name in ["CLIP", "UNet"] else "vae" - self.ort_optimizer = OrtStableDiffusionOptimizer(self.model_type) - - def get_model(self): - return self.model - - def get_input_names(self): - pass - - def get_output_names(self): - pass - - def get_dynamic_axes(self): - return None - - def get_sample_input(self, batch_size, image_height, image_width): - pass - - def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape): - """For TensorRT EP""" - ( - min_batch, - max_batch, - min_image_height, - max_image_height, - min_image_width, - max_image_width, - _, - _, - _, - _, - ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) - - profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}" - - if self.name != "CLIP": - if static_image_shape: - profile_id += f"_h_{image_height}_w_{image_width}" - else: - profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}" - - return profile_id - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - """For TensorRT""" - return None - - def get_shape_dict(self, batch_size, image_height, image_width): - return None - - def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True): - self.ort_optimizer.optimize(input_onnx_path, optimized_onnx_path, to_fp16) - - def optimize_trt(self, input_onnx_path, optimized_onnx_path): - onnx_graph = onnx.load(input_onnx_path) - opt = TrtOptimizer(onnx_graph) - opt.cleanup() - opt.fold_constants() - opt.infer_shapes() - opt.cleanup() - onnx_opt_graph = opt.get_optimized_onnx_graph() - onnx.save(onnx_opt_graph, optimized_onnx_path) - - def check_dims(self, batch_size, image_height, image_width): - assert batch_size >= self.min_batch and batch_size <= self.max_batch - assert image_height % 8 == 0 or image_width % 8 == 0 - latent_height = image_height // 8 - latent_width = image_width // 8 - assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape - assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape - return (latent_height, latent_width) - - def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape): - min_batch = batch_size if static_batch else self.min_batch - max_batch = batch_size if static_batch else self.max_batch - latent_height = image_height // 8 - latent_width = image_width // 8 - min_image_height = image_height if static_image_shape else self.min_image_shape - max_image_height = image_height if static_image_shape else self.max_image_shape - min_image_width = image_width if static_image_shape else self.min_image_shape - max_image_width = image_width if static_image_shape else self.max_image_shape - min_latent_height = latent_height if static_image_shape else self.min_latent_shape - max_latent_height = latent_height if static_image_shape else self.max_latent_shape - min_latent_width = latent_width if static_image_shape else self.min_latent_shape - max_latent_width = latent_width if static_image_shape else self.max_latent_shape - return ( - min_batch, - max_batch, - min_image_height, - max_image_height, - min_image_width, - max_image_width, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) - - -class CLIP(BaseModel): - def __init__(self, model, device, max_batch_size, embedding_dim): - super().__init__( - model=model, - name="CLIP", - device=device, - max_batch_size=max_batch_size, - embedding_dim=embedding_dim, - ) - - def get_input_names(self): - return ["input_ids"] - - def get_output_names(self): - return ["text_embeddings"] - - def get_dynamic_axes(self): - return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - self.check_dims(batch_size, image_height, image_width) - min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( - batch_size, image_height, image_width, static_batch, static_image_shape - ) - return { - "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] - } - - def get_shape_dict(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - return { - "input_ids": (batch_size, self.text_maxlen), - "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), - } - - def get_sample_input(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) - - def optimize_trt(self, input_onnx_path, optimized_onnx_path): - onnx_graph = onnx.load(input_onnx_path) - opt = TrtOptimizer(onnx_graph) - opt.select_outputs([0]) # delete graph output#1 - opt.cleanup() - opt.fold_constants() - opt.infer_shapes() - opt.select_outputs([0], names=["text_embeddings"]) # rename network output - opt.cleanup() - onnx_opt_graph = opt.get_optimized_onnx_graph() - onnx.save(onnx_opt_graph, optimized_onnx_path) - - -class UNet(BaseModel): - def __init__( - self, - model, - device="cuda", - fp16=False, # used by TRT - max_batch_size=16, - embedding_dim=768, - text_maxlen=77, - unet_dim=4, - ): - super().__init__( - model=model, - name="UNet", - device=device, - fp16=fp16, - max_batch_size=max_batch_size, - embedding_dim=embedding_dim, - text_maxlen=text_maxlen, - ) - self.unet_dim = unet_dim - - def get_input_names(self): - return ["sample", "timestep", "encoder_hidden_states"] - - def get_output_names(self): - return ["latent"] - - def get_dynamic_axes(self): - return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "encoder_hidden_states": {0: "2B"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - } - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - ( - min_batch, - max_batch, - _, - _, - _, - _, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) - return { - "sample": [ - (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), - (2 * batch_size, self.unet_dim, latent_height, latent_width), - (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), - ], - "encoder_hidden_states": [ - (2 * min_batch, self.text_maxlen, self.embedding_dim), - (2 * batch_size, self.text_maxlen, self.embedding_dim), - (2 * max_batch, self.text_maxlen, self.embedding_dim), - ], - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), - "timestep": [1], - "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), - "latent": (2 * batch_size, 4, latent_height, latent_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - dtype = torch.float16 if self.fp16 else torch.float32 - return ( - torch.randn( - 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device - ), - torch.tensor([1.0], dtype=torch.float32, device=self.device), - torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), - ) - - -class VAE(BaseModel): - def __init__(self, model, device, max_batch_size, embedding_dim): - super().__init__( - model=model, - name="VAE Decoder", - device=device, - max_batch_size=max_batch_size, - embedding_dim=embedding_dim, - ) - - def get_input_names(self): - return ["latent"] - - def get_output_names(self): - return ["images"] - - def get_dynamic_axes(self): - return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}} - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - ( - min_batch, - max_batch, - _, - _, - _, - _, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) - return { - "latent": [ - (min_batch, 4, min_latent_height, min_latent_width), - (batch_size, 4, latent_height, latent_width), - (max_batch, 4, max_latent_height, max_latent_width), - ] - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return { - "latent": (batch_size, 4, latent_height, latent_width), - "images": (batch_size, 3, image_height, image_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py index 6134fa7bddcf4..37785869a355b 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py @@ -43,16 +43,14 @@ StableDiffusionSafetyChecker, ) from diffusers.schedulers import DDIMScheduler -from diffusers.utils import DIFFUSERS_CACHE -from huggingface_hub import snapshot_download -from models import CLIP, VAE, UNet -from ort_utils import Engines +from diffusion_models import CLIP, VAE, PipelineInfo, UNet +from ort_utils import Engines, StableDiffusionPipelineMixin from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer logger = logging.getLogger(__name__) -class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline): +class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipelineMixin, StableDiffusionPipeline): r""" Pipeline for text-to-image generation using CUDA provider in ONNX Runtime. This pipeline inherits from [`StableDiffusionPipeline`]. Check the documentation in super class for most parameters. @@ -70,11 +68,12 @@ def __init__( requires_safety_checker: bool = True, # ONNX export parameters onnx_opset: int = 14, - onnx_dir: str = "raw_onnx", + onnx_dir: str = "onnx_ort", # Onnxruntime execution provider parameters - engine_dir: str = "onnxruntime_optimized_onnx", + engine_dir: str = "ORT_CUDA", force_engine_rebuild: bool = False, enable_cuda_graph: bool = False, + pipeline_info: PipelineInfo = None, ): super().__init__( vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker @@ -96,51 +95,38 @@ def __init__( self.fp16 = False - def __load_models(self): - self.embedding_dim = self.text_encoder.config.hidden_size + self.pipeline_info = pipeline_info - self.models["clip"] = CLIP( - self.text_encoder, - device=self.torch_device, - max_batch_size=self.max_batch_size, - embedding_dim=self.embedding_dim, - ) + def load_models(self): + assert self.pipeline_info.clip_embedding_dim() == self.text_encoder.config.hidden_size - self.models["unet"] = UNet( - self.unet, - device=self.torch_device, - fp16=self.fp16, - max_batch_size=self.max_batch_size, - embedding_dim=self.embedding_dim, - unet_dim=(9 if self.inpaint else 4), - ) + stages = self.pipeline_info.stages() + if "clip" in stages: + self.models["clip"] = CLIP( + self.pipeline_info, + self.text_encoder, + device=self.torch_device, + max_batch_size=self.max_batch_size, + clip_skip=0, + ) - self.models["vae"] = VAE( - self.vae, device=self.torch_device, max_batch_size=self.max_batch_size, embedding_dim=self.embedding_dim - ) + if "unet" in stages: + self.models["unet"] = UNet( + self.pipeline_info, + self.unet, + device=self.torch_device, + fp16=False, + max_batch_size=self.max_batch_size, + unet_dim=(9 if self.pipeline_info.is_inpaint() else 4), + ) - @classmethod - def set_cached_folder(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - - cls.cached_folder = ( - pretrained_model_name_or_path - if os.path.isdir(pretrained_model_name_or_path) - else snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, + if "vae" in stages: + self.models["vae"] = VAE( + self.pipeline_info, + self.vae, + device=self.torch_device, + max_batch_size=self.max_batch_size, ) - ) def to( self, @@ -156,7 +142,7 @@ def to( # load models self.fp16 = torch_dtype == torch.float16 - self.__load_models() + self.load_models() # build engines self.engines.build( @@ -180,88 +166,6 @@ def to( return self - def __encode_prompt(self, prompt, negative_prompt): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - """ - # Tokenize prompt - text_input_ids = ( - self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.torch_device) - ) - - # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt - text_embeddings = ( - self.engines.get_engine("clip").infer({"input_ids": text_input_ids})["text_embeddings"].clone() - ) - - # Tokenize negative prompt - uncond_input_ids = ( - self.tokenizer( - negative_prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.torch_device) - ) - - uncond_embeddings = self.engines.get_engine("clip").infer({"input_ids": uncond_input_ids})["text_embeddings"] - - # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) - - return text_embeddings - - def __denoise_latent(self, latents, text_embeddings, timesteps=None, mask=None, masked_image_latents=None): - if not isinstance(timesteps, torch.Tensor): - timesteps = self.scheduler.timesteps - - for _step_index, timestep in enumerate(timesteps): - # Expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) - if isinstance(mask, torch.Tensor): - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - - timestep_float = timestep.to(torch.float16) if self.fp16 else timestep.to(torch.float32) - - # Predict the noise residual - noise_pred = self.engines.get_engine("unet").infer( - {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, - )["latent"] - - # Perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample - - latents = 1.0 / 0.18215 * latents - return latents - - def __decode_latent(self, latents): - images = self.engines.get_engine("vae").infer({"latent": latents})["images"] - images = (images / 2 + 0.5).clamp(0, 1) - return images.cpu().permute(0, 2, 3, 1).float().numpy() - def __allocate_buffers(self, image_height, image_width, batch_size): # Allocate output tensors for I/O bindings for model_name, obj in self.models.items(): @@ -337,7 +241,7 @@ def __call__( with torch.inference_mode(), torch.autocast("cuda"): # CLIP text encoder - text_embeddings = self.__encode_prompt(prompt, negative_prompt) + text_embeddings = self.encode_prompt(self.engines.get_engine("clip"), prompt, negative_prompt) # Pre-initialize latents num_channels_latents = self.unet_in_channels @@ -352,30 +256,37 @@ def __call__( ) # UNet denoiser - latents = self.__denoise_latent(latents, text_embeddings) + latents = self.denoise_latent( + self.engines.get_engine("unet"), latents, text_embeddings, timestep_fp16=self.fp16 + ) # VAE decode latent - images = self.__decode_latent(latents) + images = self.decode_latent(self.engines.get_engine("vae"), latents) images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype) images = self.numpy_to_pil(images) return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) -if __name__ == "__main__": - model_name_or_path = "runwayml/stable-diffusion-v1-5" +def example(): + pipeline_info = PipelineInfo("1.5") + model_name_or_path = pipeline_info.name() scheduler = DDIMScheduler.from_pretrained(model_name_or_path, subfolder="scheduler") - pipe = OnnxruntimeCudaStableDiffusionPipeline.from_pretrained( model_name_or_path, scheduler=scheduler, + pipeline_info=pipeline_info, ) # re-use cached folder to save ONNX models - pipe.set_cached_folder(model_name_or_path) + pipe.set_cached_folder(model_name_or_path, resume_download=True, local_files_only=True) pipe = pipe.to("cuda", torch_dtype=torch.float16) prompt = "photorealistic new zealand hills" image = pipe(prompt).images[0] - image.save("ort_trt_txt2img_new_zealand_hills.png") + image.save("ort_cuda_txt2img_new_zealand_hills.png") + + +if __name__ == "__main__": + example() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py index 6f3c215f36318..c663e37c7ea7d 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py @@ -32,13 +32,11 @@ pip install onnxruntime-gpu """ -import gc +import logging import os -import shutil from typing import List, Optional, Union import torch -from cuda import cudart from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import ( StableDiffusionPipeline, @@ -46,224 +44,15 @@ StableDiffusionSafetyChecker, ) from diffusers.schedulers import DDIMScheduler -from diffusers.utils import DIFFUSERS_CACHE, logging -from huggingface_hub import snapshot_download -from models import CLIP, VAE, UNet -from ort_utils import OrtCudaSession +from diffusion_models import PipelineInfo +from engine_builder_ort_trt import OrtTensorrtEngineBuilder +from ort_utils import StableDiffusionPipelineMixin from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -import onnxruntime as ort +logger = logging.getLogger(__name__) -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -class Engine(OrtCudaSession): - def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, workspace_size, enable_cuda_graph): - self.engine_path = engine_path - self.ort_trt_provider_options = self.get_tensorrt_provider_options( - input_profile, - workspace_size, - fp16, - device_id, - enable_cuda_graph, - ) - - sess_options = ort.SessionOptions() - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL - ort_session = ort.InferenceSession( - onnx_path, - sess_options, - providers=[ - ("TensorrtExecutionProvider", self.ort_trt_provider_options), - ], - ) - - device = torch.device("cuda", device_id) - super().__init__(ort_session, device, enable_cuda_graph) - - def get_tensorrt_provider_options(self, input_profile, workspace_size, fp16, device_id, enable_cuda_graph): - trt_ep_options = { - "device_id": device_id, - "trt_fp16_enable": fp16, - "trt_engine_cache_enable": True, - "trt_timing_cache_enable": True, - "trt_detailed_build_log": True, - "trt_engine_cache_path": self.engine_path, - } - - if enable_cuda_graph: - trt_ep_options["trt_cuda_graph_enable"] = True - - if workspace_size > 0: - trt_ep_options["trt_max_workspace_size"] = workspace_size - - if input_profile: - min_shapes = [] - max_shapes = [] - opt_shapes = [] - for name, profile in input_profile.items(): - assert isinstance(profile, list) and len(profile) == 3 - min_shape = profile[0] - opt_shape = profile[1] - max_shape = profile[2] - assert len(min_shape) == len(opt_shape) and len(opt_shape) == len(max_shape) - - min_shapes.append(f"{name}:" + "x".join([str(x) for x in min_shape])) - opt_shapes.append(f"{name}:" + "x".join([str(x) for x in opt_shape])) - max_shapes.append(f"{name}:" + "x".join([str(x) for x in max_shape])) - - trt_ep_options["trt_profile_min_shapes"] = ",".join(min_shapes) - trt_ep_options["trt_profile_max_shapes"] = ",".join(max_shapes) - trt_ep_options["trt_profile_opt_shapes"] = ",".join(opt_shapes) - - logger.info("trt_ep_options=%s", trt_ep_options) - - return trt_ep_options - - -def get_onnx_path(model_name, onnx_dir, opt=True): - return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx") - - -def get_engine_path(engine_dir, model_name, profile_id): - return os.path.join(engine_dir, model_name + profile_id) - - -def has_engine_file(engine_path): - if os.path.isdir(engine_path): - children = os.scandir(engine_path) - for entry in children: - if entry.is_file() and entry.name.endswith(".engine"): - return True - return False - - -def get_work_space_size(model_name, max_workspace_size): - gibibyte = 2**30 - workspace_size = 4 * gibibyte if model_name == "clip" else max_workspace_size - if workspace_size == 0: - _, free_mem, _ = cudart.cudaMemGetInfo() - # The following logic are adopted from TensorRT demo diffusion. - if free_mem > 6 * gibibyte: - workspace_size = free_mem - 4 * gibibyte - return workspace_size - - -def build_engines( - models, - engine_dir, - onnx_dir, - onnx_opset, - opt_image_height, - opt_image_width, - opt_batch_size=1, - force_engine_rebuild=False, - static_batch=False, - static_image_shape=True, - max_workspace_size=0, - device_id=0, - enable_cuda_graph=False, -): - if force_engine_rebuild: - if os.path.isdir(onnx_dir): - logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir) - shutil.rmtree(onnx_dir) - if os.path.isdir(engine_dir): - logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir) - shutil.rmtree(engine_dir) - - if not os.path.isdir(engine_dir): - os.makedirs(engine_dir) - - if not os.path.isdir(onnx_dir): - os.makedirs(onnx_dir) - - # Export models to ONNX - for model_name, model_obj in models.items(): - profile_id = model_obj.get_profile_id( - opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape - ) - engine_path = get_engine_path(engine_dir, model_name, profile_id) - if not has_engine_file(engine_path): - onnx_path = get_onnx_path(model_name, onnx_dir, opt=False) - onnx_opt_path = get_onnx_path(model_name, onnx_dir) - if not os.path.exists(onnx_opt_path): - if not os.path.exists(onnx_path): - logger.info(f"Exporting model: {onnx_path}") - model = model_obj.get_model() - with torch.inference_mode(), torch.autocast("cuda"): - inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) - torch.onnx.export( - model, - inputs, - onnx_path, - export_params=True, - opset_version=onnx_opset, - do_constant_folding=True, - input_names=model_obj.get_input_names(), - output_names=model_obj.get_output_names(), - dynamic_axes=model_obj.get_dynamic_axes(), - ) - del model - torch.cuda.empty_cache() - gc.collect() - else: - logger.info("Found cached model: %s", onnx_path) - - # Optimize onnx - if not os.path.exists(onnx_opt_path): - logger.info("Generating optimizing model: %s", onnx_opt_path) - model_obj.optimize_trt(onnx_path, onnx_opt_path) - else: - logger.info("Found cached optimized model: %s", onnx_opt_path) - - built_engines = {} - for model_name, model_obj in models.items(): - profile_id = model_obj.get_profile_id( - opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape - ) - - engine_path = get_engine_path(engine_dir, model_name, profile_id) - onnx_opt_path = get_onnx_path(model_name, onnx_dir) - - if not has_engine_file(engine_path): - logger.info( - "Building TensorRT engine for %s from %s to %s. It can take a while to complete...", - model_name, - onnx_opt_path, - engine_path, - ) - else: - logger.info("Reuse cached TensorRT engine in directory %s", engine_path) - - input_profile = model_obj.get_input_profile( - opt_batch_size, - opt_image_height, - opt_image_width, - static_batch=static_batch, - static_image_shape=static_image_shape, - ) - - engine = Engine( - engine_path, - device_id, - onnx_opt_path, - fp16=True, - input_profile=input_profile, - workspace_size=get_work_space_size(model_name, max_workspace_size), - enable_cuda_graph=enable_cuda_graph, - ) - - built_engines[model_name] = engine - - return built_engines - - -def run_engine(engine, feed_dict): - return engine.infer(feed_dict) - - -class OnnxruntimeTensorRTStableDiffusionPipeline(StableDiffusionPipeline): +class OnnxruntimeTensorRTStableDiffusionPipeline(StableDiffusionPipelineMixin, StableDiffusionPipeline): r""" Pipeline for text-to-image generation using TensorRT execution provider in ONNX Runtime. @@ -285,11 +74,12 @@ def __init__( max_batch_size: int = 16, # ONNX export parameters onnx_opset: int = 17, - onnx_dir: str = "onnx", + onnx_dir: str = "onnx_trt", # TensorRT engine build parameters - engine_dir: str = "onnxruntime_tensorrt_engine", + engine_dir: str = "ORT_TRT", # use short name here to avoid path exceeds 260 chars in Windows. force_engine_rebuild: bool = False, enable_cuda_graph: bool = False, + pipeline_info: Optional[PipelineInfo] = None, ): super().__init__( vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker @@ -299,16 +89,14 @@ def __init__( self.image_height = image_height self.image_width = image_width - self.inpaint = False self.onnx_opset = onnx_opset self.onnx_dir = onnx_dir self.engine_dir = engine_dir self.force_engine_rebuild = force_engine_rebuild - self.enable_cuda_graph = enable_cuda_graph - # Although cuda graph requires static input shape, engine built with dyamic batch gets better performance in T4. + # Although cuda graph requires static input shape, engine built with dynamic batch gets better performance in T4. # Use static batch could reduce GPU memory footprint. - self.build_static_batch = False + self.build_static_batch = enable_cuda_graph # TODO: support dynamic image shape. self.build_dynamic_shape = False @@ -318,54 +106,13 @@ def __init__( if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512: self.max_batch_size = 4 - self.models = {} # loaded in __load_models() self.engines = {} # loaded in build_engines() - - def __load_models(self): - self.embedding_dim = self.text_encoder.config.hidden_size - - self.models["clip"] = CLIP( - self.text_encoder, - device=self.torch_device, - max_batch_size=self.max_batch_size, - embedding_dim=self.embedding_dim, - ) - - self.models["unet"] = UNet( - self.unet, - device=self.torch_device, - fp16=True, - max_batch_size=self.max_batch_size, - embedding_dim=self.embedding_dim, - unet_dim=(9 if self.inpaint else 4), + self.engine_builder = OrtTensorrtEngineBuilder( + pipeline_info, max_batch_size=max_batch_size, use_cuda_graph=enable_cuda_graph ) - self.models["vae"] = VAE( - self.vae, device=self.torch_device, max_batch_size=self.max_batch_size, embedding_dim=self.embedding_dim - ) - - @classmethod - def set_cached_folder(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - - cls.cached_folder = ( - pretrained_model_name_or_path - if os.path.isdir(pretrained_model_name_or_path) - else snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - ) - ) + self.pipeline_info = pipeline_info + self.stages = pipeline_info.stages() def to( self, @@ -381,11 +128,9 @@ def to( self.torch_device = self._execution_device logger.info(f"Running inference on device: {self.torch_device}") - self.__load_models() - - self.engines = build_engines( - self.models, + self.engines = self.engine_builder.build_engines( self.engine_dir, + None, self.onnx_dir, self.onnx_opset, opt_image_height=self.image_height, @@ -394,96 +139,10 @@ def to( static_batch=self.build_static_batch, static_image_shape=not self.build_dynamic_shape, device_id=self.torch_device.index, - enable_cuda_graph=self.enable_cuda_graph, ) return self - def __encode_prompt(self, prompt, negative_prompt): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - """ - # Tokenize prompt - text_input_ids = ( - self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.torch_device) - ) - - # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt - text_embeddings = run_engine(self.engines["clip"], {"input_ids": text_input_ids})["text_embeddings"].clone() - - # Tokenize negative prompt - uncond_input_ids = ( - self.tokenizer( - negative_prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.torch_device) - ) - - uncond_embeddings = run_engine(self.engines["clip"], {"input_ids": uncond_input_ids})["text_embeddings"] - - # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) - - return text_embeddings - - def __denoise_latent(self, latents, text_embeddings, timesteps=None, mask=None, masked_image_latents=None): - if not isinstance(timesteps, torch.Tensor): - timesteps = self.scheduler.timesteps - for _step_index, timestep in enumerate(timesteps): - # Expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) - if isinstance(mask, torch.Tensor): - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - - # Predict the noise residual - timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep - - noise_pred = run_engine( - self.engines["unet"], - {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, - )["latent"] - - # Perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample - - latents = 1.0 / 0.18215 * latents - return latents - - def __decode_latent(self, latents): - images = run_engine(self.engines["vae"], {"latent": latents})["images"] - images = (images / 2 + 0.5).clamp(0, 1) - return images.cpu().permute(0, 2, 3, 1).float().numpy() - - def __allocate_buffers(self, image_height, image_width, batch_size): - # Allocate output tensors for I/O bindings - for model_name, obj in self.models.items(): - self.engines[model_name].allocate_buffers(obj.get_shape_dict(batch_size, image_height, image_width)) - @torch.no_grad() def __call__( self, @@ -547,11 +206,11 @@ def __call__( f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4" ) - self.__allocate_buffers(self.image_height, self.image_width, batch_size) + self.engine_builder.load_resources(self.image_height, self.image_width, batch_size) with torch.inference_mode(), torch.autocast("cuda"): # CLIP text encoder - text_embeddings = self.__encode_prompt(prompt, negative_prompt) + text_embeddings = self.encode_prompt(self.engines["clip"], prompt, negative_prompt) # Pre-initialize latents num_channels_latents = self.unet.config.in_channels @@ -566,10 +225,10 @@ def __call__( ) # UNet denoiser - latents = self.__denoise_latent(latents, text_embeddings) + latents = self.denoise_latent(self.engines["unet"], latents, text_embeddings) # VAE decode latent - images = self.__decode_latent(latents) + images = self.decode_latent(self.engines["vae"], latents) images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype) images = self.numpy_to_pil(images) @@ -577,8 +236,8 @@ def __call__( if __name__ == "__main__": - model_name_or_path = "runwayml/stable-diffusion-v1-5" - + pipeline_info = PipelineInfo("1.5") + model_name_or_path = pipeline_info.name() scheduler = DDIMScheduler.from_pretrained(model_name_or_path, subfolder="scheduler") pipe = OnnxruntimeTensorRTStableDiffusionPipeline.from_pretrained( @@ -589,6 +248,7 @@ def __call__( image_height=512, image_width=512, max_batch_size=4, + pipeline_info=pipeline_info, ) # re-use cached folder to save ONNX models and TensorRT Engines diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index aef60a534608a..ffcfd6d9fd7e0 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -13,7 +13,7 @@ # python optimize_pipeline.py -i ./sd-v1-5 -o ./sd-v1-5-fp16 --float16 # # Note that the optimizations are carried out for CUDA Execution Provider at first, other EPs may not have the support -# for the fused opeartors. The users could disable the operator fusion manually to workaround. +# for the fused operators. The users could disable the operator fusion manually to workaround. import argparse import logging @@ -49,7 +49,6 @@ def has_external_data(onnx_model_path): def _optimize_sd_pipeline( source_dir: Path, target_dir: Path, - overwrite: bool, use_external_data_format: Optional[bool], float16: bool, force_fp32_ops: List[str], @@ -61,7 +60,6 @@ def _optimize_sd_pipeline( Args: source_dir (Path): Root of input directory of stable diffusion onnx pipeline with float32 models. target_dir (Path): Root of output directory of stable diffusion onnx pipeline with optimized models. - overwrite (bool): Overwrite files if exists. use_external_data_format (Optional[bool]): use external data format. float16 (bool): use half precision force_fp32_ops(List[str]): operators that are forced to run in float32. @@ -235,7 +233,7 @@ def optimize_stable_diffusion_pipeline( args, ): if os.path.exists(output_dir): - if args.overwrite: + if overwrite: shutil.rmtree(output_dir, ignore_errors=True) else: raise RuntimeError("output directory existed:{output_dir}. Add --overwrite to empty the directory.") @@ -249,7 +247,6 @@ def optimize_stable_diffusion_pipeline( _optimize_sd_pipeline( source_dir, target_dir, - overwrite, use_external_data_format, float16, args.force_fp32_ops, @@ -321,7 +318,7 @@ def parse_arguments(argv: Optional[List[str]] = None): required=False, action="store_true", help="Onnx model larger than 2GB need to use external data format. " - "If specifed, save each onnx model to two files: one for onnx graph, another for weights. " + "If specified, save each onnx model to two files: one for onnx graph, another for weights. " "If not specified, use same format as original model by default. ", ) parser.set_defaults(use_external_data_format=None) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py index 0824c8f07d6e2..2c4b8e8a1639e 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py @@ -12,6 +12,7 @@ from pathlib import Path import onnx +from optimize_pipeline import has_external_data from onnxruntime.transformers.fusion_options import FusionOptions from onnxruntime.transformers.onnx_model_clip import ClipOnnxModel @@ -32,21 +33,25 @@ def __init__(self, model_type: str): "clip": ClipOnnxModel, } - def optimize_by_ort(self, onnx_model): + def optimize_by_ort(self, onnx_model, use_external_data_format=False): # Use this step to see the final graph that executed by Onnx Runtime. with tempfile.TemporaryDirectory() as tmp_dir: # Save to a temporary file so that we can load it with Onnx Runtime. logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...") tmp_model_path = Path(tmp_dir) / "model.onnx" - onnx_model.save_model_to_file(str(tmp_model_path)) - ort_optimized_model_path = tmp_model_path + onnx_model.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format) + ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx" optimize_by_onnxruntime( - str(tmp_model_path), use_gpu=True, optimized_model_path=str(ort_optimized_model_path) + str(tmp_model_path), + use_gpu=True, + optimized_model_path=str(ort_optimized_model_path), + save_as_external_data=use_external_data_format, + external_data_filename="optimized.onnx_data", ) model = onnx.load(str(ort_optimized_model_path), load_external_data=True) return self.model_type_class_mapping[self.model_type](model) - def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True): + def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True, keep_io_types=False, keep_outputs=None): """Optimize onnx model using ONNX Runtime transformers optimizer""" logger.info(f"Optimize {input_fp32_onnx_path}...") fusion_options = FusionOptions(self.model_type) @@ -54,6 +59,8 @@ def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True): fusion_options.enable_packed_kv = False fusion_options.enable_packed_qkv = False + use_external_data_format = has_external_data(input_fp32_onnx_path) + m = optimize_model( input_fp32_onnx_path, model_type=self.model_type, @@ -64,21 +71,24 @@ def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True): use_gpu=True, ) - if self.model_type == "clip": - m.prune_graph(outputs=["text_embeddings"]) # remove the pooler_output, and only keep the first output. + if keep_outputs is None and self.model_type == "clip": + # remove the pooler_output, and only keep the first output. + keep_outputs = ["text_embeddings"] + + if keep_outputs: + m.prune_graph(outputs=keep_outputs) if float16: logger.info("Convert to float16 ...") m.convert_float_to_float16( - keep_io_types=False, - op_block_list=["RandomNormalLike"], + keep_io_types=keep_io_types, ) - # Note that ORT 1.15 could not save model larger than 2GB. This only works for float16 + # Note that ORT < 1.16 could not save model larger than 2GB. if float16 or (self.model_type != "unet"): - m = self.optimize_by_ort(m) + m = self.optimize_by_ort(m, use_external_data_format=use_external_data_format) m.get_operator_statistics() m.get_fused_operator_statistics() - m.save_model_to_file(optimized_onnx_path, use_external_data_format=(self.model_type == "unet") and not float16) + m.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format) logger.info("%s is optimized: %s", self.model_type, optimized_onnx_path) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py index 7192e4ad5584f..5c2145845e757 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py @@ -7,122 +7,24 @@ import logging import os import shutil -from collections import OrderedDict -from typing import Any, Dict +from typing import Union import torch import onnxruntime as ort -from onnxruntime.transformers.io_binding_helper import TypeHelper +from onnxruntime.transformers.io_binding_helper import CudaSession logger = logging.getLogger(__name__) -class OrtCudaSession: - """Inference Session with IO Binding for ONNX Runtime CUDA or TensorRT provider""" - - def __init__(self, ort_session: ort.InferenceSession, device: torch.device, enable_cuda_graph=False): - self.ort_session = ort_session - self.input_names = [input.name for input in self.ort_session.get_inputs()] - self.output_names = [output.name for output in self.ort_session.get_outputs()] - self.io_name_to_numpy_type = TypeHelper.get_io_numpy_type_map(self.ort_session) - self.io_binding = self.ort_session.io_binding() - self.enable_cuda_graph = enable_cuda_graph - - self.input_tensors = OrderedDict() - self.output_tensors = OrderedDict() - self.device = device - - def __del__(self): - del self.input_tensors - del self.output_tensors - del self.io_binding - del self.ort_session - - def allocate_buffers(self, shape_dict: Dict[str, tuple]): - """Allocate tensors for I/O Binding""" - if self.enable_cuda_graph: - for name, shape in shape_dict.items(): - if name in self.input_names: - # Reuse allocated buffer when the shape is same - if name in self.input_tensors: - if tuple(self.input_tensors[name].shape) == tuple(shape): - continue - raise RuntimeError("Expect static input shape for cuda graph") - - numpy_dtype = self.io_name_to_numpy_type[name] - tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to( - device=self.device - ) - self.input_tensors[name] = tensor - - self.io_binding.bind_input( - name, - tensor.device.type, - tensor.device.index, - numpy_dtype, - list(tensor.size()), - tensor.data_ptr(), - ) - - for name, shape in shape_dict.items(): - if name in self.output_names: - # Reuse allocated buffer when the shape is same - if name in self.output_tensors and tuple(self.output_tensors[name].shape) == tuple(shape): - continue - - numpy_dtype = self.io_name_to_numpy_type[name] - tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to( - device=self.device - ) - self.output_tensors[name] = tensor - - self.io_binding.bind_output( - name, - tensor.device.type, - tensor.device.index, - numpy_dtype, - list(tensor.size()), - tensor.data_ptr(), - ) - - def infer(self, feed_dict): - """Bind input tensors and run inference""" - for name, tensor in feed_dict.items(): - assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous() - if name in self.input_names: - if self.enable_cuda_graph: - assert self.input_tensors[name].nelement() == tensor.nelement() - assert tensor.device.type == "cuda" - # Update input tensor inplace since cuda graph requires input and output has fixed memory address. - from cuda import cudart - - cudart.cudaMemcpy( - self.input_tensors[name].data_ptr(), - tensor.data_ptr(), - tensor.element_size() * tensor.nelement(), - cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice, - ) - else: - self.io_binding.bind_input( - name, - tensor.device.type, - tensor.device.index, - TypeHelper.torch_type_to_numpy_type(tensor.dtype), - [1] if len(tensor.shape) == 0 else list(tensor.shape), - tensor.data_ptr(), - ) - - self.ort_session.run_with_iobinding(self.io_binding) - - return self.output_tensors - - -class Engine(OrtCudaSession): +# ----------------------------------------------------------------------------------------------------- +# Utilities for CUDA EP +# ----------------------------------------------------------------------------------------------------- +class Engine(CudaSession): def __init__(self, engine_path, provider: str, device_id: int = 0, enable_cuda_graph=False): self.engine_path = engine_path self.provider = provider - self.provider_options = self.get_cuda_provider_options(device_id, enable_cuda_graph) + self.provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph) device = torch.device("cuda", device_id) ort_session = ort.InferenceSession( @@ -135,13 +37,6 @@ def __init__(self, engine_path, provider: str, device_id: int = 0, enable_cuda_g super().__init__(ort_session, device, enable_cuda_graph) - def get_cuda_provider_options(self, device_id: int, enable_cuda_graph: bool) -> Dict[str, Any]: - return { - "device_id": device_id, - "arena_extend_strategy": "kSameAsRequested", - "enable_cuda_graph": enable_cuda_graph, - } - class Engines: def __init__(self, provider, onnx_opset: int = 14): @@ -197,9 +92,16 @@ def build( model = model_obj.get_model().to(model_obj.device) with torch.inference_mode(): inputs = model_obj.get_sample_input(1, 512, 512) + fp32_inputs = tuple( + [ + (tensor.to(torch.float32) if tensor.dtype == torch.float16 else tensor) + for tensor in inputs + ] + ) + torch.onnx.export( model, - inputs, + fp32_inputs, onnx_path, export_params=True, opset_version=self.onnx_opset, @@ -224,3 +126,125 @@ def build( def get_engine(self, model_name): return self.engines[model_name] + + +def run_engine(engine, feed_dict): + return engine.infer(feed_dict) + + +# ----------------------------------------------------------------------------------------------------- +# Utilities for both CUDA and TensorRT EP +# ----------------------------------------------------------------------------------------------------- + + +class StableDiffusionPipelineMixin: + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def encode_prompt(self, clip_engine, prompt, negative_prompt): + """ + Encodes the prompt into text encoder hidden states. + """ + + # Tokenize prompt + text_input_ids = ( + self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) + + # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt + text_embeddings = run_engine(clip_engine, {"input_ids": text_input_ids})["text_embeddings"].clone() + + # Tokenize negative prompt + uncond_input_ids = ( + self.tokenizer( + negative_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) + + uncond_embeddings = run_engine(clip_engine, {"input_ids": uncond_input_ids})["text_embeddings"] + + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) + + return text_embeddings + + def denoise_latent( + self, + unet_engine, + latents, + text_embeddings, + timesteps=None, + mask=None, + masked_image_latents=None, + timestep_fp16=False, + ): + if not isinstance(timesteps, torch.Tensor): + timesteps = self.scheduler.timesteps + + for _step_index, timestep in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + if isinstance(mask, torch.Tensor): + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # Predict the noise residual + timestep_float = timestep.to(torch.float16) if timestep_fp16 else timestep.to(torch.float32) + + noise_pred = run_engine( + unet_engine, + {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, + )["latent"] + + # Perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample + + latents = 1.0 / 0.18215 * latents + return latents + + def decode_latent(self, vae_engine, latents): + images = run_engine(vae_engine, {"latent": latents})["images"] + images = (images / 2 + 0.5).clamp(0, 1) + return images.cpu().permute(0, 2, 3, 1).float().numpy() + + def set_cached_folder(self, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): + from diffusers.utils import DIFFUSERS_CACHE + from huggingface_hub import snapshot_download + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + + self.cached_folder = ( + pretrained_model_name_or_path + if os.path.isdir(pretrained_model_name_or_path) + else snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py new file mode 100644 index 0000000000000..0e2aeb6174666 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py @@ -0,0 +1,232 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import time + +import torch +from diffusion_models import PipelineInfo +from pipeline_stable_diffusion import StableDiffusionPipeline + + +class Img2ImgXLPipeline(StableDiffusionPipeline): + """ + Stable Diffusion Img2Img XL pipeline using NVidia TensorRT. + """ + + def __init__(self, pipeline_info: PipelineInfo, *args, **kwargs): + """ + Initializes the Img2Img XL Diffusion pipeline. + + Args: + pipeline_info (PipelineInfo): + Version and Type of stable diffusion pipeline. + """ + assert pipeline_info.is_sd_xl_refiner() + + super().__init__(pipeline_info, *args, **kwargs) + + self.requires_aesthetics_score = True + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + ): + if self.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0).to(device=self.device) + return add_time_ids + + def _infer( + self, + prompt, + negative_prompt, + init_image, + image_height, + image_width, + denoising_steps=30, + guidance=5.0, + seed=None, + warmup=False, + return_type="image", + ): + assert len(prompt) == len(negative_prompt) + + # TODO(tianleiwu): Need we use image_height and image_width for the target size here? + original_size = (1024, 1024) + crops_coords_top_left = (0, 0) + target_size = (1024, 1024) + strength = 0.3 + aesthetic_score = 6.0 + negative_aesthetic_score = 2.5 + + self.set_denoising_steps(denoising_steps) + self.set_random_seed(seed) + + with torch.inference_mode(), torch.autocast("cuda"): + batch_size = len(prompt) + + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + # Initialize timesteps + timesteps, t_start = self.initialize_timesteps(self.denoising_steps, strength) + latent_timestep = timesteps[:1].repeat(batch_size) + + # CLIP text encoder 2 + text_embeddings, pooled_embeddings2 = self.encode_prompt( + prompt, + negative_prompt, + encoder="clip2", + tokenizer=self.tokenizer2, + pooled_outputs=True, + output_hidden_states=True, + ) + + # Time embeddings + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype=text_embeddings.dtype, + ) + + add_time_ids = add_time_ids.repeat(batch_size, 1) + + add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids} + + # Pre-process input image + init_image = self.preprocess_images(batch_size, (init_image,))[0] + + # VAE encode init image + if init_image.shape[1] == 4: + init_latents = init_image + else: + init_latents = self.encode_image(init_image) + + # Add noise to latents using timesteps + noise = torch.randn(init_latents.shape, device=self.device, dtype=torch.float32, generator=self.generator) + latents = self.scheduler.add_noise(init_latents, noise, t_start, latent_timestep) + + # UNet denoiser + latents = self.denoise_latent( + latents, + text_embeddings, + timesteps=timesteps, + step_offset=t_start, + denoiser="unetxl", + guidance=guidance, + add_kwargs=add_kwargs, + ) + + with torch.inference_mode(): + # VAE decode latent + if return_type == "latents": + images = latents * self.vae_scaling_factor + else: + images = self.decode_latent(latents) + + torch.cuda.synchronize() + e2e_toc = time.perf_counter() + + if not warmup: + print("SD-XL Refiner Pipeline") + self.print_summary(e2e_tic, e2e_toc, batch_size) + self.save_images(images, "img2img-xl", prompt) + + return images, (e2e_toc - e2e_tic) * 1000.0 + + def run( + self, + prompt, + negative_prompt, + init_image, + image_height, + image_width, + denoising_steps=30, + guidance=5.0, + seed=None, + warmup=False, + return_type="images", + ): + """ + Run the diffusion pipeline. + + Args: + prompt (str): + The text prompt to guide image generation. + negative_prompt (str): + The prompt not to guide the image generation. + init_image (tuple[torch.Tensor]): + Image from base pipeline. + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + denoising_steps (int): + Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference. + guidance (float): + Higher guidance scale encourages to generate images that are closely linked to the text prompt. + seed (int): + Seed for the random generator + warmup (bool): + Indicate if this is a warmup run. + return_type (str): + It can be "latents" or "images". + """ + + if self.is_backend_tensorrt(): + import tensorrt as trt + from trt_utilities import TRT_LOGGER + + with trt.Runtime(TRT_LOGGER): + return self._infer( + prompt, + negative_prompt, + init_image, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) + else: + return self._infer( + prompt, + negative_prompt, + init_image, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py new file mode 100644 index 0000000000000..a053c9d5d0835 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py @@ -0,0 +1,429 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import os +import pathlib +import random + +import nvtx +import torch +from cuda import cudart +from diffusion_models import PipelineInfo, get_tokenizer +from diffusion_schedulers import DDIMScheduler, EulerAncestralDiscreteScheduler, UniPCMultistepScheduler +from engine_builder import EngineType +from engine_builder_ort_trt import OrtTensorrtEngineBuilder +from engine_builder_tensorrt import TensorrtEngineBuilder + + +class StableDiffusionPipeline: + """ + Stable Diffusion pipeline using TensorRT. + """ + + def __init__( + self, + pipeline_info: PipelineInfo, + max_batch_size=16, + scheduler="DDIM", + device="cuda", + output_dir=".", + hf_token=None, + verbose=False, + nvtx_profile=False, + use_cuda_graph=False, + framework_model_dir="pytorch_model", + engine_type: EngineType = EngineType.ORT_TRT, + ): + """ + Initializes the Diffusion pipeline. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + scheduler (str): + The scheduler to guide the denoising process. Must be one of [DDIM, EulerA, UniPC]. + device (str): + PyTorch device to run inference. Default: 'cuda' + output_dir (str): + Output directory for log files and image artifacts + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + verbose (bool): + Enable verbose logging. + nvtx_profile (bool): + Insert NVTX profiling markers. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + framework_model_dir (str): + cache directory for framework checkpoints + engine_type (EngineType) + backend engine type like ORT_TRT or TRT + """ + + self.pipeline_info = pipeline_info + self.version = pipeline_info.version + + self.vae_scaling_factor = pipeline_info.vae_scaling_factor() + + self.max_batch_size = max_batch_size + + self.framework_model_dir = framework_model_dir + self.output_dir = output_dir + for directory in [self.framework_model_dir, self.output_dir]: + if not os.path.exists(directory): + print(f"[I] Create directory: {directory}") + pathlib.Path(directory).mkdir(parents=True) + + self.hf_token = hf_token + self.device = device + self.torch_device = torch.device(device, torch.cuda.current_device()) + self.verbose = verbose + self.nvtx_profile = nvtx_profile + + # Scheduler options + sched_opts = {"num_train_timesteps": 1000, "beta_start": 0.00085, "beta_end": 0.012} + if self.version in ("2.0", "2.1"): + sched_opts["prediction_type"] = "v_prediction" + else: + sched_opts["prediction_type"] = "epsilon" + + if scheduler == "DDIM": + self.scheduler = DDIMScheduler(device=self.device, **sched_opts) + elif scheduler == "EulerA": + self.scheduler = EulerAncestralDiscreteScheduler(device=self.device, **sched_opts) + elif scheduler == "UniPC": + self.scheduler = UniPCMultistepScheduler(device=self.device) + else: + raise ValueError("Scheduler should be either DDIM, EulerA or UniPC") + + self.stages = pipeline_info.stages() + + self.vae_torch_fallback = self.pipeline_info.is_sd_xl() + + self.use_cuda_graph = use_cuda_graph + + self.tokenizer = None + self.tokenizer2 = None + + self.generator = None + self.denoising_steps = None + + # backend engine + self.engine_type = engine_type + if engine_type == EngineType.TRT: + self.backend = TensorrtEngineBuilder(pipeline_info, max_batch_size, hf_token, device, use_cuda_graph) + elif engine_type == EngineType.ORT_TRT: + self.backend = OrtTensorrtEngineBuilder(pipeline_info, max_batch_size, hf_token, device, use_cuda_graph) + else: + raise RuntimeError(f"Backend engine type {engine_type.name} is not supported") + + # Load text tokenizer + if not self.pipeline_info.is_sd_xl_refiner(): + self.tokenizer = get_tokenizer( + self.pipeline_info, self.framework_model_dir, self.hf_token, subfolder="tokenizer" + ) + + if self.pipeline_info.is_sd_xl(): + self.tokenizer2 = get_tokenizer( + self.pipeline_info, self.framework_model_dir, self.hf_token, subfolder="tokenizer_2" + ) + + # Create CUDA events + self.events = {} + for stage in ["clip", "denoise", "vae", "vae_encoder"]: + for marker in ["start", "stop"]: + self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1] + + def is_backend_tensorrt(self): + return self.engine_type == EngineType.TRT + + def set_denoising_steps(self, denoising_steps: int): + if self.denoising_steps != denoising_steps: + assert self.denoising_steps is None # TODO(tianleiwu): support changing steps in different runs + # Pre-compute latent input scales and linear multistep coefficients + self.scheduler.set_timesteps(denoising_steps) + self.scheduler.configure() + self.denoising_steps = denoising_steps + + def load_resources(self, image_height, image_width, batch_size): + # If engine is built with static input shape, call this only once after engine build. + # Otherwise, it need be called before every inference run. + self.backend.load_resources(image_height, image_width, batch_size) + + def set_random_seed(self, seed): + # Initialize noise generator. Usually, it is done before a batch of inference. + self.generator = torch.Generator(device="cuda").manual_seed(seed) if isinstance(seed, int) else None + + def teardown(self): + for e in self.events.values(): + cudart.cudaEventDestroy(e) + + if self.backend: + self.backend.teardown() + + def run_engine(self, model_name, feed_dict): + return self.backend.run_engine(model_name, feed_dict) + + def initialize_latents(self, batch_size, unet_channels, latent_height, latent_width): + latents_dtype = torch.float32 # text_embeddings.dtype + latents_shape = (batch_size, unet_channels, latent_height, latent_width) + latents = torch.randn(latents_shape, device=self.device, dtype=latents_dtype, generator=self.generator) + # Scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def initialize_timesteps(self, timesteps, strength): + self.scheduler.set_timesteps(timesteps) + offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0 + init_timestep = int(timesteps * strength) + offset + init_timestep = min(init_timestep, timesteps) + t_start = max(timesteps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + return timesteps, t_start + + def preprocess_images(self, batch_size, images=()): + if self.nvtx_profile: + nvtx_image_preprocess = nvtx.start_range(message="image_preprocess", color="pink") + init_images = [] + for i in images: + image = i.to(self.device).float() + if image.shape[0] != batch_size: + image = image.repeat(batch_size, 1, 1, 1) + init_images.append(image) + if self.nvtx_profile: + nvtx.end_range(nvtx_image_preprocess) + return tuple(init_images) + + def encode_prompt( + self, prompt, negative_prompt, encoder="clip", tokenizer=None, pooled_outputs=False, output_hidden_states=False + ): + if tokenizer is None: + tokenizer = self.tokenizer + + if self.nvtx_profile: + nvtx_clip = nvtx.start_range(message="clip", color="green") + cudart.cudaEventRecord(self.events["clip-start"], 0) + + # Tokenize prompt + text_input_ids = ( + tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.device) + ) + + # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt + outputs = self.run_engine(encoder, {"input_ids": text_input_ids}) + text_embeddings = outputs["text_embeddings"].clone() + if output_hidden_states: + hidden_states = outputs["hidden_states"].clone() + + # Tokenize negative prompt + uncond_input_ids = ( + tokenizer( + negative_prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.device) + ) + + outputs = self.run_engine(encoder, {"input_ids": uncond_input_ids}) + uncond_embeddings = outputs["text_embeddings"] + if output_hidden_states: + uncond_hidden_states = outputs["hidden_states"] + + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) + + if pooled_outputs: + pooled_output = text_embeddings + + if output_hidden_states: + text_embeddings = torch.cat([uncond_hidden_states, hidden_states]).to(dtype=torch.float16) + + cudart.cudaEventRecord(self.events["clip-stop"], 0) + if self.nvtx_profile: + nvtx.end_range(nvtx_clip) + + if pooled_outputs: + return text_embeddings, pooled_output + return text_embeddings + + def denoise_latent( + self, + latents, + text_embeddings, + denoiser="unet", + timesteps=None, + step_offset=0, + mask=None, + masked_image_latents=None, + guidance=7.5, + image_guidance=1.5, + add_kwargs=None, + ): + assert guidance > 1.0, "Guidance has to be > 1.0" + assert image_guidance > 1.0, "Image guidance has to be > 1.0" + + cudart.cudaEventRecord(self.events["denoise-start"], 0) + if not isinstance(timesteps, torch.Tensor): + timesteps = self.scheduler.timesteps + for step_index, timestep in enumerate(timesteps): + if self.nvtx_profile: + nvtx_latent_scale = nvtx.start_range(message="latent_scale", color="pink") + + # Expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, step_offset + step_index, timestep + ) + + if isinstance(mask, torch.Tensor): + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + if self.nvtx_profile: + nvtx.end_range(nvtx_latent_scale) + + # Predict the noise residual + if self.nvtx_profile: + nvtx_unet = nvtx.start_range(message="unet", color="blue") + + timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep + + sample_inp = latent_model_input + timestep_inp = timestep_float + embeddings_inp = text_embeddings + + params = {"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp} + if add_kwargs: + params.update(add_kwargs) + + noise_pred = self.run_engine(denoiser, params)["latent"] + + if self.nvtx_profile: + nvtx.end_range(nvtx_unet) + + if self.nvtx_profile: + nvtx_latent_step = nvtx.start_range(message="latent_step", color="pink") + + # Perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond) + + if type(self.scheduler) == UniPCMultistepScheduler: + latents = self.scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] + else: + latents = self.scheduler.step(noise_pred, latents, step_offset + step_index, timestep) + + if self.nvtx_profile: + nvtx.end_range(nvtx_latent_step) + + latents = 1.0 / self.vae_scaling_factor * latents + cudart.cudaEventRecord(self.events["denoise-stop"], 0) + return latents + + def encode_image(self, init_image): + if self.nvtx_profile: + nvtx_vae = nvtx.start_range(message="vae_encoder", color="red") + cudart.cudaEventRecord(self.events["vae_encoder-start"], 0) + init_latents = self.run_engine("vae_encoder", {"images": init_image})["latent"] + cudart.cudaEventRecord(self.events["vae_encoder-stop"], 0) + if self.nvtx_profile: + nvtx.end_range(nvtx_vae) + + init_latents = self.vae_scaling_factor * init_latents + return init_latents + + def decode_latent(self, latents): + if self.nvtx_profile: + nvtx_vae = nvtx.start_range(message="vae", color="red") + cudart.cudaEventRecord(self.events["vae-start"], 0) + images = self.backend.vae_decode(latents) + cudart.cudaEventRecord(self.events["vae-stop"], 0) + if self.nvtx_profile: + nvtx.end_range(nvtx_vae) + return images + + def print_summary(self, tic, toc, batch_size, vae_enc=False): + print("|------------|--------------|") + print("| {:^10} | {:^12} |".format("Module", "Latency")) + print("|------------|--------------|") + if vae_enc: + print( + "| {:^10} | {:>9.2f} ms |".format( + "VAE-Enc", + cudart.cudaEventElapsedTime(self.events["vae_encoder-start"], self.events["vae_encoder-stop"])[1], + ) + ) + print( + "| {:^10} | {:>9.2f} ms |".format( + "CLIP", cudart.cudaEventElapsedTime(self.events["clip-start"], self.events["clip-stop"])[1] + ) + ) + print( + "| {:^10} | {:>9.2f} ms |".format( + "UNet x " + str(self.denoising_steps), + cudart.cudaEventElapsedTime(self.events["denoise-start"], self.events["denoise-stop"])[1], + ) + ) + print( + "| {:^10} | {:>9.2f} ms |".format( + "VAE-Dec", cudart.cudaEventElapsedTime(self.events["vae-start"], self.events["vae-stop"])[1] + ) + ) + print("|------------|--------------|") + print("| {:^10} | {:>9.2f} ms |".format("Pipeline", (toc - tic) * 1000.0)) + print("|------------|--------------|") + print(f"Throughput: {batch_size / (toc - tic):.2f} image/s") + + @staticmethod + def to_pil_image(images): + images = ( + ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy() + ) + from PIL import Image + + return [Image.fromarray(images[i]) for i in range(images.shape[0])] + + def save_images(self, images, pipeline, prompt): + image_name_prefix = ( + pipeline + "".join(set(["-" + prompt[i].replace(" ", "_")[:10] for i in range(len(prompt))])) + "-" + ) + + images = self.to_pil_image(images) + random_session_id = str(random.randint(1000, 9999)) + for i, image in enumerate(images): + image_path = os.path.join( + self.output_dir, image_name_prefix + str(i + 1) + "-" + random_session_id + ".png" + ) + print(f"Saving image {i+1} / {len(images)} to: {image_path}") + image.save(image_path) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py new file mode 100644 index 0000000000000..82f73e8b3cc61 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py @@ -0,0 +1,155 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import time + +import torch +from diffusion_models import PipelineInfo +from pipeline_stable_diffusion import StableDiffusionPipeline + + +class Txt2ImgPipeline(StableDiffusionPipeline): + """ + Stable Diffusion Txt2Img pipeline using NVidia TensorRT. + """ + + def __init__(self, pipeline_info: PipelineInfo, **kwargs): + """ + Initializes the Txt2Img Diffusion pipeline. + + Args: + pipeline_info (PipelineInfo): + Version and Type of stable diffusion pipeline. + """ + super().__init__(pipeline_info, **kwargs) + + def _infer( + self, + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=50, + guidance=7.5, + seed=None, + warmup=False, + return_type="latents", + ): + assert len(prompt) == len(negative_prompt) + batch_size = len(prompt) + + self.set_denoising_steps(denoising_steps) + self.set_random_seed(seed) + + with torch.inference_mode(), torch.autocast("cuda"): + # Pre-initialize latents + latents = self.initialize_latents( + batch_size=batch_size, + unet_channels=4, + latent_height=(image_height // 8), + latent_width=(image_width // 8), + ) + + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + # CLIP text encoder + text_embeddings = self.encode_prompt(prompt, negative_prompt) + + # UNet denoiser + latents = self.denoise_latent(latents, text_embeddings, guidance=guidance) + + # VAE decode latent + images = self.decode_latent(latents) + + torch.cuda.synchronize() + e2e_toc = time.perf_counter() + + if not warmup: + self.print_summary(e2e_tic, e2e_toc, batch_size) + self.save_images(images, "txt2img", prompt) + + return images, (e2e_toc - e2e_tic) * 1000.0 + + def run( + self, + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=30, + guidance=7.5, + seed=None, + warmup=False, + return_type="images", + ): + """ + Run the diffusion pipeline. + + Args: + prompt (str): + The text prompt to guide image generation. + negative_prompt (str): + The prompt not to guide the image generation. + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + denoising_steps (int): + Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference. + guidance (float): + Higher guidance scale encourages to generate images that are closely linked to the text prompt. + seed (int): + Seed for the random generator + warmup (bool): + Indicate if this is a warmup run. + return_type (str): + type of return. The value can be "latents" or "images". + """ + if self.is_backend_tensorrt(): + import tensorrt as trt + from trt_utilities import TRT_LOGGER + + with trt.Runtime(TRT_LOGGER): + return self._infer( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) + else: + return self._infer( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py new file mode 100644 index 0000000000000..d8f00ed619354 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py @@ -0,0 +1,198 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import time + +import torch +from diffusion_models import PipelineInfo +from pipeline_stable_diffusion import StableDiffusionPipeline + + +class Txt2ImgXLPipeline(StableDiffusionPipeline): + """ + Stable Diffusion Txt2Img XL pipeline. + """ + + def __init__(self, pipeline_info: PipelineInfo, *args, **kwargs): + """ + Initializes the Txt2Img XL Diffusion pipeline. + + Args: + pipeline_info (PipelineInfo): + Version and Type of stable diffusion pipeline. + """ + assert pipeline_info.is_sd_xl_base() + + super().__init__(pipeline_info, *args, **kwargs) + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def _infer( + self, + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=30, + guidance=5.0, + seed=None, + warmup=False, + return_type="images", + ): + assert len(prompt) == len(negative_prompt) + + # TODO(tianleiwu): Need we use image_height and image_width for the target size here? + original_size = (1024, 1024) + crops_coords_top_left = (0, 0) + target_size = (1024, 1024) + batch_size = len(prompt) + + self.set_denoising_steps(denoising_steps) + self.set_random_seed(seed) + + with torch.inference_mode(), torch.autocast("cuda"): + # Pre-initialize latents + latents = self.initialize_latents( + batch_size=batch_size, + unet_channels=4, + latent_height=(image_height // 8), + latent_width=(image_width // 8), + ) + + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + # CLIP text encoder + text_embeddings = self.encode_prompt( + prompt, negative_prompt, encoder="clip", tokenizer=self.tokenizer, output_hidden_states=True + ) + # CLIP text encoder 2 + text_embeddings2, pooled_embeddings2 = self.encode_prompt( + prompt, + negative_prompt, + encoder="clip2", + tokenizer=self.tokenizer2, + pooled_outputs=True, + output_hidden_states=True, + ) + + # Merged text embeddings + text_embeddings = torch.cat([text_embeddings, text_embeddings2], dim=-1) + + # Time embeddings + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=text_embeddings.dtype + ) + add_time_ids = add_time_ids.repeat(batch_size, 1) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0).to(self.device) + + add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids} + + # UNet denoiser + latents = self.denoise_latent( + latents, text_embeddings, denoiser="unetxl", guidance=guidance, add_kwargs=add_kwargs + ) + + # VAE decode latent + if return_type == "latents": + images = latents * self.vae_scaling_factor + else: + images = self.decode_latent(latents) + + torch.cuda.synchronize() + e2e_toc = time.perf_counter() + + if not warmup: + print("SD-XL Base Pipeline") + self.print_summary(e2e_tic, e2e_toc, batch_size) + if return_type == "images": + self.save_images(images, "txt2img-xl", prompt) + + return images, (e2e_toc - e2e_tic) * 1000.0 + + def run( + self, + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=30, + guidance=5.0, + seed=None, + warmup=False, + return_type="images", + ): + """ + Run the diffusion pipeline. + + Args: + prompt (str): + The text prompt to guide image generation. + negative_prompt (str): + The prompt not to guide the image generation. + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + denoising_steps (int): + Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference. + guidance (float): + Higher guidance scale encourages to generate images that are closely linked to the text prompt. + seed (int): + Seed for the random generator + warmup (bool): + Indicate if this is a warmup run. + return_type (str): + It can be "latents" or "images". + """ + + if self.is_backend_tensorrt(): + import tensorrt as trt + from trt_utilities import TRT_LOGGER + + with trt.Runtime(TRT_LOGGER): + return self._infer( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) + else: + return self._infer( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt index b942749f8dcd2..2a3caf4c2392b 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt @@ -1,8 +1,14 @@ -r requirements.txt -onnxruntime-gpu>=1.14 +onnxruntime-gpu>=1.16 py3nvml>=0.2.7 + # cuda-python is needed for cuda graph. It shall be compatible with CUDA version of torch and onnxruntime-gpu. -cuda-python==11.7.0 -#To export onnx of stable diffusion, please install PyTorch 1.13.1+cu117 -#--extra-index-url https://download.pytorch.org/whl/cu117 -#torch==1.13.1+cu117 +cuda-python==11.8.0 +# For windows, cuda-python need the following +pywin32; platform_system == "Windows" + +nvtx + +# To export onnx, please install PyTorch 2.10 like +# pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118 +# pip3 install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt index 567f39c0119e6..5b59c64ab7470 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt @@ -1,18 +1,2 @@ -diffusers>=0.16.0 -transformers>=4.26.0 -numpy>=1.24.1 -accelerate -onnx>=1.13.0 -coloredlogs -packaging -protobuf -psutil -sympy +-r requirements-cuda.txt tensorrt>=8.6.1 -onnxruntime-gpu>=1.15.1 -py3nvml -# cuda-python version shall be compatible with CUDA version of torch and onnxruntime-gpu -cuda-python==11.7.0 -#To export onnx of stable diffusion, please install PyTorch 1.13.1+cu117 -#--extra-index-url https://download.pytorch.org/whl/cu117 -#torch==1.13.1+cu117 diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/trt_utilities.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/trt_utilities.py new file mode 100644 index 0000000000000..d03a9f9f55372 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/trt_utilities.py @@ -0,0 +1,12 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import tensorrt as trt + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) + + +def init_trt_plugins(): + # Register TensorRT plugins + trt.init_libnvinfer_plugins(TRT_LOGGER, "") diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 60be2d84b2bc8..e9c24ed3eb09b 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -610,7 +610,7 @@ def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs): When symbolic shape inference is used (even if it failed), ONNX shape inference will be disabled. - Note that onnx shape inference will fail for model larger than 2GB. For large model, you have to eanble + Note that onnx shape inference will fail for model larger than 2GB. For large model, you have to enable symbolic shape inference. If your model is not optimized, you can also use model path to call convert_float_to_float16 in float16.py (see https://github.com/microsoft/onnxruntime/pull/15067) to avoid the 2GB limit. @@ -832,7 +832,7 @@ def get_first_output(node): # Keep track of nodes to keep. The key is first output of node, and the value is the node. output_to_node = {} - # Start from graph outputs, and find parent nodes recurisvely, and add nodes to the output_to_node dictionary. + # Start from graph outputs, and find parent nodes recursively, and add nodes to the output_to_node dictionary. dq = deque() for output in keep_outputs: if output in output_name_to_node: @@ -1161,7 +1161,7 @@ def has_same_value( signature_cache1 (dict): Optional dictionary to store data signatures of tensor1 in order to speed up comparison. signature_cache2 (dict): Optional dictionary to store data signatures of tensor2 in order to speed up comparison. Returns: - bool: True when two intializers has same value. + bool: True when two initializers has same value. """ sig1 = ( signature_cache1[tensor1.name]