Skip to content

Commit

Permalink
SDXL: Update demo with dynamic shape serving with CUDA EP (microsoft#…
Browse files Browse the repository at this point in the history
…18340)

Update the SDXL demo with dynamic shape serving with CUDA EP.
  • Loading branch information
tianleiwu authored and kleiti committed Mar 22, 2024
1 parent 5213f36 commit 13e83a8
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,7 @@
from pipeline_txt2img_xl import Txt2ImgXLPipeline


def run_demo():
"""Run Stable Diffusion XL Base + Refiner together (known as ensemble of expert denoisers) to generate an image."""

args = parse_arguments(is_xl=True, description="Options for Stable Diffusion XL Demo")

prompt, negative_prompt = repeat_prompt(args)

# Recommend image size as one of those used in training (see Appendix I in https://arxiv.org/pdf/2307.01952.pdf).
image_height = args.height
image_width = args.width

def load_pipelines(args, batch_size):
# Register TensorRT plugins
engine_type = get_engine_type(args.engine)
if engine_type == EngineType.TRT:
Expand All @@ -49,19 +39,18 @@ def run_demo():

max_batch_size = 16
if (engine_type in [EngineType.ORT_TRT, EngineType.TRT]) and (
args.build_dynamic_shape or image_height > 512 or image_width > 512
args.build_dynamic_shape or args.height > 512 or args.width > 512
):
max_batch_size = 4

batch_size = len(prompt)
if batch_size > max_batch_size:
raise ValueError(f"Batch size {batch_size} is larger than allowed {max_batch_size}.")

# No VAE decoder in base when it outputs latent instead of image.
base_info = PipelineInfo(args.version, use_vae=False)
base_info = PipelineInfo(args.version, use_vae=False, min_image_size=640, max_image_size=1536)
base = init_pipeline(Txt2ImgXLPipeline, base_info, engine_type, args, max_batch_size, batch_size)

refiner_info = PipelineInfo(args.version, is_refiner=True)
refiner_info = PipelineInfo(args.version, is_refiner=True, min_image_size=640, max_image_size=1536)
refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type, args, max_batch_size, batch_size)

if engine_type == EngineType.TRT:
Expand All @@ -77,7 +66,13 @@ def run_demo():
enable_vae_slicing = True
if enable_vae_slicing:
refiner.backend.enable_vae_slicing()
return base, refiner


def run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False):
image_height = args.height
image_width = args.width
batch_size = len(prompt)
base.load_resources(image_height, image_width, batch_size)
refiner.load_resources(image_height, image_width, batch_size)

Expand Down Expand Up @@ -112,25 +107,95 @@ def run_base_and_refiner(warmup=False):
# inference once to get cuda graph
_, _ = run_base_and_refiner(warmup=True)

print("[I] Warming up ..")
if args.num_warmup_runs > 0:
print("[I] Warming up ..")
for _ in range(args.num_warmup_runs):
_, _ = run_base_and_refiner(warmup=True)

if is_warm_up:
return

print("[I] Running StableDiffusion XL pipeline")
if args.nvtx_profile:
cudart.cudaProfilerStart()
_, latency = run_base_and_refiner(warmup=False)
if args.nvtx_profile:
cudart.cudaProfilerStop()

base.teardown()

print("|------------|--------------|")
print("| {:^10} | {:>9.2f} ms |".format("e2e", latency))
print("|------------|--------------|")


def run_demo(args):
"""Run Stable Diffusion XL Base + Refiner together (known as ensemble of expert denoisers) to generate an image."""

prompt, negative_prompt = repeat_prompt(args)
batch_size = len(prompt)
base, refiner = load_pipelines(args, batch_size)
run_pipelines(args, base, refiner, prompt, negative_prompt)
base.teardown()
refiner.teardown()


def run_dynamic_shape_demo(args):
"""Run demo of generating images with different size with list of prompts with ORT CUDA provider."""
args.engine = "ORT_CUDA"
args.scheduler = "UniPC"
args.denoising_steps = 8
args.disable_cuda_graph = True

