Skip to content

Commit

Permalink
Add SDXL-Turbo
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Nov 29, 2023
1 parent 14a3434 commit 1f4ee3e
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,17 @@ python3 demo_txt2img_xl.py "Self-portrait oil painting, a beautiful cyborg with
python3 demo_txt2img_xl.py --lcm --disable-refiner "an astronaut riding a rainbow unicorn, cinematic, dramatic"
```

#### Generate an image with SDXL Turbo model guided by a text prompt
```
python3 demo_txt2img_xl.py --version xl-turbo --height 512 --width 512 --denoising-steps 1 --scheduler UniPC "little cute gremlin sitting on a bed, cinematic"
```

#### Generate an image with a text prompt using a control net
```
python3 demo_txt2img.py "Stormtrooper's lecture in beautiful lecture hall" --controlnet-type depth --controlnet-scale 1.0
python3 demo_txt2img_xl.py "young Mona Lisa" --controlnet-type canny --controlnet-scale 0.5 --scheduler UniPC --disable-refiner
```

## Optimize Stable Diffusion ONNX models for Hugging Face Diffusers or Optimum

If you are able to run the above demo with docker, you can use the docker and skip the following setup and fast forward to [Export ONNX pipeline](#export-onnx-pipeline).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@ def load_pipelines(args, batch_size):
# For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size.
# Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance.
# This range can cover most frequent shape of landscape (832x1216), portrait (1216x832) or square (1024x1024).
min_image_size = 832 if args.engine != "ORT_CUDA" else 512
max_image_size = 1216 if args.engine != "ORT_CUDA" else 2048
if args.version == "xl-turbo":
min_image_size = 512
max_image_size = 768 if args.engine != "ORT_CUDA" else 1024
else:
min_image_size = 832 if args.engine != "ORT_CUDA" else 512
max_image_size = 1216 if args.engine != "ORT_CUDA" else 2048

# No VAE decoder in base when it outputs latent instead of image.
base_info = PipelineInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def parse_arguments(is_xl: bool, parser):
parser.add_argument(
"--version",
type=str,
default=supported_versions[-1] if is_xl else "1.5",
default="xl-1.0" if is_xl else "1.5",
choices=supported_versions,
help="Version of Stable Diffusion" + (" XL." if is_xl else "."),
)
Expand Down Expand Up @@ -244,6 +244,20 @@ def parse_arguments(is_xl: bool, parser):
args.onnx_opset = 14 if args.engine == "ORT_CUDA" else 17

if is_xl:
if args.version == "xl-turbo":
if args.guidance > 1.0:
print("[I] Use --guidance=1.0 for sdxl-turbo.")
args.guidance = 1.0
if args.lcm:
print("[I] sdxl-turbo cannot use with LCM.")
args.lcm = False
if args.denoising_steps > 8:
print("[I] Use --denoising_steps=4 (no more than 8) for sdxl-turbo.")
args.denoising_steps = 4
if not args.disable_refiner:
print("[I] sdxl-turbo cannot use with SDXL refiner.")
args.disable_refiner = True

if args.lcm and args.scheduler != "LCM":
print("[I] Use --scheduler=LCM for base since LCM is used.")
args.scheduler = "LCM"
Expand Down Expand Up @@ -628,12 +642,12 @@ def process_controlnet_arguments(args):
assert isinstance(args.controlnet_type, list)
assert isinstance(args.controlnet_scale, list)
assert isinstance(args.controlnet_image, list)
if args.version not in ["1.5", "xl-1.0"]:
raise ValueError("This demo only supports ControlNet in Stable Diffusion 1.5 or XL.")
if args.version not in ["1.5", "xl-1.0", "xl-turbo"]:
raise ValueError("This demo only supports ControlNet in Stable Diffusion 1.5, XL or Turbo.")

is_xl = args.version == "xl-1.0"
is_xl = "xl" in args.version
if is_xl and len(args.controlnet_type) > 1:
raise ValueError("This demo only support one ControlNet for Stable Diffusion XL.")
raise ValueError("This demo only support one ControlNet for Stable Diffusion XL or Turbo.")

if len(args.controlnet_image) != 0 and len(args.controlnet_image) != len(args.controlnet_scale):
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,23 @@ def is_inpaint(self) -> bool:
def is_xl(self) -> bool:
return "xl" in self.version

def is_xl_turbo(self) -> bool:
return self.version == "xl-turbo"

def is_xl_base(self) -> bool:
return self.is_xl() and not self._is_refiner
return self.version == "xl-1.0" and not self._is_refiner

def is_xl_base_or_turbo(self) -> bool:
return self.is_xl_base() or self.is_xl_turbo()

def is_xl_refiner(self) -> bool:
return self.is_xl() and self._is_refiner
return self.version == "xl-1.0" and self._is_refiner

def use_safetensors(self) -> bool:
return self.is_xl()

def stages(self) -> List[str]:
if self.is_xl_base():
if self.is_xl_base_or_turbo():
return ["clip", "clip2", "unetxl"] + (["vae"] if self._use_vae else [])

if self.is_xl_refiner():
Expand All @@ -153,7 +159,7 @@ def custom_unet(self) -> Optional[str]:

@staticmethod
def supported_versions(is_xl: bool):
return ["xl-1.0"] if is_xl else ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base"]
return ["xl-1.0", "xl-turbo"] if is_xl else ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base"]

def name(self) -> str:
if self.version == "1.4":
Expand Down Expand Up @@ -185,6 +191,8 @@ def name(self) -> str:
return "stabilityai/stable-diffusion-xl-refiner-1.0"
else:
return "stabilityai/stable-diffusion-xl-base-1.0"
elif self.version == "xl-turbo":
return "stabilityai/sdxl-turbo"

raise ValueError(f"Incorrect version {self.version}")

Expand All @@ -197,13 +205,13 @@ def clip_embedding_dim(self):
return 768
elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base"):
return 1024
elif self.version in ("xl-1.0") and self.is_xl_base():
elif self.is_xl_base_or_turbo():
return 768
else:
raise ValueError(f"Invalid version {self.version}")

def clipwithproj_embedding_dim(self):
if self.version in ("xl-1.0"):
if self.is_xl_base_or_turbo():
return 1280
else:
raise ValueError(f"Invalid version {self.version}")
Expand All @@ -213,9 +221,9 @@ def unet_embedding_dim(self):
return 768
elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base"):
return 1024
elif self.version in ("xl-1.0") and self.is_xl_base():
elif self.is_xl_base_or_turbo():
return 2048
elif self.version in ("xl-1.0") and self.is_xl_refiner():
elif self.version == "xl-1.0" and self.is_xl_refiner():
return 1280
else:
raise ValueError(f"Invalid version {self.version}")
Expand All @@ -227,15 +235,15 @@ def max_image_size(self):
return self._max_image_size

def default_image_size(self):
if self.is_xl():
if self.version == "xl-1.0":
return 1024
if self.version in ("2.0", "2.1"):
return 768
return 512

@staticmethod
def supported_controlnet(version="1.5"):
if version == "xl-1.0":
if version in ("xl-1.0", "xl-turbo"):
return {
"canny": "diffusers/controlnet-canny-sdxl-1.0",
"depth": "diffusers/controlnet-depth-sdxl-1.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, pipeline_info: PipelineInfo, *args, **kwargs):
pipeline_info (PipelineInfo):
Version and Type of stable diffusion pipeline.
"""
assert pipeline_info.is_xl_base()
assert pipeline_info.is_xl_base_or_turbo()

super().__init__(pipeline_info, *args, **kwargs)

Expand Down

0 comments on commit 1f4ee3e

Please sign in to comment.