Skip to content

Commit

Permalink
SDXL Latent Consistency Model (LCM) optimization (#18526)
Browse files Browse the repository at this point in the history
Add support of LCM model
(https://huggingface.co/latent-consistency/lcm-sdxl) in SDXL demo.

Since LCM model does not need classifier-free guidance, so there is no
need to use negative prompt. The input and output shape is different
from original SDXL model: no need to double the batch dimension.

We also save metadata to image, and update image filename to include
scheduler and steps.

#### Latency (miliseconds) of generating 1024x1024 images in
A100-SXM4-80GB GPU

Engines are built with static input shape, and CUDA graph is enabled.
For dynamic shape input, the latency could be slower.

Batch Size | Pipeline | Steps | ORT_CUDA | ORT_TRT | TRT 8.6
-- | -- | -- | -- | -- | --
1 | LCM SDXL | 4 | 275 | 249 | 258
1 | LCM SDXL | 8 | 460 | 423 | 430
1 | SDXL Base | 30 | 2566 | 2535 | 2569
4 | LCM  SDXL | 4 | 925 | 887 | 1032
4 | LCM  SDXL | 8 | 1539 | 1493 | 1662
4 | SDXL Base | 30 | 9227 | 9408 | 9678
  • Loading branch information
tianleiwu authored Nov 22, 2023
1 parent d455b0f commit 62da3b1
Show file tree
Hide file tree
Showing 12 changed files with 570 additions and 148 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ For example:

If you do not provide prompt, the script will generate different image sizes for a list of prompts for demonstration.

#### Generate an image with SDXL LCM guided by a text prompt
```python3 demo_txt2img_xl.py --lcm --disable-refiner "an astronaut riding a rainbow unicorn, cinematic, dramatic"```

## Optimize Stable Diffusion ONNX models for Hugging Face Diffusers or Optimum

If you are able to run the above demo with docker, you can use the docker and skip the following setup and fast forward to [Export ONNX pipeline](#export-onnx-pipeline).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import coloredlogs
from cuda import cudart
from demo_utils import init_pipeline, parse_arguments, repeat_prompt
from demo_utils import get_metadata, init_pipeline, parse_arguments, repeat_prompt
from diffusion_models import PipelineInfo
from engine_builder import EngineType, get_engine_type
from pipeline_txt2img import Txt2ImgPipeline
Expand Down Expand Up @@ -104,17 +104,25 @@ def run_inference(warmup=False):

if not args.disable_cuda_graph:
# inference once to get cuda graph
_image, _latency = run_inference(warmup=True)
_, _ = run_inference(warmup=True)

print("[I] Warming up ..")
for _ in range(args.num_warmup_runs):
_image, _latency = run_inference(warmup=True)
_, _ = run_inference(warmup=True)

print("[I] Running StableDiffusion pipeline")
if args.nvtx_profile:
cudart.cudaProfilerStart()
_image, _latency = run_inference(warmup=False)
images, perf_data = run_inference(warmup=False)
if args.nvtx_profile:
cudart.cudaProfilerStop()

metadata = get_metadata(args, False)
metadata.update(pipeline.metadata())
if perf_data:
metadata.update(perf_data)
metadata["images"] = len(images)
print(metadata)
pipeline.save_images(images, prompt, negative_prompt, metadata)

pipeline.teardown()
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import coloredlogs
from cuda import cudart
from demo_utils import init_pipeline, parse_arguments, repeat_prompt
from demo_utils import get_metadata, init_pipeline, parse_arguments, repeat_prompt
from diffusion_models import PipelineInfo
from engine_builder import EngineType, get_engine_type
from pipeline_img2img_xl import Img2ImgXLPipeline
Expand Down Expand Up @@ -54,7 +54,11 @@ def load_pipelines(args, batch_size):

# No VAE decoder in base when it outputs latent instead of image.
base_info = PipelineInfo(
args.version, use_vae=args.disable_refiner, min_image_size=min_image_size, max_image_size=max_image_size
args.version,
use_vae=args.disable_refiner,
min_image_size=min_image_size,
max_image_size=max_image_size,
use_lcm=args.lcm,
)

# Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to
Expand Down Expand Up @@ -118,7 +122,7 @@ def run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False
refiner.load_resources(image_height, image_width, batch_size)

def run_base_and_refiner(warmup=False):
images, time_base = base.run(
images, base_perf = base.run(
prompt,
negative_prompt,
image_height,
Expand All @@ -130,24 +134,31 @@ def run_base_and_refiner(warmup=False):
return_type="latent" if refiner else "image",
)
if refiner is None:
return images, time_base
return images, base_perf

# Use same seed in base and refiner.
seed = base.get_current_seed()

images, time_refiner = refiner.run(
images, refiner_perf = refiner.run(
prompt,
negative_prompt,
images,
image_height,
image_width,
warmup=warmup,
denoising_steps=args.denoising_steps,
guidance=args.guidance,
denoising_steps=args.refiner_steps,
strength=args.strength,
guidance=args.refiner_guidance,
seed=seed,
)

return images, time_base + time_refiner
perf_data = None
if base_perf and refiner_perf:
perf_data = {"latency": base_perf["latency"] + refiner_perf["latency"]}
perf_data.update({"base." + key: val for key, val in base_perf.items()})
perf_data.update({"refiner." + key: val for key, val in refiner_perf.items()})

return images, perf_data

if not args.disable_cuda_graph:
# inference once to get cuda graph
Expand All @@ -164,13 +175,24 @@ def run_base_and_refiner(warmup=False):
print("[I] Running StableDiffusion XL pipeline")
if args.nvtx_profile:
cudart.cudaProfilerStart()
_, latency = run_base_and_refiner(warmup=False)
images, perf_data = run_base_and_refiner(warmup=False)
if args.nvtx_profile:
cudart.cudaProfilerStop()

print("|------------|--------------|")
print("| {:^10} | {:>9.2f} ms |".format("e2e", latency))
print("|------------|--------------|")
if refiner:
print("|------------|--------------|")
print("| {:^10} | {:>9.2f} ms |".format("e2e", perf_data["latency"]))
print("|------------|--------------|")

metadata = get_metadata(args, True)
metadata.update({"base." + key: val for key, val in base.metadata().items()})
if refiner:
metadata.update({"refiner." + key: val for key, val in refiner.metadata().items()})
if perf_data:
metadata.update(perf_data)
metadata["images"] = len(images)
print(metadata)
(refiner or base).save_images(images, prompt, negative_prompt, metadata)


def run_demo(args):
Expand All @@ -189,6 +211,8 @@ def run_dynamic_shape_demo(args):
"""Run demo of generating images with different settings with ORT CUDA provider."""
args.engine = "ORT_CUDA"
args.disable_cuda_graph = True
if args.lcm:
args.disable_refiner = True
base, refiner = load_pipelines(args, 1)

prompts = [
Expand All @@ -198,22 +222,31 @@ def run_dynamic_shape_demo(args):
"cute grey cat with blue eyes, wearing a bowtie, acrylic painting",
"beautiful Renaissance Revival Estate, Hobbit-House, detailed painting, warm colors, 8k, trending on Artstation",
"blue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realistic",
"An astronaut riding a rainbow unicorn, cinematic, dramatic",
"close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm",
]

# batch size, height, width, scheduler, steps, prompt, seed
# refiner, batch size, height, width, scheduler, steps, prompt, seed, guidance, refiner scheduler, refiner steps, refiner strength
configs = [
(1, 832, 1216, "UniPC", 8, prompts[0], None),
(1, 1024, 1024, "DDIM", 24, prompts[1], None),
(1, 1216, 832, "UniPC", 16, prompts[2], None),
(1, 1344, 768, "DDIM", 24, prompts[3], None),
(2, 640, 1536, "UniPC", 16, prompts[4], 4312973633252712),
(2, 1152, 896, "DDIM", 24, prompts[5], 1964684802882906),
(1, 832, 1216, "UniPC", 8, prompts[0], None, 5.0, "UniPC", 10, 0.3),
(1, 1024, 1024, "DDIM", 24, prompts[1], None, 5.0, "DDIM", 30, 0.3),
(1, 1216, 832, "UniPC", 16, prompts[2], None, 5.0, "UniPC", 10, 0.3),
(1, 1344, 768, "DDIM", 24, prompts[3], None, 5.0, "UniPC", 20, 0.3),
(2, 640, 1536, "UniPC", 16, prompts[4], 4312973633252712, 5.0, "UniPC", 10, 0.3),
(2, 1152, 896, "DDIM", 24, prompts[5], 1964684802882906, 5.0, "UniPC", 20, 0.3),
]

# In testing LCM, refiner is disabled so the settings of refiner is not used.
if args.lcm:
configs = [
(1, 1024, 1024, "LCM", 8, prompts[6], None, 1.0, "UniPC", 20, 0.3),
(1, 1216, 832, "LCM", 6, prompts[7], 1337, 1.0, "UniPC", 20, 0.3),
]

# Warm up each combination of (batch size, height, width) once before serving.
args.prompt = ["warm up"]
args.num_warmup_runs = 1
for batch_size, height, width, _, _, _, _ in configs:
for batch_size, height, width, _, _, _, _, _, _, _, _ in configs:
args.batch_size = batch_size
args.height = height
args.width = width
Expand All @@ -223,20 +256,33 @@ def run_dynamic_shape_demo(args):

# Run pipeline on a list of prompts.
args.num_warmup_runs = 0
for batch_size, height, width, scheduler, steps, example_prompt, seed in configs:
for (
batch_size,
height,
width,
scheduler,
steps,
example_prompt,
seed,
guidance,
refiner_scheduler,
refiner_steps,
strength,
) in configs:
args.prompt = [example_prompt]
args.batch_size = batch_size
args.height = height
args.width = width
args.scheduler = scheduler
args.denoising_steps = steps
args.seed = seed
args.guidance = guidance
args.refiner_scheduler = refiner_scheduler
args.refiner_steps = refiner_steps
args.strength = strength
base.set_scheduler(scheduler)
if refiner:
refiner.set_scheduler(scheduler)
print(
f"\nbatch_size={batch_size}, height={height}, width={width}, scheduler={scheduler}, steps={steps}, prompt={example_prompt}, seed={seed}"
)
refiner.set_scheduler(refiner_scheduler)
prompt, negative_prompt = repeat_prompt(args)
run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# --------------------------------------------------------------------------

import argparse
from typing import Any, Dict

import torch
from diffusion_models import PipelineInfo
Expand Down Expand Up @@ -68,8 +69,8 @@ def parse_arguments(is_xl: bool, description: str):
"--scheduler",
type=str,
default="DDIM",
choices=["DDIM", "UniPC"] if is_xl else ["DDIM", "EulerA", "UniPC"],
help="Scheduler for diffusion process",
choices=["DDIM", "UniPC", "LCM"] if is_xl else ["DDIM", "EulerA", "UniPC"],
help="Scheduler for diffusion process" + " of base" if is_xl else "",
)

parser.add_argument(
Expand Down Expand Up @@ -105,6 +106,42 @@ def parse_arguments(is_xl: bool, description: str):
help="Higher guidance scale encourages to generate images that are closely linked to the text prompt.",
)

if is_xl:
parser.add_argument(
"--lcm",
action="store_true",
help="Use fine-tuned latent consistency model to replace the UNet in base.",
)

parser.add_argument(
"--refiner-scheduler",
type=str,
default="DDIM",
choices=["DDIM", "UniPC"],
help="Scheduler for diffusion process of refiner.",
)

parser.add_argument(
"--refiner-guidance",
type=float,
default=5.0,
help="Guidance scale used in refiner.",
)

parser.add_argument(
"--refiner-steps",
type=int,
default=30,
help="Number of denoising steps in refiner. Note that actual refiner steps is refiner_steps * strength.",
)

parser.add_argument(
"--strength",
type=float,
default=0.3,
help="A value between 0 and 1. The higher the value less the final image similar to the seed image.",
)

# ONNX export
parser.add_argument(
"--onnx-opset",
Expand Down Expand Up @@ -190,11 +227,52 @@ def parse_arguments(is_xl: bool, description: str):
if args.onnx_opset is None:
args.onnx_opset = 14 if args.engine == "ORT_CUDA" else 17

if is_xl:
if args.lcm:
if args.guidance > 1.0:
print("[I] Use --guidance=1.0 for base since LCM is used.")
args.guidance = 1.0
if args.scheduler != "LCM":
print("[I] Use --scheduler=LCM for base since LCM is used.")
args.scheduler = "LCM"
if args.denoising_steps > 16:
print("[I] Use --denoising_steps=8 (no more than 16) for base since LCM is used.")
args.denoising_steps = 8
assert args.strength > 0.0 and args.strength < 1.0

print(args)

return args


def get_metadata(args, is_xl: bool = False) -> Dict[str, Any]:
metadata = {
"args.prompt": args.prompt,
"args.negative_prompt": args.negative_prompt,
"args.batch_size": args.batch_size,
"height": args.height,
"width": args.width,
"cuda_graph": not args.disable_cuda_graph,
"vae_slicing": args.enable_vae_slicing,
"engine": args.engine,
}

if is_xl and not args.disable_refiner:
metadata["base.scheduler"] = args.scheduler
metadata["base.denoising_steps"] = args.denoising_steps
metadata["base.guidance"] = args.guidance
metadata["refiner.strength"] = args.strength
metadata["refiner.scheduler"] = args.refiner_scheduler
metadata["refiner.denoising_steps"] = args.refiner_steps
metadata["refiner.guidance"] = args.refiner_guidance
else:
metadata["scheduler"] = args.scheduler
metadata["denoising_steps"] = args.denoising_steps
metadata["guidance"] = args.guidance

return metadata


def repeat_prompt(args):
if not isinstance(args.prompt, list):
raise ValueError(f"`prompt` must be of type `str` or `str` list, but is {type(args.prompt)}")
Expand Down Expand Up @@ -223,7 +301,7 @@ def init_pipeline(
# Initialize demo
pipeline = pipeline_class(
pipeline_info,
scheduler=args.scheduler,
scheduler=args.refiner_scheduler if pipeline_info.is_xl_refiner() else args.scheduler,
output_dir=output_dir,
hf_token=args.hf_token,
verbose=False,
Expand Down
Loading

0 comments on commit 62da3b1

Please sign in to comment.