batch_size = args.repeat_prompt
base, refiner = load_pipelines(args, batch_size)

image_sizes = [
(1024, 1024),
(1152, 896),
(896, 1152),
(1216, 832),
(832, 1216),
(1344, 768),
(768, 1344),
(1536, 640),
(640, 1536),
]

# Warm up the pipelines. This only need once before serving.
args.prompt = ["warm up"]
args.num_warmup_runs = 3
prompt, negative_prompt = repeat_prompt(args)
for height, width in image_sizes:
args.height = height
args.width = width
print(f"\nWarm up pipelines for Batch_size={batch_size}, Height={height}, Width={width}")
run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=True)

# Run pipeline on a list of prompts.
prompts = [
"starry night over Golden Gate Bridge by van gogh",
"little cute gremlin sitting on a bed, cinematic",
]
args.num_warmup_runs = 0
for example_prompt in prompts:
args.prompt = [example_prompt]
prompt, negative_prompt = repeat_prompt(args)

for height, width in image_sizes:
args.height = height
args.width = width
print(f"\nBatch_size={batch_size}, Height={height}, Width={width}, Prompt={example_prompt}")
run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False)

base.teardown()
refiner.teardown()


if __name__ == "__main__":
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
run_demo()

args = parse_arguments(is_xl=True, description="Options for Stable Diffusion XL Demo")
no_prompt = isinstance(args.prompt, list) and len(args.prompt) == 1 and not args.prompt[0]
if no_prompt:
run_dynamic_shape_demo(args)
else:
run_demo(args)
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def parse_arguments(is_xl: bool, description: str):
help="Root Directory to store torch or ONNX models, built engines and output images etc.",
)

parser.add_argument("prompt", nargs="+", help="Text prompt(s) to guide image generation.")
parser.add_argument("prompt", nargs="*", default=[""], help="Text prompt(s) to guide image generation.")

parser.add_argument(
"--negative-prompt", nargs="*", default=[""], help="Optional negative prompt(s) to guide the image generation."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,21 @@ def infer_shapes(self):


class PipelineInfo:
def __init__(self, version: str, is_inpaint: bool = False, is_refiner: bool = False, use_vae=False):
def __init__(
self,
version: str,
is_inpaint: bool = False,
is_refiner: bool = False,
use_vae=False,
min_image_size=256,
max_image_size=1024,
):
self.version = version
self._is_inpaint = is_inpaint
self._is_refiner = is_refiner
self._use_vae = use_vae

self._min_image_size = min_image_size
self._max_image_size = max_image_size
if is_refiner:
assert self.is_xl()

Expand Down Expand Up @@ -187,6 +196,12 @@ def unet_embedding_dim(self):
else:
raise ValueError(f"Invalid version {self.version}")

def min_image_size(self):
return self._min_image_size

def max_image_size(self):
return self._max_image_size


class BaseModel:
def __init__(
Expand All @@ -209,8 +224,8 @@ def __init__(

self.min_batch = 1
self.max_batch = max_batch_size
self.min_image_shape = 256 # min image resolution: 256x256
self.max_image_shape = 1024 # max image resolution: 1024x1024
self.min_image_shape = pipeline_info.min_image_size()
self.max_image_shape = pipeline_info.max_image_size()
self.min_latent_shape = self.min_image_shape // 8
self.max_latent_shape = self.max_image_shape // 8

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,9 @@ def is_backend_tensorrt(self):
return self.engine_type == EngineType.TRT

def set_denoising_steps(self, denoising_steps: int):
if self.denoising_steps != denoising_steps:
assert self.denoising_steps is None # TODO(tianleiwu): support changing steps in different runs
# Pre-compute latent input scales and linear multistep coefficients
self.scheduler.set_timesteps(denoising_steps)
self.scheduler.configure()
self.denoising_steps = denoising_steps
self.scheduler.set_timesteps(denoising_steps)
self.scheduler.configure()
self.denoising_steps = denoising_steps

def load_resources(self, image_height, image_width, batch_size):
# If engine is built with static input shape, call this only once after engine build.
Expand Down

0 comments on commit 13e83a8

Please sign in to comment.