Skip to content

Commit

Permalink
add demo_txt2img
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Oct 3, 2023
1 parent ef94b1b commit 7fe1ec7
Show file tree
Hide file tree
Showing 8 changed files with 408 additions and 241 deletions.
1 change: 1 addition & 0 deletions onnxruntime/python/tools/transformers/io_binding_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def infer(self, feed_dict: Dict[str, torch.Tensor]):
if name in self.input_names:
if self.enable_cuda_graph:
assert self.input_tensors[name].nelement() == tensor.nelement()
assert self.input_tensors[name].dtype == tensor.dtype
assert tensor.device.type == "cuda"
# Please install cuda-python package with a version corresponding to CUDA in your machine.
from cuda import cudart
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# Modified from TensorRT demo diffusion, which has the following license:
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------

import coloredlogs
from cuda import cudart
from demo_utils import init_pipeline, parse_arguments, repeat_prompt
from diffusion_models import PipelineInfo
from engine_builder import EngineType, get_engine_paths, get_engine_type
from pipeline_txt2img import Txt2ImgPipeline

if __name__ == "__main__":
coloredlogs.install(fmt="%(funcName)20s: %(message)s")

args = parse_arguments(is_xl=False, description="Options for Stable Diffusion Demo")
prompt, negative_prompt = repeat_prompt(args)

image_height = args.height
image_width = args.width

# Register TensorRT plugins
engine_type = get_engine_type(args.engine)
if engine_type == EngineType.TRT:
from trt_utilities import init_trt_plugins

init_trt_plugins()

max_batch_size = 16
if engine_type != EngineType.ORT_CUDA and (args.build_dynamic_shape or image_height > 512 or image_width > 512):
max_batch_size = 4

batch_size = len(prompt)
if batch_size > max_batch_size:
raise ValueError(
f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4"
)

pipeline_info = PipelineInfo(args.version)
pipeline = init_pipeline(Txt2ImgPipeline, pipeline_info, engine_type, args, max_batch_size, batch_size)

if engine_type == EngineType.TRT:
max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory())
_, shared_device_memory = cudart.cudaMalloc(max_device_memory)
pipeline.backend.activate_engines(shared_device_memory)

pipeline.load_resources(image_height, image_width, batch_size)

def run_inference(warmup=False):
images, time_base = pipeline.run(
prompt,
negative_prompt,
image_height,
image_width,
warmup=warmup,
denoising_steps=args.denoising_steps,
guidance=args.guidance,
seed=args.seed,
return_type="images",
)

return images, time_base

if not args.disable_cuda_graph:
# inference once to get cuda graph
images, _ = run_inference(warmup=True)

print("[I] Warming up ..")
for _ in range(args.num_warmup_runs):
images, _ = run_inference(warmup=True)

print("[I] Running StableDiffusion pipeline")
if args.nvtx_profile:
cudart.cudaProfilerStart()
images, pipeline_time = run_inference(warmup=False)
if args.nvtx_profile:
cudart.cudaProfilerStop()

pipeline.teardown()
Original file line number Diff line number Diff line change
Expand Up @@ -21,161 +21,23 @@
# --------------------------------------------------------------------------

import argparse
import coloredlogs

import torch
from cuda import cudart
from diffusion_models import PipelineInfo
from engine_builder import EngineType, get_engine_paths, get_engine_type
from pipeline_img2img_xl import Img2ImgXLPipeline
from pipeline_txt2img_xl import Txt2ImgXLPipeline


def parse_arguments():
parser = argparse.ArgumentParser(description="Options for Stable Diffusion XL Demo", conflict_handler="resolve")
parser.add_argument(
"--engine",
type=str,
default="ORT_TRT",
choices=["ORT_TRT", "TRT"],
help="Backend engine. Default is ORT_TRT, which means OnnxRuntime TensorRT execution provider.",
)

parser.add_argument(
"--version", type=str, default="xl-1.0", choices=["xl-1.0"], help="Version of Stable Diffusion XL"
)
parser.add_argument(
"--height", type=int, default=1024, help="Height of image to generate (must be multiple of 8). Default is 1024."
)
parser.add_argument(
"--width", type=int, default=1024, help="Height of image to generate (must be multiple of 8). Default is 1024."
)

parser.add_argument(
"--scheduler",
type=str,
default="DDIM",
choices=["DDIM", "EulerA", "UniPC"],
help="Scheduler for diffusion process",
)

parser.add_argument(
"--work-dir",
default="",
help="Root Directory to store torch or ONNX models, built engines and output images etc",
)

parser.add_argument("prompt", nargs="+", help="Text prompt(s) to guide image generation")

parser.add_argument(
"--negative-prompt", nargs="*", default=[""], help="Optional negative prompt(s) to guide the image generation."
)
parser.add_argument(
"--repeat-prompt",
type=int,
default=1,
choices=[1, 2, 4, 8, 16],
help="Number of times to repeat the prompt (batch size multiplier). Default is 1.",
)
parser.add_argument(
"--denoising-steps",
type=int,
default=30,
help="Number of denoising steps in each of base and refiner. Default is 30.",
)
parser.add_argument(
"--guidance",
type=float,
default=5.0,
help="Higher guidance scale encourages to generate images that are closely linked to the text prompt.",
)

# ONNX export
parser.add_argument(
"--onnx-opset",
type=int,
default=17,
choices=range(14, 18),
help="Select ONNX opset version to target for exported models. Default is 17.",
)
parser.add_argument(
"--force-onnx-export", action="store_true", help="Force ONNX export of CLIP, UNET, and VAE models"
)
parser.add_argument(
"--force-onnx-optimize", action="store_true", help="Force ONNX optimizations for CLIP, UNET, and VAE models"
)

