Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ONNX conversion and pipeline for SD3 #8984

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 292 additions & 0 deletions scripts/convert_stable_diffusion_3_checkpoint_to_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import shutil
from pathlib import Path

import onnx
import torch
from packaging import version
from torch.onnx import export

from diffusers import OnnxRuntimeModel, OnnxStableDiffusion3Pipeline, StableDiffusion3Pipeline


is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")


def onnx_export(
model,
model_args: tuple,
output_path: Path,
ordered_input_names,
output_names,
dynamic_axes,
opset,
use_external_data_format=False,
):
output_path.parent.mkdir(parents=True, exist_ok=True)
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if is_torch_less_than_1_11:
export(
model,
model_args,
f=output_path.as_posix(),
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
use_external_data_format=use_external_data_format,
enable_onnx_checker=True,
opset_version=opset,
)
else:
export(
model,
model_args,
f=output_path.as_posix(),
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
opset_version=opset,
)


@torch.no_grad()
def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = False):
dtype = torch.float16 if fp16 else torch.float32
if fp16 and torch.cuda.is_available():
device = "cuda"
elif fp16 and not torch.cuda.is_available():
raise ValueError("`float16` model export is only supported on GPUs with CUDA")
else:
device = "cpu"
pipeline = StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
output_path = Path(output_path)

# TEXT ENCODER
num_tokens = pipeline.text_encoder.config.max_position_embeddings
text_hidden_size = pipeline.text_encoder.config.hidden_size
text_input = pipeline.tokenizer(
"A sample prompt",
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
onnx_export(
pipeline.text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=(
text_input.input_ids.to(device=device, dtype=torch.int32),
None,
None,
None,
True,
),
output_path=output_path / "text_encoder" / "model.onnx",
ordered_input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output", "hidden_states"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
},
opset=opset,
)
del pipeline.text_encoder

num_tokens = pipeline.text_encoder_2.config.max_position_embeddings
text_hidden_size = pipeline.text_encoder_2.config.hidden_size
text_input = pipeline.tokenizer_2(
"A sample prompt",
padding="max_length",
max_length=pipeline.tokenizer_2.model_max_length,
truncation=True,
return_tensors="pt",
)
onnx_export(
pipeline.text_encoder_2,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=(
text_input.input_ids.to(device=device, dtype=torch.int32),
None,
None,
None,
True,
),
output_path=output_path / "text_encoder_2" / "model.onnx",
ordered_input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output", "hidden_states"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
},
opset=opset,
)
del pipeline.text_encoder_2

text_input = pipeline.tokenizer_3(
"A sample prompt",
padding="max_length",
max_length=pipeline.tokenizer_3.model_max_length,
truncation=True,
return_tensors="pt",
)
onnx_export(
pipeline.text_encoder_3,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=(text_input.input_ids.to(device=device, dtype=torch.int32)),
output_path=output_path / "text_encoder_3" / "model.onnx",
ordered_input_names=["input_ids"],
output_names=["last_hidden_state"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
},
opset=opset,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also collate the external tensor files into one file for this model? Like what you are doing with the transformer.

del pipeline.text_encoder_3

# TRANSFORMER
in_channels = pipeline.transformer.config.in_channels
sample_size = pipeline.transformer.config.sample_size
joint_attention_dim = pipeline.transformer.config.joint_attention_dim
pooled_projection_dim = pipeline.transformer.config.pooled_projection_dim
transformer_path = output_path / "transformer" / "model.onnx"
onnx_export(
pipeline.transformer,
model_args=(
torch.randn(2, in_channels, sample_size, sample_size).to(device=device, dtype=dtype),
torch.randn(2, num_tokens, joint_attention_dim).to(device=device, dtype=dtype),
torch.randn(2, pooled_projection_dim).to(device=device, dtype=dtype),
torch.randn(2).to(device=device, dtype=dtype),
),
output_path=transformer_path,
ordered_input_names=["hidden_states", "encoder_hidden_states", "pooled_projections", "timestep"],
output_names=["out_sample"], # has to be different from "sample" for correct tracing
dynamic_axes={
"hidden_states": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"encoder_hidden_states": {0: "batch", 1: "sequence", 2: "embed_dims"},
"pooled_projections": {0: "batch", 1: "projection_dim"},
"timestep": {0: "batch"},
},
opset=opset,
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
)
model_path = str(transformer_path.absolute().as_posix())
transformer_dir = os.path.dirname(model_path)
transformer = onnx.load(model_path)
# clean up existing tensor files
shutil.rmtree(transformer_dir)
os.mkdir(transformer_dir)
# collate external tensor files into one
onnx.save_model(
transformer,
model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location="weights.pb",
convert_attribute=False,
)
del pipeline.transformer

# VAE ENCODER
vae_encoder = pipeline.vae
vae_in_channels = vae_encoder.config.in_channels
vae_sample_size = vae_encoder.config.sample_size
# need to get the raw tensor output (sample) from the encoder
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()
onnx_export(
vae_encoder,
model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(device=device, dtype=dtype),
False,
),
output_path=output_path / "vae_encoder" / "model.onnx",
ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"],
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)

