Skip to content

Commit

Permalink
SDXL demo: Add Option to disable refiner (#18455)
Browse files Browse the repository at this point in the history
Add option to disable refiner and only run base model.
  • Loading branch information
tianleiwu authored Nov 16, 2023
1 parent 16d7f55 commit 119e86e
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def load_pipelines(args, batch_size):
max_image_size = 1216 if args.engine != "ORT_CUDA" else 2048

# No VAE decoder in base when it outputs latent instead of image.
base_info = PipelineInfo(args.version, use_vae=False, min_image_size=min_image_size, max_image_size=max_image_size)
base_info = PipelineInfo(
args.version, use_vae=args.disable_refiner, min_image_size=min_image_size, max_image_size=max_image_size
)

# Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to
# optimize the shape used most frequently. We can let user config it when we develop a UI plugin.
Expand All @@ -74,33 +76,36 @@ def load_pipelines(args, batch_size):
opt_image_width,
)

refiner_info = PipelineInfo(
args.version, is_refiner=True, min_image_size=min_image_size, max_image_size=max_image_size
)
refiner = init_pipeline(
Img2ImgXLPipeline,
refiner_info,
engine_type,
args,
max_batch_size,
opt_batch_size,
opt_image_height,
opt_image_width,
)
refiner = None
if not args.disable_refiner:
refiner_info = PipelineInfo(
args.version, is_refiner=True, min_image_size=min_image_size, max_image_size=max_image_size
)
refiner = init_pipeline(
Img2ImgXLPipeline,
refiner_info,
engine_type,
args,
max_batch_size,
opt_batch_size,
opt_image_height,
opt_image_width,
)

if engine_type == EngineType.TRT:
max_device_memory = max(base.backend.max_device_memory(), refiner.backend.max_device_memory())
max_device_memory = max(base.backend.max_device_memory(), (refiner or base).backend.max_device_memory())
_, shared_device_memory = cudart.cudaMalloc(max_device_memory)
base.backend.activate_engines(shared_device_memory)
refiner.backend.activate_engines(shared_device_memory)
if refiner:
refiner.backend.activate_engines(shared_device_memory)

if engine_type == EngineType.ORT_CUDA:
enable_vae_slicing = args.enable_vae_slicing
if batch_size > 4 and not enable_vae_slicing:
print("Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4.")
enable_vae_slicing = True
if enable_vae_slicing:
refiner.backend.enable_vae_slicing()
(refiner or base).backend.enable_vae_slicing()
return base, refiner


Expand All @@ -109,7 +114,8 @@ def run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False
image_width = args.width
batch_size = len(prompt)
base.load_resources(image_height, image_width, batch_size)
refiner.load_resources(image_height, image_width, batch_size)
if refiner:
refiner.load_resources(image_height, image_width, batch_size)

def run_base_and_refiner(warmup=False):
images, time_base = base.run(
Expand All @@ -121,8 +127,10 @@ def run_base_and_refiner(warmup=False):
denoising_steps=args.denoising_steps,
guidance=args.guidance,
seed=args.seed,
return_type="latent",
return_type="latent" if refiner else "image",
)
if refiner is None:
return images, time_base

# Use same seed in base and refiner.
seed = base.get_current_seed()
Expand Down Expand Up @@ -173,7 +181,8 @@ def run_demo(args):
base, refiner = load_pipelines(args, batch_size)
run_pipelines(args, base, refiner, prompt, negative_prompt)
base.teardown()
refiner.teardown()
if refiner:
refiner.teardown()


def run_dynamic_shape_demo(args):
Expand Down Expand Up @@ -223,15 +232,17 @@ def run_dynamic_shape_demo(args):
args.denoising_steps = steps
args.seed = seed
base.set_scheduler(scheduler)
refiner.set_scheduler(scheduler)
if refiner:
refiner.set_scheduler(scheduler)
print(
f"\nbatch_size={batch_size}, height={height}, width={width}, scheduler={scheduler}, steps={steps}, prompt={example_prompt}, seed={seed}"
)
prompt, negative_prompt = repeat_prompt(args)
run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False)

base.teardown()
refiner.teardown()
if refiner:
refiner.teardown()


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def parse_arguments(is_xl: bool, description: str):
parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.")
parser.add_argument("--disable-cuda-graph", action="store_true", help="Disable cuda graph.")

parser.add_argument(
"--disable-refiner", action="store_true", help="Disable refiner and only run base for XL pipeline."
)

group = parser.add_argument_group("Options for ORT_CUDA engine only")
group.add_argument("--enable-vae-slicing", action="store_true", help="True will feed only one image to VAE once.")

Expand Down

0 comments on commit 119e86e

Please sign in to comment.