# Framework model ckpt
parser.add_argument("--framework-model-dir", default="pytorch_model", help="Directory for HF saved models")
parser.add_argument("--hf-token", type=str, help="HuggingFace API access token for downloading model checkpoints")

# Engine build options.
parser.add_argument("--force-engine-build", action="store_true", help="Force rebuilding the TensorRT engine")
parser.add_argument(
"--build-dynamic-batch", action="store_true", help="Build TensorRT engines to support dynamic batch size."
)
parser.add_argument(
"--build-dynamic-shape", action="store_true", help="Build TensorRT engines to support dynamic image sizes."
)

# Inference related options
parser.add_argument(
"--num-warmup-runs", type=int, default=5, help="Number of warmup runs before benchmarking performance"
)
parser.add_argument("--nvtx-profile", action="store_true", help="Enable NVTX markers for performance profiling")
parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results")
parser.add_argument("--disable-cuda-graph", action="store_true", help="Disable cuda graph.")

# TensorRT only options
group = parser.add_argument_group("Options for TensorRT (--engine=TRT) only")
group.add_argument("--onnx-refit-dir", help="ONNX models to load the weights from")
group.add_argument(
"--build-enable-refit", action="store_true", help="Enable Refit option in TensorRT engines during build."
)
group.add_argument(
"--build-preview-features", action="store_true", help="Build TensorRT engines with preview features."
)
group.add_argument(
"--build-all-tactics", action="store_true", help="Build TensorRT engines using all tactic sources."
)

# Pipeline options
parser.add_argument(
"--enable-refiner", action="store_true", help="Enable refiner and run both base and refiner pipeline."
)

return parser.parse_args()

from demo_utils import parse_arguments, repeat_prompt, init_pipeline

if __name__ == "__main__":
args = parse_arguments()

if (args.build_dynamic_batch or args.build_dynamic_shape) and not args.disable_cuda_graph:
print("[I] CUDA Graph is disabled since dynamic input shape is configured.")
args.disable_cuda_graph = True
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
args = parse_arguments(is_xl=True, description="Options for Stable Diffusion XL Demo")
prompt, negative_prompt = repeat_prompt(args)

print(args)

# Process prompt
if not isinstance(args.prompt, list):
raise ValueError(f"`prompt` must be of type `str` or `str` list, but is {type(args.prompt)}")
prompt = args.prompt * args.repeat_prompt

if not isinstance(args.negative_prompt, list):
raise ValueError(
f"`--negative-prompt` must be of type `str` or `str` list, but is {type(args.negative_prompt)}"
)
if len(args.negative_prompt) == 1:
negative_prompt = args.negative_prompt * len(prompt)
else:
negative_prompt = args.negative_prompt

# Validate image dimensions
image_height = args.height
image_width = args.width
if image_height % 8 != 0 or image_width % 8 != 0:
raise ValueError(
f"Image height and width have to be divisible by 8 but specified as: {image_height} and {image_width}."
)

# Register TensorRT plugins
engine_type = get_engine_type(args.engine)
Expand All @@ -194,73 +56,12 @@ def parse_arguments():
f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4"
)

def init_pipeline(pipeline_class, pipeline_info, engine_type):
onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths(
args.work_dir, pipeline_info, engine_type
)

# Initialize demo
pipeline = pipeline_class(
pipeline_info,
scheduler=args.scheduler,
output_dir=output_dir,
hf_token=args.hf_token,
verbose=False,
nvtx_profile=args.nvtx_profile,
max_batch_size=max_batch_size,
use_cuda_graph=not args.disable_cuda_graph,
framework_model_dir=framework_model_dir,
engine_type=engine_type,
)

# Load TensorRT engines and pytorch modules
if engine_type == EngineType.ORT_TRT:
pipeline.backend.build_engines(
engine_dir,
framework_model_dir,
onnx_dir,
args.onnx_opset,
opt_image_height=image_height,
opt_image_width=image_width,
opt_batch_size=len(prompt),
# force_export=args.force_onnx_export,
# force_optimize=args.force_onnx_optimize,
force_engine_rebuild=args.force_engine_build,
static_batch=not args.build_dynamic_batch,
static_image_shape=not args.build_dynamic_shape,
max_workspace_size=0,
device_id=torch.cuda.current_device(),
)
elif engine_type == EngineType.TRT:
# Load TensorRT engines and pytorch modules
pipeline.backend.load_engines(
engine_dir,
framework_model_dir,
onnx_dir,
args.onnx_opset,
opt_batch_size=len(prompt),
opt_image_height=image_height,
opt_image_width=image_width,
force_export=args.force_onnx_export,
force_optimize=args.force_onnx_optimize,
force_build=args.force_engine_build,
static_batch=not args.build_dynamic_batch,
static_shape=not args.build_dynamic_shape,
enable_refit=args.build_enable_refit,
enable_preview=args.build_preview_features,
enable_all_tactics=args.build_all_tactics,
timing_cache=timing_cache,
onnx_refit_dir=args.onnx_refit_dir,
)

return pipeline

base_info = PipelineInfo(args.version, use_vae_in_xl_base=not args.enable_refiner)
base = init_pipeline(Txt2ImgXLPipeline, base_info, engine_type)
base = init_pipeline(Txt2ImgXLPipeline, base_info, engine_type, args, max_batch_size, batch_size)

if args.enable_refiner:
refiner_info = PipelineInfo(args.version, is_sd_xl_refiner=True)
refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type)
refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type, args, max_batch_size, batch_size)

if engine_type == EngineType.TRT:
max_device_memory = max(base.backend.max_device_memory(), refiner.backend.max_device_memory())
Expand Down
Loading

0 comments on commit 7fe1ec7

Please sign in to comment.