# VAE DECODER
vae_decoder = pipeline.vae
vae_latent_channels = vae_decoder.config.latent_channels
vae_out_channels = vae_decoder.config.out_channels
# forward only through the decoder part
vae_decoder.forward = vae_encoder.decode
onnx_export(
vae_decoder,
model_args=(
torch.randn(1, vae_latent_channels, sample_size, sample_size).to(device=device, dtype=dtype),
False,
),
output_path=output_path / "vae_decoder" / "model.onnx",
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)
del pipeline.vae

onnx_pipeline = OnnxStableDiffusion3Pipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"),
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
text_encoder_2=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder_2"),
tokenizer_2=pipeline.tokenizer_2,
text_encoder_3=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder_3"),
tokenizer_3=pipeline.tokenizer_3,
transformer=OnnxRuntimeModel.from_pretrained(output_path / "transformer"),
scheduler=pipeline.scheduler,
)

onnx_pipeline.save_pretrained(output_path)
print("ONNX pipeline saved to", output_path)

del pipeline
del onnx_pipeline
_ = OnnxStableDiffusion3Pipeline.from_pretrained(output_path, provider="CPUExecutionProvider")
print("ONNX pipeline is loadable")


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).",
)

parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.")

parser.add_argument(
"--opset",
default=14,
type=int,
help="The version of the ONNX operator set to use.",
)
parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode")

args = parser.parse_args()

convert_models(args.model_path, args.output_path, args.opset, args.fp16)
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@
"OnnxStableDiffusionPipeline",
"OnnxStableDiffusionUpscalePipeline",
"StableDiffusionOnnxPipeline",
"OnnxStableDiffusion3Pipeline",
]
)

Expand Down Expand Up @@ -793,6 +794,7 @@
OnnxStableDiffusionPipeline,
OnnxStableDiffusionUpscalePipeline,
StableDiffusionOnnxPipeline,
OnnxStableDiffusion3Pipeline,
)

try:
Expand Down
24 changes: 24 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,20 @@
]
)

try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403

_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects))
else:
_import_structure["stable_diffusion_3"].extend(
[
"OnnxStableDiffusion3Pipeline",
]
)

try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -610,6 +624,16 @@
StableDiffusionOnnxPipeline,
)

try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_onnx_objects import *
else:
from .stable_diffusion_3 import (
OnnxStableDiffusion3Pipeline,
)

try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
Expand Down
21 changes: 21 additions & 0 deletions src/diffusers/pipelines/stable_diffusion_3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
is_flax_available,
is_torch_available,
is_transformers_available,
is_onnx_available,
)


Expand All @@ -27,6 +28,18 @@
_import_structure["pipeline_stable_diffusion_3_img2img"] = ["StableDiffusion3Img2ImgPipeline"]
_import_structure["pipeline_stable_diffusion_3_inpaint"] = ["StableDiffusion3InpaintPipeline"]

try:
if not (is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_onnx_objects # noqa F403

_dummy_objects.update(get_objects_from_module(dummy_onnx_objects))
else:
_import_structure["pipeline_onnx_stable_diffusion_3"] = [
"OnnxStableDiffusion3Pipeline",
]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
Expand All @@ -38,6 +51,14 @@
from .pipeline_stable_diffusion_3_img2img import StableDiffusion3Img2ImgPipeline
from .pipeline_stable_diffusion_3_inpaint import StableDiffusion3InpaintPipeline

try:
if not (is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_onnx_objects import *
else:
from .pipeline_onnx_stable_diffusion_3 import OnnxStableDiffusion3Pipeline

else:
import sys

Expand Down
Loading